Compare commits

...

79 Commits

Author SHA1 Message Date
公明 52d03dc849 Update config.yaml 2026-05-26 14:45:38 +08:00
公明 9de72d9ad5 Update config.yaml 2026-05-26 14:42:20 +08:00
公明 d95275ffae Add files via upload 2026-05-26 14:37:24 +08:00
公明 6cef93dbb7 Add files via upload 2026-05-26 14:36:40 +08:00
公明 dd3b1ae219 Add files via upload 2026-05-26 14:34:21 +08:00
公明 f42209682a Add files via upload 2026-05-26 14:31:59 +08:00
公明 1b1aed1699 Add files via upload 2026-05-26 14:27:44 +08:00
公明 44ced98863 Add files via upload 2026-05-26 14:24:32 +08:00
公明 97834c162e Update config.yaml 2026-05-23 19:53:40 +08:00
公明 9276f2f144 Add files via upload 2026-05-23 19:49:50 +08:00
公明 a454cada6a Add files via upload 2026-05-23 19:39:03 +08:00
公明 99b53d4fbc Add files via upload 2026-05-23 19:35:30 +08:00
公明 a43a9deaea Add files via upload 2026-05-23 19:33:23 +08:00
公明 ce88da84c9 Add files via upload 2026-05-23 19:31:40 +08:00
公明 15855c7073 Add files via upload 2026-05-23 19:29:49 +08:00
公明 43eb3e546b Add files via upload 2026-05-22 17:23:01 +08:00
公明 2d52c9b6ac Update config.yaml 2026-05-22 17:18:48 +08:00
公明 d5401b8b4c Update config.yaml 2026-05-22 17:17:48 +08:00
公明 5fd4393a2e Add files via upload 2026-05-22 17:14:33 +08:00
公明 a049f6b5c2 Add files via upload 2026-05-22 17:13:55 +08:00
公明 acba8e5a39 Add files via upload 2026-05-22 17:11:34 +08:00
公明 f826b91362 Add files via upload 2026-05-22 17:09:54 +08:00
公明 98c2de2a60 Add files via upload 2026-05-22 17:08:05 +08:00
公明 1c4d4b305b Update config.yaml 2026-05-22 15:15:46 +08:00
公明 f210ac9a03 Add files via upload 2026-05-22 11:36:36 +08:00
公明 6685076dfb Add files via upload 2026-05-22 11:35:02 +08:00
公明 7f322653f6 Add files via upload 2026-05-22 11:32:36 +08:00
公明 66ac2f1357 Add files via upload 2026-05-22 11:30:25 +08:00
公明 c446e22d0c Add files via upload 2026-05-22 11:28:51 +08:00
公明 0358d3a67d Add files via upload 2026-05-22 10:30:19 +08:00
公明 9b82f265fd Add files via upload 2026-05-20 18:24:17 +08:00
公明 3d9cae58e4 Update config.yaml 2026-05-20 17:59:57 +08:00
公明 1f1eadee5e Update config.yaml 2026-05-20 17:58:24 +08:00
公明 0569255189 Add files via upload 2026-05-20 17:54:30 +08:00
公明 8ccf90d067 Add files via upload 2026-05-20 17:52:22 +08:00
公明 b3be89f47d Add files via upload 2026-05-20 17:50:52 +08:00
公明 b9bf8f62d4 Add files via upload 2026-05-20 17:48:42 +08:00
公明 05ca0c1480 Update config.yaml 2026-05-20 16:57:50 +08:00
公明 47a4f3fc5b Add files via upload 2026-05-20 16:52:50 +08:00
公明 a3b378ae9e Add files via upload 2026-05-20 16:49:26 +08:00
公明 a904d26e78 Add files via upload 2026-05-20 16:47:34 +08:00
公明 7ba7476c4f Add files via upload 2026-05-20 16:45:59 +08:00
公明 ae25a243ac Add files via upload 2026-05-20 16:43:38 +08:00
公明 23bd6288ff Add files via upload 2026-05-20 16:39:13 +08:00
公明 fef21d3a24 Add files via upload 2026-05-20 16:36:50 +08:00
公明 933bba4517 Update config.yaml 2026-05-20 16:12:13 +08:00
公明 e1d65437cc Add files via upload 2026-05-20 16:11:10 +08:00
公明 9325aed1eb Add files via upload 2026-05-20 16:09:33 +08:00
公明 dee2b3ab42 Add files via upload 2026-05-20 16:07:33 +08:00
公明 a69bc93fa1 Add files via upload 2026-05-20 16:05:40 +08:00
公明 b1a620bfce Update config.yaml 2026-05-20 14:18:33 +08:00
公明 61b164eec2 Add files via upload 2026-05-20 11:03:38 +08:00
公明 ba77e1837e Update config.yaml 2026-05-19 23:05:52 +08:00
公明 eacad60fd6 Add files via upload 2026-05-19 23:03:04 +08:00
公明 70bf5c93bf Update config.yaml 2026-05-19 19:01:31 +08:00
公明 08bd278d8c Update config.yaml 2026-05-19 18:56:24 +08:00
公明 22746d64a3 Add files via upload 2026-05-19 18:53:46 +08:00
公明 199392a5d5 Add files via upload 2026-05-19 18:52:22 +08:00
公明 aafb4cb584 Add files via upload 2026-05-19 18:50:28 +08:00
公明 96e3dd397c Add files via upload 2026-05-19 18:48:17 +08:00
公明 ec0f17145b Add files via upload 2026-05-19 17:50:38 +08:00
公明 ed53da0999 Delete security directory 2026-05-19 17:49:21 +08:00
公明 dc440fc511 Delete robot directory 2026-05-19 17:49:10 +08:00
公明 009ae59033 Delete logger directory 2026-05-19 17:48:59 +08:00
公明 f348b3245a Delete knowledge directory 2026-05-19 17:48:44 +08:00
公明 0018c5219c Delete config directory 2026-05-19 17:48:33 +08:00
公明 01a3e3677a Delete c2 directory 2026-05-19 17:48:22 +08:00
公明 a12ecdb46f Add files via upload 2026-05-19 17:47:56 +08:00
公明 9f59230d74 Add files via upload 2026-05-19 17:46:33 +08:00
公明 085c6a1c72 Add files via upload 2026-05-19 17:43:45 +08:00
公明 7b3860971f Add files via upload 2026-05-19 17:42:12 +08:00
公明 f6f7b7b237 Add files via upload 2026-05-19 17:40:19 +08:00
公明 d5cf4b3b16 Add files via upload 2026-05-19 16:48:07 +08:00
公明 3e58d8355b Add files via upload 2026-05-19 16:32:38 +08:00
公明 eb01ade63b Add files via upload 2026-05-19 16:29:05 +08:00
公明 d1dc15fa44 Add files via upload 2026-05-19 16:27:29 +08:00
公明 73a39ef868 Add files via upload 2026-05-19 16:25:47 +08:00
公明 a022baef03 Add files via upload 2026-05-19 16:23:21 +08:00
公明 59312d428e Add files via upload 2026-05-19 14:53:07 +08:00
103 changed files with 11914 additions and 994 deletions
+3 -2
View File
@@ -113,6 +113,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 🔒 Password-protected web UI, audit logs, and SQLite persistence - 🔒 Password-protected web UI, audit logs, and SQLite persistence
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks - 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
- 📁 Conversation grouping with pinning, rename, and batch management - 📁 Conversation grouping with pinning, rename, and batch management
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics - 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially - 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions - 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
@@ -285,7 +286,7 @@ Requirements / tips:
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent. - **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only. - **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
- **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`. - **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
- **Config** `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning). - **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
- **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats). - **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
### Skills System (Agent Skills + Eino) ### Skills System (Agent Skills + Eino)
@@ -536,7 +537,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
multi_agent: multi_agent:
enabled: false enabled: false
default_mode: "single" # single | multi (UI default when multi-agent is enabled) default_mode: "single" # single | multi (UI default when multi-agent is enabled)
robot_use_multi_agent: false robot_default_agent_mode: react
batch_use_multi_agent: false batch_use_multi_agent: false
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
+3 -2
View File
@@ -112,6 +112,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 🔒 Web 登录保护、审计日志、SQLite 持久化 - 🔒 Web 登录保护、审计日志、SQLite 持久化
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项) - 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作 - 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact``get_project_fact` 等)
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板 - 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪 - 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制 - 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
@@ -283,7 +284,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。 - **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
- **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。 - **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
- **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。 - **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。
- **配置项**`multi_agent``enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。 - **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
- **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。 - **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
### Skills 技能系统(Agent Skills + Eino ### Skills 技能系统(Agent Skills + Eino
@@ -534,7 +535,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
multi_agent: multi_agent:
enabled: false enabled: false
default_mode: "single" # single | multi(开启多代理时的界面默认模式) default_mode: "single" # single | multi(开启多代理时的界面默认模式)
robot_use_multi_agent: false robot_default_agent_mode: react
batch_use_multi_agent: false batch_use_multi_agent: false
orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用 orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选 # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
+5 -1
View File
@@ -127,7 +127,11 @@ description: 多代理模式下的 Deep 编排者:在已授权安全场景中
## 工具与 MCP ## 工具与 MCP
- **工具调用失败时**:1) 仔细分析错误信息,理解失败的具体原因;2) 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标;3) 如果参数错误,根据错误提示修正参数后重试;4) 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析;5) 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作;6) 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务。工具返回的错误信息会包含在工具响应中,请仔细阅读并做出合理决策。 - **工具调用失败时**:1) 仔细分析错误信息,理解失败的具体原因;2) 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标;3) 如果参数错误,根据错误提示修正参数后重试;4) 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析;5) 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作;6) 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务。工具返回的错误信息会包含在工具响应中,请仔细阅读并做出合理决策。
- **漏洞记录**:发现**有效漏洞**时,必须使用 **`record_vulnerability`** 记录(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度使用 critical / high / medium / low / info。记录后可在授权范围内继续测试。 - **项目黑板(事实)与漏洞记录(分离)**:当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**
- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新。
- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。
- 同一发现可能需**各记一次**(事实记上下文,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。
- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。
- **编排进度(待办)**:当你的任务包含 3 个或以上步骤,或你准备委派多个子目标并行/串行推进时,优先使用 `write_todos` 来向用户展示“当前在做什么/接下来做什么”。维护约束:同一时刻最多一个条目处于 `in_progress`;完成后立刻标记 `completed`;遇到阻塞就保留为 `in_progress` 并继续推进。 - **编排进度(待办)**:当你的任务包含 3 个或以上步骤,或你准备委派多个子目标并行/串行推进时,优先使用 `write_todos` 来向用户展示“当前在做什么/接下来做什么”。维护约束:同一时刻最多一个条目处于 `in_progress`;完成后立刻标记 `completed`;遇到阻塞就保留为 `in_progress` 并继续推进。
- **强触发建议(提升多 agent 使用率)**:如果你将要进行任何“证据收集/枚举/扫描/验证/复现/整理报告”这类实质执行动作,且不只是单步查询,请优先在第一个工具调用前就用 `write_todos` 建立计划;随后用 `task` 委派至少一个子代理获取结构化证据,而不是自己把全部步骤做完。 - **强触发建议(提升多 agent 使用率)**:如果你将要进行任何“证据收集/枚举/扫描/验证/复现/整理报告”这类实质执行动作,且不只是单步查询,请优先在第一个工具调用前就用 `write_todos` 建立计划;随后用 `task` 委派至少一个子代理获取结构化证据,而不是自己把全部步骤做完。
- **技能库(Skills)与知识库**:技能包位于服务器 `skills/` 目录(各子目录 `SKILL.md`,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。多代理本会话通过内置 **`skill`** 工具渐进加载;子代理同样挂载 skill + 可选本机文件工具时,可在委派说明中提示按需加载。若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话。 - **技能库(Skills)与知识库**:技能包位于服务器 `skills/` 目录(各子目录 `SKILL.md`,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。多代理本会话通过内置 **`skill`** 工具渐进加载;子代理同样挂载 skill + 可选本机文件工具时,可在委派说明中提示按需加载。若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话。
+36 -5
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.16" version: "v1.6.23"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -34,6 +34,12 @@ auth:
log: log:
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误) level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径 output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
audit:
enabled: true
retention_days: 15 # 0 表示不自动清理
max_detail_bytes: 8192
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
# ============================================ # ============================================
# 对话相关配置 # 对话相关配置
# ============================================ # ============================================
@@ -55,7 +61,7 @@ openai:
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinkingextended thinking),mode: off 关闭 # Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinkingextended thinking),mode: off 关闭
reasoning: reasoning:
mode: on # auto | on | offoff 时不附加任何推理扩展字段 mode: on # auto | on | offoff 时不附加任何推理扩展字段
effort: max # low | medium | high | max;空表示不指定openai_compat 下 auto 且无强度时不发请求扩展) effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准 allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级) # extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
@@ -76,16 +82,18 @@ agent:
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下 result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起) tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示 # system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
system_prompt_path: ""
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。 # 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
hitl: hitl:
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 [] # 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
tool_whitelist: [read_file, list_dir, glob, grep] tool_whitelist: [read_file, list_dir, glob, grep]
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存) # 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct # 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/streamDeep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep # 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/streamDeep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
multi_agent: multi_agent:
enabled: true enabled: true
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高) robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高) batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。 # plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
@@ -108,7 +116,7 @@ multi_agent:
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文 tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略) tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过 plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放 plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
@@ -125,6 +133,8 @@ multi_agent:
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断) plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题 plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接 checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
run_retry_max_attempts: 0 # >0429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入 deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试 deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑 task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
@@ -235,6 +245,14 @@ knowledge:
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用 # 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
# 在系统设置 -> 机器人设置 中可配置 # 在系统设置 -> 机器人设置 中可配置
robots: robots:
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
enabled: false
bot_token: ""
ilink_bot_id: ""
ilink_user_id: ""
base_url: https://ilinkai.weixin.qq.com
bot_type: "3"
bot_agent: CyberStrikeAI/1.0
wecom: # 企业微信 wecom: # 企业微信
enabled: false enabled: false
token: "" token: ""
@@ -246,11 +264,13 @@ robots:
enabled: false enabled: false
client_id: "" client_id: ""
client_secret: "" client_secret: ""
allow_conversation_id_fallback: false
lark: # 飞书 lark: # 飞书
enabled: false enabled: false
app_id: "" app_id: ""
app_secret: "" app_secret: ""
verify_token: "" verify_token: ""
allow_chat_id_fallback: false
# ============================================ # ============================================
# Skills 相关配置 # Skills 相关配置
# ============================================ # ============================================
@@ -272,3 +292,14 @@ agents_dir: agents
# 系统会从该目录加载所有 .yaml 格式的角色配置文件 # 系统会从该目录加载所有 .yaml 格式的角色配置文件
# 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等 # 每个角色应创建独立的配置文件,例如:roles/CTF.yaml, roles/默认.yaml 等
roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录) roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录)
# ============================================
# 项目管理与事实黑板
# ============================================
project:
enabled: true
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
fact_index_max_runes: 3500
fact_summary_max_runes: 120
default_inject_deprecated: false
+1
View File
@@ -27,6 +27,7 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/pkoukk/tiktoken-go v0.1.8 github.com/pkoukk/tiktoken-go v0.1.8
github.com/robfig/cron/v3 v3.0.1 github.com/robfig/cron/v3 v3.0.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
go.opentelemetry.io/otel v1.34.0 go.opentelemetry.io/otel v1.34.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
+2
View File
@@ -163,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY= github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
+37 -11
View File
@@ -17,6 +17,7 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage" "cyberstrike-ai/internal/storage"
@@ -365,12 +366,12 @@ type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环 // AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil) return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, "")
} }
// AgentLoopWithConversationID 执行Agent循环(带对话ID // AgentLoopWithConversationID 执行Agent循环(带对话ID
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) { func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil) return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, "")
} }
// EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用,与 AgentLoopWithProgress 首条 system 对齐(含 system_prompt_path)。 // EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用,与 AgentLoopWithProgress 首条 system 对齐(含 system_prompt_path)。
@@ -396,7 +397,7 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string {
} }
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID) // AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) { func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, systemPromptExtra string) (*AgentLoopResult, error) {
ctx = withAgentConversationID(ctx, conversationID) ctx = withAgentConversationID(ctx, conversationID)
// 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准) // 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准)
a.mu.Lock() a.mu.Lock()
@@ -426,6 +427,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} }
} }
} }
systemPrompt = project.AppendSystemPromptBlock(systemPrompt, systemPromptExtra)
messages := []ChatMessage{ messages := []ChatMessage{
{ {
@@ -598,11 +600,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
thinkingStreamSeq++ thinkingStreamSeq++
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq) thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
thinkingStreamStarted := false thinkingStreamStarted := false
var thinkingWire string
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error { response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
if delta == "" { if delta == "" {
return nil return nil
} }
var deltaOut string
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
if deltaOut == "" {
return nil
}
if !thinkingStreamStarted { if !thinkingStreamStarted {
thinkingStreamStarted = true thinkingStreamStarted = true
sendProgress("thinking_stream_start", " ", map[string]interface{}{ sendProgress("thinking_stream_start", " ", map[string]interface{}{
@@ -611,10 +619,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"toolStream": false, "toolStream": false,
}) })
} }
sendProgress("thinking_stream_delta", delta, map[string]interface{}{ sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": thinkingStreamId, "streamId": thinkingStreamId,
"iteration": i + 1, "iteration": i + 1,
}) }, thinkingWire))
return nil return nil
}) })
if err != nil { if err != nil {
@@ -827,10 +835,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "summary", "messageGeneratedBy": "summary",
}) })
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{ var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
}) }, summaryWire))
return nil return nil
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
@@ -874,10 +888,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "summary", "messageGeneratedBy": "summary",
}) })
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{ var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
}) }, summaryWire))
return nil return nil
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
@@ -921,10 +941,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "max_iter_summary", "messageGeneratedBy": "max_iter_summary",
}) })
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{ var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
}) }, summaryWire))
return nil return nil
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
+167
View File
@@ -0,0 +1,167 @@
package agent
import (
"encoding/json"
"strings"
)
// ParseTraceMessages 解析落库的 last_react_inputOpenAI 风格 messages JSON 数组)。
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
traceInputJSON = strings.TrimSpace(traceInputJSON)
if traceInputJSON == "" {
return nil, nil
}
var raw []map[string]interface{}
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
return nil, err
}
out := make([]ChatMessage, 0, len(raw))
for _, msgMap := range raw {
msg := ChatMessage{}
role, _ := msgMap["role"].(string)
if role == "" {
continue
}
msg.Role = role
if content, ok := msgMap["content"].(string); ok {
msg.Content = content
}
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
msg.ReasoningContent = rc
}
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
for _, tcRaw := range toolCallsArray {
tcMap, ok := tcRaw.(map[string]interface{})
if !ok {
continue
}
toolCall := ToolCall{}
if id, ok := tcMap["id"].(string); ok {
toolCall.ID = id
}
if toolType, ok := tcMap["type"].(string); ok {
toolCall.Type = toolType
}
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
toolCall.Function = FunctionCall{}
if name, ok := funcMap["name"].(string); ok {
toolCall.Function.Name = name
}
if argsRaw, ok := funcMap["arguments"]; ok {
if argsStr, ok := argsRaw.(string); ok {
var argsMap map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
toolCall.Function.Arguments = argsMap
}
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
toolCall.Function.Arguments = argsMap
}
}
}
if toolCall.ID != "" {
msg.ToolCalls = append(msg.ToolCalls, toolCall)
}
}
}
}
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
msg.ToolCallID = toolCallID
}
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
msg.ToolName = strings.TrimSpace(tn)
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
msg.ToolName = strings.TrimSpace(tn)
}
out = append(out, msg)
}
return out, nil
}
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
if len(msgs) == 0 {
return msgs
}
lastUser := -1
for i, m := range msgs {
if strings.EqualFold(m.Role, "user") {
lastUser = i
}
}
if lastUser < 0 {
return msgs
}
trimmed := msgs[lastUser:]
out := make([]ChatMessage, 0, len(trimmed))
for _, m := range trimmed {
if strings.EqualFold(m.Role, "system") {
continue
}
out = append(out, m)
}
return out
}
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
traceInputJSON = strings.TrimSpace(traceInputJSON)
if traceInputJSON == "" {
return traceInputJSON
}
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
return traceInputJSON
}
lastUser := -1
for i, m := range arr {
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
lastUser = i
}
}
if lastUser <= 0 {
return traceInputJSON
}
trimmed := arr[lastUser:]
b, err := json.Marshal(trimmed)
if err != nil {
return traceInputJSON
}
return string(b)
}
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
assistantOut = strings.TrimSpace(assistantOut)
if assistantOut == "" || len(msgs) == 0 {
return msgs
}
out := append([]ChatMessage(nil), msgs...)
last := &out[len(out)-1]
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
last.Content = assistantOut
return out
}
out = append(out, ChatMessage{
Role: "assistant",
Content: assistantOut,
})
return out
}
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
filtered := make([]ChatMessage, 0, len(msgs))
for _, m := range msgs {
if strings.EqualFold(m.Role, "system") {
continue
}
filtered = append(filtered, m)
}
b, err := json.Marshal(filtered)
if err != nil {
return "", err
}
return string(b), nil
}
+57
View File
@@ -0,0 +1,57 @@
package agent
import (
"encoding/json"
"testing"
)
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
raw := []map[string]interface{}{
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "old answer"},
{"role": "user", "content": "new target 1.1.1.1"},
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
"id": "c1", "type": "function",
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
}}},
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
}
b, _ := json.Marshal(raw)
out := ExtractLastUserTurnTraceJSON(string(b))
var trimmed []map[string]interface{}
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
t.Fatal(err)
}
if len(trimmed) != 3 {
t.Fatalf("expected 3 messages, got %d", len(trimmed))
}
if trimmed[0]["content"] != "new target 1.1.1.1" {
t.Fatalf("unexpected first message: %v", trimmed[0])
}
}
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
msgs := []ChatMessage{
{Role: "system", Content: "sys"},
{Role: "user", Content: "q"},
{Role: "assistant", Content: "a"},
}
out := ExtractLastUserTurnMessages(msgs)
if len(out) != 2 {
t.Fatalf("expected 2, got %d", len(out))
}
if out[0].Role != "user" {
t.Fatal("expected user first")
}
}
func TestMergeAssistantTraceOutput(t *testing.T) {
msgs := []ChatMessage{
{Role: "user", Content: "q"},
{Role: "assistant", Content: "draft"},
}
out := MergeAssistantTraceOutput(msgs, "final summary")
if out[len(out)-1].Content != "final summary" {
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
}
}
@@ -105,11 +105,15 @@ func DefaultSingleAgentSystemPrompt() string {
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。 - 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。 - 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
## 漏洞记录 ## 项目黑板(事实)与漏洞记录(分离)
发现有效漏洞时,必须使` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。 当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须` + builtin.ToolGetProjectFact + `(fact_key) 获取 body,禁止凭摘要臆造细节。**
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试 - **环境/目标/认证等认知**(非正式漏洞条目):使用 ` + builtin.ToolUpsertProjectFact + `fact_key 建议 ` + "`category/slug`" + `(如 target/primary_domain),同 key 覆盖更新
- **可交付漏洞**:使用 ` + builtin.ToolRecordVulnerability + `,含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ` + builtin.ToolListVulnerabilities + ` 查重,详情用 ` + builtin.ToolGetVulnerability + `(id)(默认仅当前项目/会话)。
- 同一发现可能需**各记一次**(事实记上下文,漏洞记正式 findings)。误报用 ` + builtin.ToolDeprecateProjectFact + ` 或漏洞状态 false_positive。
严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。
## 技能库(Skills)与知识库 ## 技能库(Skills)与知识库
+77 -193
View File
@@ -15,6 +15,7 @@ import (
"time" "time"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/c2" "cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
@@ -56,10 +57,12 @@ type App struct {
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启 dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启 larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil) c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗 c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
c2WatchdogCancel context.CancelFunc // 看门狗取消函数 c2WatchdogCancel context.CancelFunc // 看门狗取消函数
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步) c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
auditSvc *audit.Service
} }
// New 创建新应用 // New 创建新应用
@@ -92,6 +95,11 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
return nil, fmt.Errorf("初始化数据库失败: %w", err) return nil, fmt.Errorf("初始化数据库失败: %w", err)
} }
auditSvc := audit.NewService(db, cfg, log.Logger)
audit.RegisterConversationCreateHook(auditSvc)
auditSvc.PurgeExpired()
audit.StartRetentionLoop(auditSvc, log.Logger)
// 创建MCP服务器(带数据库持久化) // 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db) mcpServer := mcp.NewServerWithStorage(log.Logger, db)
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes) mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
@@ -103,7 +111,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
executor.RegisterTools(mcpServer) executor.RegisterTools(mcpServer)
// 注册漏洞记录工具 // 注册漏洞记录工具
registerVulnerabilityTool(mcpServer, db, log.Logger) registerVulnerabilityTools(mcpServer, db, log.Logger)
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
if cfg.Auth.GeneratedPassword != "" { if cfg.Auth.GeneratedPassword != "" {
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr) config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
@@ -221,6 +230,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
// 创建知识库API处理器 // 创建知识库API处理器
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger) knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
knowledgeHandler.SetAudit(auditSvc)
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 扫描知识库并建立索引(异步) // 扫描知识库并建立索引(异步)
@@ -317,31 +327,43 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err)) log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
} }
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir) markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
markdownAgentsHandler.SetAudit(auditSvc)
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir)) log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
// 创建处理器 // 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger) agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
agentHandler.SetAudit(auditSvc)
agentHandler.SetAgentsMarkdownDir(agentsDir) agentHandler.SetAgentsMarkdownDir(agentsDir)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志 // 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil { if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager) agentHandler.SetKnowledgeManager(knowledgeManager)
} }
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetAudit(auditSvc)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger) notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger) groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
authHandler.SetAudit(auditSvc)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger) vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
projectHandler := handler.NewProjectHandler(db, log.Logger)
vulnerabilityHandler.SetAudit(auditSvc)
webshellHandler := handler.NewWebShellHandler(log.Logger, db) webshellHandler := handler.NewWebShellHandler(log.Logger, db)
webshellHandler.SetAudit(auditSvc)
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger) chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
chatUploadsHandler.SetAudit(auditSvc)
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger) registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger) registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
configHandler.SetAudit(auditSvc)
agentHandler.SetHitlToolWhitelistSaver(configHandler) agentHandler.SetHitlToolWhitelistSaver(configHandler)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
externalMCPHandler.SetAudit(auditSvc)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger) roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
roleHandler.SetAudit(auditSvc)
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger) skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
skillsHandler.SetAudit(auditSvc)
fofaHandler := handler.NewFofaHandler(cfg, log.Logger) fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
terminalHandler := handler.NewTerminalHandler(log.Logger) terminalHandler := handler.NewTerminalHandler(log.Logger)
if db != nil { if db != nil {
@@ -356,9 +378,12 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port) registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
} }
c2Handler := handler.NewC2Handler(c2Manager, log.Logger) c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
c2Handler.SetAudit(auditSvc)
// 创建OpenAPI处理器 // 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger) conversationHandler := handler.NewConversationHandler(db, log.Logger)
conversationHandler.SetAudit(auditSvc)
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger) robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler) openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
@@ -384,13 +409,15 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
c2Watchdog: c2Watchdog, c2Watchdog: c2Watchdog,
c2WatchdogCancel: watchdogCancel, c2WatchdogCancel: watchdogCancel,
c2Handler: c2Handler, c2Handler: c2Handler,
auditSvc: auditSvc,
} }
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启 // 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
app.startRobotConnections() app.startRobotConnections()
// 设置漏洞工具注册器(内置工具,必须设置) // 设置漏洞工具注册器(内置工具,必须设置)
vulnerabilityRegistrar := func() error { vulnerabilityRegistrar := func() error {
registerVulnerabilityTool(mcpServer, db, log.Logger) registerVulnerabilityTools(mcpServer, db, log.Logger)
registerProjectFactTools(mcpServer, db, cfg, log.Logger)
return nil return nil
} }
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar) configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
@@ -449,9 +476,11 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
configHandler.SetRetrieverUpdater(knowledgeRetriever) configHandler.SetRetrieverUpdater(knowledgeRetriever)
} }
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效 // 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
configHandler.SetRobotRestarter(app) configHandler.SetRobotRestarter(app)
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
configHandler.SetC2Runtime(app) configHandler.SetC2Runtime(app)
configHandler.SetC2ToolRegistrar(func() error { configHandler.SetC2ToolRegistrar(func() error {
if app.config.C2.EnabledEffective() && app.c2Manager != nil { if app.config.C2.EnabledEffective() && app.c2Manager != nil {
@@ -469,12 +498,14 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
notificationHandler, notificationHandler,
conversationHandler, conversationHandler,
robotHandler, robotHandler,
wechatRobotHandler,
groupHandler, groupHandler,
configHandler, configHandler,
externalMCPHandler, externalMCPHandler,
attackChainHandler, attackChainHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler, vulnerabilityHandler,
projectHandler,
webshellHandler, webshellHandler,
chatUploadsHandler, chatUploadsHandler,
roleHandler, roleHandler,
@@ -483,6 +514,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
fofaHandler, fofaHandler,
terminalHandler, terminalHandler,
app.c2Handler, app.c2Handler,
auditHandler,
mcpServer, mcpServer,
authManager, authManager,
openAPIHandler, openAPIHandler,
@@ -675,9 +707,14 @@ func (a *App) startRobotConnections() {
a.dingCancel = cancel a.dingCancel = cancel
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger) go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
} }
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
ctx, cancel := context.WithCancel(context.Background())
a.wechatCancel = cancel
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
}
} }
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter // RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter
func (a *App) RestartRobotConnections() { func (a *App) RestartRobotConnections() {
a.robotMu.Lock() a.robotMu.Lock()
if a.dingCancel != nil { if a.dingCancel != nil {
@@ -688,6 +725,10 @@ func (a *App) RestartRobotConnections() {
a.larkCancel() a.larkCancel()
a.larkCancel = nil a.larkCancel = nil
} }
if a.wechatCancel != nil {
a.wechatCancel()
a.wechatCancel = nil
}
a.robotMu.Unlock() a.robotMu.Unlock()
// 给旧 goroutine 一点时间退出 // 给旧 goroutine 一点时间退出
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
@@ -703,12 +744,14 @@ func setupRoutes(
notificationHandler *handler.NotificationHandler, notificationHandler *handler.NotificationHandler,
conversationHandler *handler.ConversationHandler, conversationHandler *handler.ConversationHandler,
robotHandler *handler.RobotHandler, robotHandler *handler.RobotHandler,
wechatRobotHandler *handler.WechatRobotHandler,
groupHandler *handler.GroupHandler, groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler, configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler, externalMCPHandler *handler.ExternalMCPHandler,
attackChainHandler *handler.AttackChainHandler, attackChainHandler *handler.AttackChainHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler, vulnerabilityHandler *handler.VulnerabilityHandler,
projectHandler *handler.ProjectHandler,
webshellHandler *handler.WebShellHandler, webshellHandler *handler.WebShellHandler,
chatUploadsHandler *handler.ChatUploadsHandler, chatUploadsHandler *handler.ChatUploadsHandler,
roleHandler *handler.RoleHandler, roleHandler *handler.RoleHandler,
@@ -717,6 +760,7 @@ func setupRoutes(
fofaHandler *handler.FofaHandler, fofaHandler *handler.FofaHandler,
terminalHandler *handler.TerminalHandler, terminalHandler *handler.TerminalHandler,
c2Handler *handler.C2Handler, c2Handler *handler.C2Handler,
auditHandler *handler.AuditHandler,
mcpServer *mcp.Server, mcpServer *mcp.Server,
authManager *security.AuthManager, authManager *security.AuthManager,
openAPIHandler *handler.OpenAPIHandler, openAPIHandler *handler.OpenAPIHandler,
@@ -751,6 +795,12 @@ func setupRoutes(
// 机器人测试(需登录):POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑 // 机器人测试(需登录):POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
protected.POST("/robot/test", robotHandler.HandleRobotTest) protected.POST("/robot/test", robotHandler.HandleRobotTest)
// 微信 iLink 扫码绑定(需登录)
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
// Agent Loop // Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop) protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出 // Agent Loop 流式输出
@@ -806,6 +856,7 @@ func setupRoutes(
protected.GET("/conversations/:id", conversationHandler.GetConversation) protected.GET("/conversations/:id", conversationHandler.GetConversation)
protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails) protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails)
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation) protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
protected.PUT("/conversations/:id/project", conversationHandler.SetConversationProject)
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation) protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn) protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn)
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned) protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
@@ -847,6 +898,13 @@ func setupRoutes(
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream) protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
protected.GET("/terminal/ws", terminalHandler.RunCommandWS) protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
// 平台审计日志
protected.GET("/audit/meta", auditHandler.Meta)
protected.GET("/audit/summary", auditHandler.Summary)
protected.GET("/audit/logs", auditHandler.ListLogs)
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
protected.GET("/audit/logs/:id", auditHandler.GetLog)
// 外部MCP管理 // 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs) protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats) protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
@@ -1015,6 +1073,18 @@ func setupRoutes(
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability) protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability) protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 项目管理与事实黑板
protected.GET("/projects", projectHandler.ListProjects)
protected.POST("/projects", projectHandler.CreateProject)
protected.GET("/projects/:id", projectHandler.GetProject)
protected.PUT("/projects/:id", projectHandler.UpdateProject)
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
protected.DELETE("/projects/:id/facts/:factId", projectHandler.DeleteFact)
protected.POST("/projects/:id/facts/deprecate", projectHandler.DeprecateFact)
// WebShell 管理(代理执行 + 连接配置存 SQLite) // WebShell 管理(代理执行 + 连接配置存 SQLite)
protected.GET("/webshell/connections", webshellHandler.ListConnections) protected.GET("/webshell/connections", webshellHandler.ListConnections)
protected.POST("/webshell/connections", webshellHandler.CreateConnection) protected.POST("/webshell/connections", webshellHandler.CreateConnection)
@@ -1135,195 +1205,6 @@ func setupRoutes(
}) })
} }
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "漏洞标题(必需)",
},
"description": map[string]interface{}{
"type": "string",
"description": "漏洞详细描述",
},
"severity": map[string]interface{}{
"type": "string",
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"vulnerability_type": map[string]interface{}{
"type": "string",
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
},
"target": map[string]interface{}{
"type": "string",
"description": "受影响的目标(URL、IP地址、服务等)",
},
"proof": map[string]interface{}{
"type": "string",
"description": "漏洞证明(POC、截图、请求/响应等)",
},
"impact": map[string]interface{}{
"type": "string",
"description": "漏洞影响说明",
},
"recommendation": map[string]interface{}{
"type": "string",
"description": "修复建议",
},
},
"required": []string{"title", "severity"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 从参数中获取conversation_id(由Agent自动添加)
conversationID, _ := args["conversation_id"].(string)
if conversationID == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
},
},
IsError: true,
}, nil
}
title, ok := args["title"].(string)
if !ok || title == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: title 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
severity, ok := args["severity"].(string)
if !ok || severity == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: severity 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
// 验证严重程度
validSeverities := map[string]bool{
"critical": true,
"high": true,
"medium": true,
"low": true,
"info": true,
}
if !validSeverities[severity] {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
},
},
IsError: true,
}, nil
}
// 获取可选参数
description := ""
if d, ok := args["description"].(string); ok {
description = d
}
vulnType := ""
if t, ok := args["vulnerability_type"].(string); ok {
vulnType = t
}
target := ""
if t, ok := args["target"].(string); ok {
target = t
}
proof := ""
if p, ok := args["proof"].(string); ok {
proof = p
}
impact := ""
if i, ok := args["impact"].(string); ok {
impact = i
}
recommendation := ""
if r, ok := args["recommendation"].(string); ok {
recommendation = r
}
// 创建漏洞记录
vuln := &database.Vulnerability{
ConversationID: conversationID,
Title: title,
Description: description,
Severity: severity,
Status: "open",
Type: vulnType,
Target: target,
Proof: proof,
Impact: impact,
Recommendation: recommendation,
}
created, err := db.CreateVulnerability(vuln)
if err != nil {
logger.Error("记录漏洞失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("记录漏洞失败: %v", err),
},
},
IsError: true,
}, nil
}
logger.Info("漏洞记录成功",
zap.String("id", created.ID),
zap.String("title", created.Title),
zap.String("severity", created.Severity),
zap.String("conversation_id", conversationID),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("漏洞记录工具注册成功")
}
// registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作 // registerWebshellTools 注册 WebShell 相关 MCP 工具,供 AI 助手在指定连接上执行命令与文件操作
func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) { func registerWebshellTools(mcpServer *mcp.Server, db *database.DB, webshellHandler *handler.WebShellHandler, logger *zap.Logger) {
if db == nil || webshellHandler == nil { if db == nil || webshellHandler == nil {
@@ -1908,6 +1789,9 @@ func initializeKnowledge(
// 创建知识库API处理器 // 创建知识库API处理器
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger) knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
if app != nil && app.auditSvc != nil {
knowledgeHandler.SetAudit(app.auditSvc)
}
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 设置知识库管理器到AgentHandler以便记录检索日志 // 设置知识库管理器到AgentHandler以便记录检索日志
+278
View File
@@ -0,0 +1,278 @@
package app
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
func projectIDFromConversation(db *database.DB, ctx context.Context) (string, error) {
convID := agent.ConversationIDFromContext(ctx)
if convID == "" {
return "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用项目事实工具")
}
pid, err := db.GetConversationProjectID(convID)
if err != nil {
return "", err
}
if strings.TrimSpace(pid) == "" {
return "", fmt.Errorf("当前对话未绑定项目,请先在对话中选择项目或创建带项目的对话")
}
return pid, nil
}
func textResult(msg string, isErr bool) *mcp.ToolResult {
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: msg}},
IsError: isErr,
}
}
// registerProjectFactTools 注册项目黑板 MCP 工具。
func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *config.Config, logger *zap.Logger) {
if db == nil || cfg == nil || !cfg.Project.Enabled {
if logger != nil {
logger.Info("项目黑板工具未注册(未启用)")
}
return
}
upsertTool := mcp.Tool{
Name: builtin.ToolUpsertProjectFact,
Description: "写入或更新项目黑板事实。用于记录环境认知、目标信息、认证特征等(非正式漏洞条目)。同 fact_key 会覆盖更新。需要当前对话已绑定项目。",
ShortDescription: "写入/更新项目事实",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"fact_key": map[string]interface{}{
"type": "string",
"description": "项目内唯一 key,建议格式 category/slug,如 target/primary_domain",
},
"category": map[string]interface{}{
"type": "string",
"description": "分类:target、auth、infra、business、note 等",
},
"summary": map[string]interface{}{
"type": "string",
"description": "单行摘要(会注入到后续对话索引)",
},
"body": map[string]interface{}{
"type": "string",
"description": "完整详情(POC、长文本等,仅 get_project_fact 返回)",
},
"confidence": map[string]interface{}{
"type": "string",
"description": "confirmed | tentative | deprecated",
"enum": []string{"confirmed", "tentative", "deprecated"},
},
"pinned": map[string]interface{}{
"type": "boolean",
"description": "是否优先出现在黑板索引",
},
"related_vulnerability_id": map[string]interface{}{
"type": "string",
"description": "可选:关联的漏洞记录 ID",
},
},
"required": []string{"fact_key", "summary"},
},
}
mcpServer.RegisterTool(upsertTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
projectID, err := projectIDFromConversation(db, ctx)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
factKey, _ := args["fact_key"].(string)
summary, _ := args["summary"].(string)
if strings.TrimSpace(factKey) == "" || strings.TrimSpace(summary) == "" {
return textResult("错误: fact_key 与 summary 必填", true), nil
}
if len([]rune(summary)) > cfg.Project.FactSummaryMaxRunesEffective() {
return textResult(fmt.Sprintf("错误: summary 过长(最多 %d 字)", cfg.Project.FactSummaryMaxRunesEffective()), true), nil
}
f := &database.ProjectFact{
ProjectID: projectID,
FactKey: factKey,
Category: strArg(args, "category"),
Summary: summary,
Body: strArg(args, "body"),
Confidence: strArg(args, "confidence"),
Pinned: boolArg(args, "pinned"),
RelatedVulnerabilityID: strArg(args, "related_vulnerability_id"),
}
if convID := agent.ConversationIDFromContext(ctx); convID != "" {
f.SourceConversationID = convID
}
created, err := db.UpsertProjectFact(f)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
return textResult(fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence), false), nil
})
getTool := mcp.Tool{
Name: builtin.ToolGetProjectFact,
Description: "按 fact_key 获取项目事实完整 body 与元数据。摘要不足时必须调用本工具,禁止臆造细节。",
ShortDescription: "按 key 获取事实详情",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"fact_key": map[string]interface{}{"type": "string", "description": "事实 key"},
},
"required": []string{"fact_key"},
},
}
mcpServer.RegisterTool(getTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
projectID, err := projectIDFromConversation(db, ctx)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
key := strings.TrimSpace(strArg(args, "fact_key"))
if key == "" {
return textResult("错误: fact_key 必填", true), nil
}
f, err := db.GetProjectFactByKey(projectID, key)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
msg := fmt.Sprintf("fact_key: %s\ncategory: %s\nconfidence: %s\nsummary: %s\nupdated_at: %s\n\n--- body ---\n%s",
f.FactKey, f.Category, f.Confidence, f.Summary, f.UpdatedAt.Format("2006-01-02 15:04:05"), f.Body)
return textResult(msg, false), nil
})
listTool := mcp.Tool{
Name: builtin.ToolListProjectFacts,
Description: "列出当前项目的事实(分页)。",
ShortDescription: "列出项目事实",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"category": map[string]interface{}{"type": "string"},
"confidence": map[string]interface{}{"type": "string"},
"limit": map[string]interface{}{"type": "integer"},
"offset": map[string]interface{}{"type": "integer"},
},
},
}
mcpServer.RegisterTool(listTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
projectID, err := projectIDFromConversation(db, ctx)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
limit := intArg(args, "limit", 50)
offset := intArg(args, "offset", 0)
filter := database.ProjectFactListFilter{
Category: strArg(args, "category"),
Confidence: strArg(args, "confidence"),
}
list, err := db.ListProjectFacts(projectID, filter, limit, offset)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
var b strings.Builder
b.WriteString(fmt.Sprintf("共 %d 条(limit=%d offset=%d:\n", len(list), limit, offset))
for _, f := range list {
b.WriteString(fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, f.Summary, f.Confidence))
}
return textResult(b.String(), false), nil
})
searchTool := mcp.Tool{
Name: builtin.ToolSearchProjectFacts,
Description: "按关键词搜索项目事实(summary/body/fact_key)。",
ShortDescription: "搜索项目事实",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{"type": "string"},
"limit": map[string]interface{}{"type": "integer"},
"offset": map[string]interface{}{"type": "integer"},
},
"required": []string{"query"},
},
}
mcpServer.RegisterTool(searchTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
projectID, err := projectIDFromConversation(db, ctx)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
q := strings.TrimSpace(strArg(args, "query"))
if q == "" {
return textResult("错误: query 必填", true), nil
}
list, err := db.ListProjectFacts(projectID, database.ProjectFactListFilter{Search: q}, intArg(args, "limit", 30), intArg(args, "offset", 0))
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
var b strings.Builder
b.WriteString(fmt.Sprintf("搜索 \"%s\" 命中 %d 条:\n", q, len(list)))
for _, f := range list {
b.WriteString(fmt.Sprintf("- [%s] %s — %s\n", f.FactKey, f.Category, f.Summary))
}
return textResult(b.String(), false), nil
})
deprecateTool := mcp.Tool{
Name: builtin.ToolDeprecateProjectFact,
Description: "将事实标记为 deprecated,从黑板索引中排除。",
ShortDescription: "废弃项目事实",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"fact_key": map[string]interface{}{"type": "string"},
},
"required": []string{"fact_key"},
},
}
mcpServer.RegisterTool(deprecateTool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
projectID, err := projectIDFromConversation(db, ctx)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
key := strings.TrimSpace(strArg(args, "fact_key"))
if err := db.DeprecateProjectFact(projectID, key); err != nil {
return textResult("错误: "+err.Error(), true), nil
}
return textResult("事实已标记为 deprecated: "+key, false), nil
})
if logger != nil {
logger.Info("项目黑板 MCP 工具注册成功")
}
}
func strArg(args map[string]interface{}, key string) string {
if v, ok := args[key].(string); ok {
return v
}
return ""
}
func boolArg(args map[string]interface{}, key string) bool {
if v, ok := args[key].(bool); ok {
return v
}
return false
}
func intArg(args map[string]interface{}, key string, def int) int {
switch v := args[key].(type) {
case float64:
return int(v)
case int:
return v
case int64:
return int(v)
default:
return def
}
}
+405
View File
@@ -0,0 +1,405 @@
package app
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
func conversationIDFromToolCtx(ctx context.Context) string {
if id := agent.ConversationIDFromContext(ctx); id != "" {
return id
}
return mcp.MCPConversationIDFromContext(ctx)
}
// canAccessVulnerability 校验当前对话是否有权查看该漏洞(默认项目隔离,未绑项目则仅本会话)。
func canAccessVulnerability(vuln *database.Vulnerability, convID, projectID string) bool {
if vuln == nil || convID == "" {
return false
}
if projectID != "" {
if strings.TrimSpace(vuln.ProjectID) == projectID {
return true
}
// 历史记录:写入时尚未绑定 project_id,但属于同一会话
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID == convID {
return true
}
return false
}
return vuln.ConversationID == convID
}
func buildVulnerabilityListFilter(db *database.DB, ctx context.Context, args map[string]interface{}) (database.VulnerabilityListFilter, string, error) {
convID := conversationIDFromToolCtx(ctx)
if convID == "" {
return database.VulnerabilityListFilter{}, "", fmt.Errorf("无法确定当前对话,请在对话上下文中使用漏洞查询工具")
}
projectID := ""
if pid, err := db.GetConversationProjectID(convID); err == nil {
projectID = strings.TrimSpace(pid)
}
scope := strings.TrimSpace(strArg(args, "scope"))
if scope == "" {
if projectID != "" {
scope = "project"
} else {
scope = "conversation"
}
}
filter := database.VulnerabilityListFilter{
Severity: strings.TrimSpace(strArg(args, "severity")),
Status: strings.TrimSpace(strArg(args, "status")),
}
if q := strings.TrimSpace(strArg(args, "q")); q != "" {
filter.Search = q
} else {
filter.Search = strings.TrimSpace(strArg(args, "search"))
}
var scopeLabel string
switch scope {
case "project":
if projectID == "" {
return filter, "", fmt.Errorf("当前对话未绑定项目,无法按项目列出漏洞;请使用 scope=conversation,或先在对话中绑定项目")
}
filter.ProjectID = projectID
scopeLabel = fmt.Sprintf("项目 %s", projectID)
case "conversation":
filter.ConversationID = convID
scopeLabel = fmt.Sprintf("会话 %s", convID)
default:
return filter, "", fmt.Errorf("scope 仅支持 project 或 conversation,当前值: %s", scope)
}
return filter, scopeLabel, nil
}
func formatVulnerabilityListItem(v *database.Vulnerability) string {
line := fmt.Sprintf("- id=%s | %s | %s | %s", v.ID, v.Severity, v.Status, v.Title)
if v.Type != "" {
line += fmt.Sprintf(" | type=%s", v.Type)
}
if v.Target != "" {
line += fmt.Sprintf(" | target=%s", truncateRunes(v.Target, 80))
}
return line
}
func formatVulnerabilityDetail(v *database.Vulnerability) string {
var b strings.Builder
b.WriteString(fmt.Sprintf("漏洞ID: %s\n", v.ID))
b.WriteString(fmt.Sprintf("标题: %s\n", v.Title))
b.WriteString(fmt.Sprintf("严重程度: %s\n", v.Severity))
b.WriteString(fmt.Sprintf("状态: %s\n", v.Status))
if v.Type != "" {
b.WriteString(fmt.Sprintf("类型: %s\n", v.Type))
}
if v.Target != "" {
b.WriteString(fmt.Sprintf("目标: %s\n", v.Target))
}
if v.ProjectID != "" {
b.WriteString(fmt.Sprintf("项目ID: %s\n", v.ProjectID))
}
b.WriteString(fmt.Sprintf("会话ID: %s\n", v.ConversationID))
if !v.CreatedAt.IsZero() {
b.WriteString(fmt.Sprintf("创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
}
if v.Description != "" {
b.WriteString("\n--- 描述 ---\n")
b.WriteString(v.Description)
b.WriteString("\n")
}
if v.Proof != "" {
b.WriteString("\n--- 证明(POC ---\n")
b.WriteString(v.Proof)
b.WriteString("\n")
}
if v.Impact != "" {
b.WriteString("\n--- 影响 ---\n")
b.WriteString(v.Impact)
b.WriteString("\n")
}
if v.Recommendation != "" {
b.WriteString("\n--- 修复建议 ---\n")
b.WriteString(v.Recommendation)
b.WriteString("\n")
}
return b.String()
}
func truncateRunes(s string, max int) string {
r := []rune(s)
if len(r) <= max {
return s
}
return string(r[:max]) + "…"
}
// registerVulnerabilityTools 注册漏洞记录与查询 MCP 工具。
func registerVulnerabilityTools(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
registerRecordVulnerabilityTool(mcpServer, db, logger)
registerListVulnerabilitiesTool(mcpServer, db, logger)
registerGetVulnerabilityTool(mcpServer, db, logger)
if logger != nil {
logger.Info("漏洞 MCP 工具注册成功", zap.Strings("tools", []string{
builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
}))
}
}
func registerRecordVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。记录前可先 list_vulnerabilities 避免重复。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "漏洞标题(必需)",
},
"description": map[string]interface{}{
"type": "string",
"description": "漏洞详细描述",
},
"severity": map[string]interface{}{
"type": "string",
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"vulnerability_type": map[string]interface{}{
"type": "string",
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
},
"target": map[string]interface{}{
"type": "string",
"description": "受影响的目标(URL、IP地址、服务等)",
},
"proof": map[string]interface{}{
"type": "string",
"description": "漏洞证明(POC、截图、请求/响应等)",
},
"impact": map[string]interface{}{
"type": "string",
"description": "漏洞影响说明",
},
"recommendation": map[string]interface{}{
"type": "string",
"description": "修复建议",
},
},
"required": []string{"title", "severity"},
},
}
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
conversationID := strings.TrimSpace(strArg(args, "conversation_id"))
if conversationID == "" {
conversationID = conversationIDFromToolCtx(ctx)
}
if conversationID == "" {
return textResult("错误: conversation_id 未设置。这是系统错误,请重试。", true), nil
}
title := strings.TrimSpace(strArg(args, "title"))
if title == "" {
return textResult("错误: title 参数必需且不能为空", true), nil
}
severity := strings.TrimSpace(strArg(args, "severity"))
if severity == "" {
return textResult("错误: severity 参数必需且不能为空", true), nil
}
validSeverities := map[string]bool{
"critical": true, "high": true, "medium": true, "low": true, "info": true,
}
if !validSeverities[severity] {
return textResult(fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity), true), nil
}
projectID := ""
if pid, perr := db.GetConversationProjectID(conversationID); perr == nil {
projectID = strings.TrimSpace(pid)
}
vuln := &database.Vulnerability{
ConversationID: conversationID,
ProjectID: projectID,
Title: title,
Description: strArg(args, "description"),
Severity: severity,
Status: "open",
Type: strArg(args, "vulnerability_type"),
Target: strArg(args, "target"),
Proof: strArg(args, "proof"),
Impact: strArg(args, "impact"),
Recommendation: strArg(args, "recommendation"),
}
created, err := db.CreateVulnerability(vuln)
if err != nil {
if logger != nil {
logger.Error("记录漏洞失败", zap.Error(err))
}
return textResult(fmt.Sprintf("记录漏洞失败: %v", err), true), nil
}
if logger != nil {
logger.Info("漏洞记录成功",
zap.String("id", created.ID),
zap.String("title", created.Title),
zap.String("severity", created.Severity),
zap.String("conversation_id", conversationID),
)
}
return textResult(fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n可使用 get_vulnerability(id) 查看详情,或 list_vulnerabilities 查看列表。",
created.ID, created.Title, created.Severity, created.Status), false), nil
})
}
func registerListVulnerabilitiesTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: builtin.ToolListVulnerabilities,
Description: "列出当前授权范围内的漏洞(摘要)。默认:对话已绑定项目时列出该项目下全部漏洞;未绑项目时仅列出当前会话漏洞。可用 scope=conversation 仅看本会话。记录新漏洞前建议先调用以避免重复。",
ShortDescription: "列出漏洞(默认当前项目)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"scope": map[string]interface{}{
"type": "string",
"description": "范围:project(默认,需绑定项目)| conversation(仅当前会话)",
"enum": []string{"project", "conversation"},
},
"severity": map[string]interface{}{
"type": "string",
"description": "按严重程度筛选:critical、high、medium、low、info",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"status": map[string]interface{}{
"type": "string",
"description": "按状态筛选:open、confirmed、fixed、false_positive",
"enum": []string{"open", "confirmed", "fixed", "false_positive"},
},
"q": map[string]interface{}{
"type": "string",
"description": "关键词搜索(标题、描述、类型、目标等)",
},
"limit": map[string]interface{}{
"type": "integer",
"description": "返回条数上限,默认 30,最大 100",
},
"offset": map[string]interface{}{
"type": "integer",
"description": "分页偏移,默认 0",
},
},
},
}
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
filter, scopeLabel, err := buildVulnerabilityListFilter(db, ctx, args)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
limit := intArg(args, "limit", 30)
if limit <= 0 || limit > 100 {
limit = 30
}
offset := intArg(args, "offset", 0)
if offset < 0 {
offset = 0
}
total, err := db.CountVulnerabilities(filter)
if err != nil {
if logger != nil {
logger.Warn("统计漏洞失败", zap.Error(err))
}
total = 0
}
list, err := db.ListVulnerabilities(limit, offset, filter)
if err != nil {
return textResult("错误: "+err.Error(), true), nil
}
var b strings.Builder
b.WriteString(fmt.Sprintf("范围: %s\n总计: %d | 本页: %d 条 (limit=%d offset=%d)\n\n", scopeLabel, total, len(list), limit, offset))
if len(list) == 0 {
b.WriteString("(暂无漏洞记录)\n")
} else {
for _, v := range list {
b.WriteString(formatVulnerabilityListItem(v))
b.WriteString("\n")
}
if total > offset+len(list) {
b.WriteString(fmt.Sprintf("\n(还有更多,可增大 offset 或使用 q/severity/status 筛选)\n"))
}
}
b.WriteString("\n需要 POC 与完整字段请对具体 id 调用 get_vulnerability。")
return textResult(b.String(), false), nil
})
}
func registerGetVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: builtin.ToolGetVulnerability,
Description: "按漏洞 ID 获取完整详情(含 POC、影响、修复建议)。仅能访问当前项目或当前会话下的漏洞(与 list_vulnerabilities 授权范围一致)。",
ShortDescription: "按 ID 获取漏洞详情",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"id": map[string]interface{}{
"type": "string",
"description": "漏洞 IDlist_vulnerabilities 返回的 id",
},
},
"required": []string{"id"},
},
}
mcpServer.RegisterTool(tool, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
convID := conversationIDFromToolCtx(ctx)
if convID == "" {
return textResult("错误: 无法确定当前对话,请在对话上下文中使用本工具", true), nil
}
id := strings.TrimSpace(strArg(args, "id"))
if id == "" {
return textResult("错误: id 必填", true), nil
}
vuln, err := db.GetVulnerability(id)
if err != nil {
return textResult("错误: 漏洞不存在或查询失败", true), nil
}
projectID := ""
if pid, perr := db.GetConversationProjectID(convID); perr == nil {
projectID = strings.TrimSpace(pid)
}
if !canAccessVulnerability(vuln, convID, projectID) {
return textResult("错误: 无权访问该漏洞(仅可查看当前项目或当前会话下的记录)", true), nil
}
return textResult(formatVulnerabilityDetail(vuln), false), nil
})
}
+85 -66
View File
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
} }
} }
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出) // BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) { func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID)) b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
var reactInputFinal string var reactInputFinal string
var dataSource string // 记录数据来源 var dataSource string // 记录数据来源
// 如果成功获取到保存的ReAct数据,直接使用 // 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
if reactInputJSON != "" && modelOutput != "" { if reactInputJSON != "" {
// 计算 ReAct 输入的哈希值,用于追踪 trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
hash := sha256.Sum256([]byte(reactInputJSON)) hash := sha256.Sum256([]byte(trimmedJSON))
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识 reactInputHash := hex.EncodeToString(hash[:])[:16]
// 统计消息数量
var messageCount int var messageCount int
var tempMessages []interface{} if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil { messageCount = len(msgs)
messageCount = len(tempMessages) msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
} else {
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
if strings.TrimSpace(modelOutput) != "" {
reactInputFinal += "\n\n## 助手结论(last_react_output\n\n" + modelOutput
}
} }
dataSource = "database_last_agent_trace" dataSource = "last_user_turn_agent_trace"
b.logger.Info("使用保存的ReAct数据构建攻击链", b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)), zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
zap.Int("messageCount", messageCount), zap.Int("messageCount", messageCount),
zap.String("reactInputHash", reactInputHash), zap.String("reactInputHash", reactInputHash),
zap.Int("modelOutputSize", len(modelOutput))) zap.Int("modelOutputSize", len(modelOutput)))
// 从保存的ReAct输入(JSON格式)中提取用户输入
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
} else { } else {
// 2. 如果没有保存的ReAct数据,从对话消息构建 // 2. 如果没有保存的ReAct数据,从对话消息构建
dataSource = "messages_table" dataSource = "messages_table"
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
} }
// 3. 构建简化的prompt,一次性传递给大模型 // 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput) reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
promptAssistantOut := modelOutput
if reactInputJSON != "" {
promptAssistantOut = ""
}
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
// fmt.Println(prompt) // fmt.Println(prompt)
// 6. 调用AI生成攻击链(一次性,不做任何处理) // 6. 调用AI生成攻击链(一次性,不做任何处理)
chainJSON, err := b.callAIForChainGeneration(ctx, prompt) chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
return strings.TrimSpace(sb.String()) return strings.TrimSpace(sb.String())
} }
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入) // buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
func (b *Builder) buildAgentTraceInput(messages []database.Message) string { func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
start := 0
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
start = i
break
}
}
var builder strings.Builder var builder strings.Builder
for _, msg := range messages { for _, msg := range messages[start:] {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
} }
return builder.String() return builder.String()
@@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
// return "" // return ""
// } // }
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 // formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string { func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{} trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { msgs, err := agent.ParseTraceMessages(trimmed)
if err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return reactInputJSON // 如果解析失败,返回原始JSON return trimmed
} }
return b.formatAgentTraceFromChatMessages(msgs)
}
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
var builder strings.Builder var builder strings.Builder
for _, msg := range messages { for _, msg := range msgs {
role, _ := msg["role"].(string) role := msg.Role
content, _ := msg["content"].(string) content := msg.Content
// 处理assistant消息:提取tool_calls信息 if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
if role == "assistant" { if content != "" {
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
// 如果有文本内容,先显示 }
if content != "" { builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content)) for i, tc := range msg.ToolCalls {
} args := ""
// 详细显示每个工具调用 if tc.Function.Arguments != nil {
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls))) if b, err := json.Marshal(tc.Function.Arguments); err == nil {
for i, toolCall := range toolCalls { args = string(b)
if tc, ok := toolCall.(map[string]interface{}); ok {
toolCallID, _ := tc["id"].(string)
if funcData, ok := tc["function"].(map[string]interface{}); ok {
toolName, _ := funcData["name"].(string)
arguments, _ := funcData["arguments"].(string)
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
}
} }
} }
builder.WriteString("\n") builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
continue builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
} }
builder.WriteString("\n")
continue
} }
// 处理tool消息:显示tool_call_id和完整内容 if strings.EqualFold(role, "tool") {
if role == "tool" { if msg.ToolCallID != "" {
toolCallID, _ := msg["tool_call_id"].(string) builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
if toolCallID != "" {
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
} else { } else {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
} }
continue continue
} }
// 其他消息类型(system, user等)正常显示
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
} }
return builder.String() return builder.String()
} }
// buildSimplePrompt 构建简化的prompt // buildSimplePrompt 构建简化的prompt
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径 return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)
## 输入范围(与「继续对话」续跑一致)
- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。
- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。
## 核心目标 ## 核心目标
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability 5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径) 6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
## 最后一轮ReAct输入 ## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant)
%s %s
## 大模型输出
%s %s
## 输出格式 ## 输出格式
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。 9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。 10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
现在开始分析并构建攻击链:`, reactInput, modelOutput) 现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput))
}
func assistantOutSection(modelOutput string) string {
modelOutput = strings.TrimSpace(modelOutput)
if modelOutput == "" {
return ""
}
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
} }
// saveChain 保存攻击链到数据库 // saveChain 保存攻击链到数据库
@@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
}, },
}, },
"temperature": 0.3, "temperature": 0.3,
"max_completion_tokens": 80000, "max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
} }
var apiResponse struct { var apiResponse struct {
+248
View File
@@ -0,0 +1,248 @@
package attackchain
import (
"strings"
"unicode/utf8"
"go.uber.org/zap"
)
const (
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
attackChainSystemReserve = 256
attackChainSafetyReserve = 2048
)
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
func attackChainMaxCompletionTokens(maxTotal int) int {
const capTokens = 16384
if maxTotal <= 0 {
return 8192
}
v := maxTotal / 8
if v < 4096 {
v = 4096
}
if v > capTokens {
v = capTokens
}
return v
}
func (b *Builder) modelName() string {
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
return b.openAIConfig.Model
}
return "gpt-4"
}
func (b *Builder) countTokens(text string) int {
if text == "" {
return 0
}
n, err := b.tokenCounter.Count(b.modelName(), text)
if err != nil {
return utf8.RuneCountInString(text) / 4
}
return n
}
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
func (b *Builder) attackChainPayloadTokenBudget() int {
maxTotal := b.maxTokens
if maxTotal <= 0 {
maxTotal = 100000
}
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
completion := attackChainMaxCompletionTokens(maxTotal)
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
budget := maxTotal - reserve
minBudget := maxTotal * 35 / 100
if budget < minBudget {
budget = minBudget
}
if budget < 4096 {
budget = 4096
}
return budget
}
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
budget := b.attackChainPayloadTokenBudget()
modelBudget := budget * 15 / 100
if modelBudget < 512 {
modelBudget = 512
}
reactBudget := budget - modelBudget
origReactTok := b.countTokens(reactInput)
origModelTok := b.countTokens(modelOutput)
truncated := false
outModel := modelOutput
if origModelTok > modelBudget {
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
truncated = true
}
outReact := reactInput
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
for _, lim := range perToolLimits {
compact := compactFormattedToolBodies(outReact, lim)
if compact != outReact {
outReact = compact
truncated = true
}
if b.countTokens(outReact) <= reactBudget {
break
}
}
if b.countTokens(outReact) > reactBudget {
outReact = truncateTextByTokens(b, outReact, reactBudget)
truncated = true
}
if truncated {
b.logger.Info("攻击链输入已按 token 预算截断",
zap.Int("maxTotalTokens", b.maxTokens),
zap.Int("payloadBudget", budget),
zap.Int("reactBudget", reactBudget),
zap.Int("modelBudget", modelBudget),
zap.Int("reactInputTokensBefore", origReactTok),
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
zap.Int("modelOutputTokensBefore", origModelTok),
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
)
}
return outReact, outModel, truncated
}
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
if maxRunesPerBody <= 0 || s == "" {
return s
}
const marker = "[tool]"
var out strings.Builder
remaining := s
changed := false
for {
idx := strings.Index(remaining, marker)
if idx < 0 {
out.WriteString(remaining)
break
}
out.WriteString(remaining[:idx])
remaining = remaining[idx:]
nl := strings.IndexByte(remaining, '\n')
if nl < 0 {
out.WriteString(remaining)
break
}
header := remaining[:nl+1]
remaining = remaining[nl+1:]
bodyEnd := strings.Index(remaining, "\n\n[")
var body, rest string
if bodyEnd < 0 {
body = remaining
rest = ""
} else {
body = remaining[:bodyEnd]
rest = remaining[bodyEnd:]
}
if runeLen(body) > maxRunesPerBody {
body = truncateRunesWithNotice(body, maxRunesPerBody)
changed = true
}
out.WriteString(header)
out.WriteString(body)
remaining = rest
if rest == "" {
break
}
}
if !changed {
return s
}
return out.String()
}
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
if maxTokens <= 0 || text == "" {
return ""
}
if b.countTokens(text) <= maxTokens {
return text
}
markerTok := b.countTokens(attackChainTruncationMarker)
usable := maxTokens - markerTok
if usable < 256 {
usable = maxTokens / 2
}
headBudget := usable * 60 / 100
tailBudget := usable - headBudget
head := takeTokensFromStart(b, text, headBudget)
tail := takeTokensFromEnd(b, text, tailBudget)
return head + attackChainTruncationMarker + tail
}
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
rs := []rune(text)
if len(rs) == 0 || maxTokens <= 0 {
return ""
}
lo, hi := 0, len(rs)
for lo < hi {
mid := (lo + hi + 1) / 2
if b.countTokens(string(rs[:mid])) <= maxTokens {
lo = mid
} else {
hi = mid - 1
}
}
return string(rs[:lo])
}
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
rs := []rune(text)
if len(rs) == 0 || maxTokens <= 0 {
return ""
}
lo, hi := 0, len(rs)
for lo < hi {
mid := (lo + hi) / 2
if b.countTokens(string(rs[mid:])) <= maxTokens {
hi = mid
} else {
lo = mid + 1
}
}
return string(rs[lo:])
}
func truncateRunesWithNotice(s string, maxRunes int) string {
rs := []rune(s)
if len(rs) <= maxRunes {
return s
}
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
noticeRunes := []rune(notice)
keep := maxRunes - len(noticeRunes)
if keep < 200 {
keep = maxRunes * 2 / 3
}
if keep < 1 {
return notice
}
head := keep * 70 / 100
tail := keep - head
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
}
func runeLen(s string) int {
return len([]rune(s))
}
+63
View File
@@ -0,0 +1,63 @@
package attackchain
import (
"strings"
"testing"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func testBuilder(maxTotal int) *Builder {
return &Builder{
logger: zap.NewNop(),
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
tokenCounter: agent.NewTikTokenCounter(),
maxTokens: maxTotal,
}
}
func TestCompactFormattedToolBodies(t *testing.T) {
long := strings.Repeat("x", 20000)
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
out := compactFormattedToolBodies(in, 500)
if strings.Contains(out, strings.Repeat("x", 10000)) {
t.Fatal("expected tool body to be truncated")
}
if !strings.Contains(out, "[user]: hi") {
t.Fatal("expected user header preserved")
}
if !strings.Contains(out, "[assistant]: done") {
t.Fatal("expected assistant header preserved")
}
}
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
b := testBuilder(32000)
react := strings.Repeat("scan ", 50000)
model := strings.Repeat("result ", 10000)
r, m, truncated := b.fitAttackChainPayload(react, model)
if !truncated {
t.Fatal("expected truncation for large payload")
}
prompt := b.buildSimplePrompt(r, m)
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
if total > b.maxTokens+attackChainSafetyReserve {
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
}
_ = m
}
func TestAttackChainMaxCompletionTokens(t *testing.T) {
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
// 120000/8 = 15000
if got < 4096 || got > 16384 {
t.Fatalf("unexpected completion cap: %d", got)
}
}
if got := attackChainMaxCompletionTokens(0); got != 8192 {
t.Fatalf("expected default 8192, got %d", got)
}
}
+55
View File
@@ -0,0 +1,55 @@
package audit
import (
"strings"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
)
// RegisterConversationCreateHook records platform audit rows for every new conversation.
func RegisterConversationCreateHook(s *Service) {
if s == nil {
return
}
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
detail := map[string]interface{}{
"title": conv.Title,
"source": meta.Source,
}
if meta.WebShellConnectionID != "" {
detail["webshell_connection_id"] = meta.WebShellConnectionID
}
s.Record(nil, Entry{
Category: "conversation",
Action: "create",
Result: "success",
Message: "创建对话",
ResourceType: "conversation",
ResourceID: conv.ID,
Detail: detail,
ClientIP: meta.ClientIP,
SessionHint: meta.SessionHint,
})
})
}
// ConversationCreateMeta builds audit metadata for conversation creation.
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
}
// ConversationCreateMetaFromGin includes client IP and session hint when available.
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
m := ConversationCreateMeta(source)
if c == nil {
return m
}
m.ClientIP = c.ClientIP()
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
m.SessionHint = sessionHint(token)
}
return m
}
+9
View File
@@ -0,0 +1,9 @@
package audit
// RetentionDays returns configured retention; 0 means keep forever.
func (s *Service) RetentionDays() int {
if s == nil || s.cfg == nil {
return 0
}
return s.cfg.Audit.RetentionDaysEffective()
}
+29
View File
@@ -0,0 +1,29 @@
package audit
import "github.com/gin-gonic/gin"
// RecordAction writes a platform audit row with common defaults.
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
if s == nil {
return
}
s.Record(c, Entry{
Category: category,
Action: action,
Result: result,
Message: message,
ResourceType: resourceType,
ResourceID: resourceID,
Detail: detail,
})
}
// RecordOK is a shorthand for successful operations.
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
}
// RecordFail is a shorthand for failed operations.
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
s.RecordAction(c, category, action, "failure", message, "", "", detail)
}
+86
View File
@@ -0,0 +1,86 @@
package audit
import (
"strings"
"cyberstrike-ai/internal/database"
)
var auditActionsResourceRemoved = map[string]bool{
"delete": true,
"item_delete": true,
"connection_delete": true,
"listener_delete": true,
"session_delete": true,
"task_delete": true,
"execution_delete": true,
"execution_delete_batch": true,
"delete_queue": true,
"delete_batch_task": true,
"markdown_delete": true,
}
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
return
}
if auditActionsResourceRemoved[log.Action] {
f := false
log.ResourceAvailable = &f
return
}
if db == nil {
return
}
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
if known {
log.ResourceAvailable = &available
}
}
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
resourceID = strings.TrimSpace(resourceID)
if resourceID == "" {
return false, false
}
t := strings.TrimSpace(resourceType)
if t == "" {
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
t = "conversation"
} else {
return false, false
}
}
switch t {
case "conversation":
ok, err := db.ConversationExists(resourceID)
return ok, err == nil
case "vulnerability":
_, err := db.GetVulnerability(resourceID)
if err != nil {
return false, strings.Contains(err.Error(), "不存在")
}
return true, true
case "batch_queue":
_, err := db.GetBatchQueue(resourceID)
return err == nil, true
case "c2_listener":
_, err := db.GetC2Listener(resourceID)
return err == nil, true
case "c2_session":
_, err := db.GetC2Session(resourceID)
return err == nil, true
case "c2_task":
_, err := db.GetC2Task(resourceID)
return err == nil, true
case "webshell_connection":
c, err := db.GetWebshellConnection(resourceID)
return err == nil && c != nil, true
case "tool_execution":
_, err := db.GetToolExecution(resourceID)
return err == nil, true
default:
return false, false
}
}
+27
View File
@@ -0,0 +1,27 @@
package audit
import (
"time"
"go.uber.org/zap"
)
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
const auditRetentionPurgeInterval = time.Hour
// StartRetentionLoop periodically purges expired audit rows.
func StartRetentionLoop(s *Service, logger *zap.Logger) {
if s == nil {
return
}
go func() {
ticker := time.NewTicker(auditRetentionPurgeInterval)
defer ticker.Stop()
for range ticker.C {
s.PurgeExpired()
if logger != nil {
logger.Debug("audit retention tick completed")
}
}
}()
}
+58
View File
@@ -0,0 +1,58 @@
package audit
import (
"encoding/json"
"strings"
)
var sensitiveKeySubstrings = []string{
"password", "api_key", "apikey", "secret", "token", "authorization",
"credential", "private_key", "access_key",
}
// SanitizeDetail redacts sensitive keys and truncates serialized size.
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
if detail == nil {
return nil
}
if maxBytes <= 0 {
maxBytes = 8192
}
out := sanitizeValue("", detail)
if m, ok := out.(map[string]interface{}); ok {
b, _ := json.Marshal(m)
if len(b) > maxBytes {
return map[string]interface{}{
"_truncated": true,
"_preview": string(b[:maxBytes]),
}
}
return m
}
return map[string]interface{}{"value": out}
}
func sanitizeValue(key string, v interface{}) interface{} {
kl := strings.ToLower(key)
for _, sub := range sensitiveKeySubstrings {
if strings.Contains(kl, sub) {
return "***"
}
}
switch t := v.(type) {
case map[string]interface{}:
m := make(map[string]interface{}, len(t))
for k, val := range t {
m[k] = sanitizeValue(k, val)
}
return m
case []interface{}:
arr := make([]interface{}, len(t))
for i, val := range t {
arr[i] = sanitizeValue(key, val)
}
return arr
default:
return v
}
}
+172
View File
@@ -0,0 +1,172 @@
package audit
import (
"crypto/sha256"
"encoding/hex"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Service persists platform audit logs.
type Service struct {
db *database.DB
cfg *config.Config
logger *zap.Logger
failThrottle *failureThrottle
}
// NewService creates an audit service.
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
return &Service{
db: db,
cfg: cfg,
logger: logger,
failThrottle: newFailureThrottle(),
}
}
// Enabled reports whether audit persistence is on.
func (s *Service) Enabled() bool {
if s == nil || s.cfg == nil {
return false
}
return s.cfg.Audit.EnabledEffective()
}
// Record writes one audit row from a Gin request context.
func (s *Service) Record(c *gin.Context, e Entry) {
if s == nil || !s.Enabled() || s.db == nil {
return
}
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
return
}
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
return
}
if strings.TrimSpace(e.Result) == "" {
e.Result = "success"
}
if strings.TrimSpace(e.Level) == "" {
if e.Result == "failure" {
e.Level = "warn"
} else {
e.Level = "info"
}
}
if strings.TrimSpace(e.Actor) == "" {
e.Actor = "admin"
}
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
detail := SanitizeDetail(e.Detail, maxDetail)
sessionHintVal := e.SessionHint
if sessionHintVal == "" && c != nil {
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
sessionHintVal = sessionHint(token)
}
}
clientIPVal := e.ClientIP
if clientIPVal == "" {
clientIPVal = clientIP(c)
}
row := &database.AuditLog{
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
CreatedAt: time.Now(),
Level: e.Level,
Category: e.Category,
Action: e.Action,
Result: e.Result,
Actor: e.Actor,
SessionHint: sessionHintVal,
ClientIP: clientIPVal,
UserAgent: userAgent(c),
ResourceType: e.ResourceType,
ResourceID: e.ResourceID,
Message: e.Message,
Detail: detail,
}
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
s.logger.Warn("写入审计日志失败",
zap.String("action", e.Action),
zap.Error(err),
)
}
}
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
func (s *Service) RecordSystem(e Entry) {
s.Record(nil, e)
}
// PurgeExpired deletes rows older than retention_days when configured.
func (s *Service) PurgeExpired() {
if s == nil || s.db == nil || s.cfg == nil {
return
}
days := s.cfg.Audit.RetentionDaysEffective()
if days <= 0 {
return
}
cutoff := time.Now().AddDate(0, 0, -days)
n, err := s.db.DeleteAuditLogsBefore(cutoff)
if err != nil {
if s.logger != nil {
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
}
return
}
if n > 0 && s.logger != nil {
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
}
}
// HintFromToken returns a short stable hash prefix for a session token.
func HintFromToken(token string) string {
return sessionHint(token)
}
func sessionHint(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:4])
}
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
if !isAuthFailureThrottled(e.Category, e.Action) {
return true
}
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
return s.failThrottle.allow(key, cooldown)
}
func clientIP(c *gin.Context) string {
if c == nil {
return ""
}
return c.ClientIP()
}
func userAgent(c *gin.Context) string {
if c == nil {
return ""
}
ua := c.GetHeader("User-Agent")
if len(ua) > 512 {
return ua[:512]
}
return ua
}
+55
View File
@@ -0,0 +1,55 @@
package audit
import (
"sync"
"time"
)
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
type failureThrottle struct {
mu sync.Mutex
last map[string]time.Time
}
func newFailureThrottle() *failureThrottle {
return &failureThrottle{last: make(map[string]time.Time)}
}
// allow reports whether a row with the given key may be written now.
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
if t == nil || cooldown <= 0 || key == "" {
return true
}
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
return false
}
t.last[key] = now
if len(t.last) > 4096 {
for k, ts := range t.last {
if now.Sub(ts) > cooldown*2 {
delete(t.last, k)
}
}
}
return true
}
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
func authFailureThrottleKey(category, action, clientIP string) string {
return category + ":" + action + ":" + clientIP
}
func isAuthFailureThrottled(category, action string) bool {
if category != "auth" {
return false
}
switch action {
case "login", "change_password":
return true
default:
return false
}
}
+16
View File
@@ -0,0 +1,16 @@
package audit
// Entry describes one platform audit record (not chat/tool execution bodies).
type Entry struct {
Level string
Category string
Action string
Result string // success | failure
Actor string
SessionHint string
ResourceType string
ResourceID string
Message string
Detail map[string]interface{}
ClientIP string // optional when c is nil (robot, batch, DB hook)
}
+123 -11
View File
@@ -26,6 +26,7 @@ type Config struct {
Security SecurityConfig `yaml:"security"` Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用 C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
@@ -35,13 +36,39 @@ type Config struct {
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录 SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.mdYAML front matter AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.mdYAML front matter
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"` MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
Project ProjectConfig `yaml:"project,omitempty" json:"project,omitempty"`
}
// ProjectConfig 项目黑板(跨对话共享事实)配置。
type ProjectConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
}
// FactIndexMaxRunesEffective 自动注入黑板索引的最大 rune 数。
func (c ProjectConfig) FactIndexMaxRunesEffective() int {
if c.FactIndexMaxRunes <= 0 {
return 3500
}
return c.FactIndexMaxRunes
}
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数。
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
if c.FactSummaryMaxRunes <= 0 {
return 120
}
return c.FactSummaryMaxRunes
} }
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。 // MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
type MultiAgentConfig struct { type MultiAgentConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理 RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理 BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。 // Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"` Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor
@@ -227,6 +254,10 @@ type MultiAgentEinoMiddlewareConfig struct {
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"` DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries). // DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"` DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
// RunRetryMaxAttempts > 0429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended). // TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"` TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
} }
@@ -362,9 +393,9 @@ type MultiAgentSubConfig struct {
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。 // MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
type MultiAgentPublic struct { type MultiAgentPublic struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"` RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"` BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
SubAgentCount int `json:"sub_agent_count"` SubAgentCount int `json:"sub_agent_count"`
Orchestration string `json:"orchestration,omitempty"` Orchestration string `json:"orchestration,omitempty"`
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"` PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
@@ -372,6 +403,18 @@ type MultiAgentPublic struct {
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"` ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
} }
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
if s == "" || s == "single" || s == "react" {
return "react"
}
if s == "eino_single" {
return "eino_single"
}
return NormalizeMultiAgentOrchestration(s)
}
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。 // NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
func NormalizeMultiAgentOrchestration(s string) string { func NormalizeMultiAgentOrchestration(s string) string {
v := strings.TrimSpace(strings.ToLower(s)) v := strings.TrimSpace(strings.ToLower(s))
@@ -387,22 +430,35 @@ func NormalizeMultiAgentOrchestration(s string) string {
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。 // MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
type MultiAgentAPIUpdate struct { type MultiAgentAPIUpdate struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"` RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"` BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"` PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。 // 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"` ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
} }
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等) // RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
type RobotsConfig struct { type RobotsConfig struct {
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略 Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信 Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉 Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书 Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
} }
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
type RobotWechatConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
}
// RobotSessionConfig 机器人会话隔离策略 // RobotSessionConfig 机器人会话隔离策略
type RobotSessionConfig struct { type RobotSessionConfig struct {
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底 StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
@@ -484,7 +540,7 @@ type OpenAIConfig struct {
type OpenAIReasoningConfig struct { type OpenAIReasoningConfig struct {
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。 // Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning // Effort: low | medium | high | max | xhighmax/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"` Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
// AllowClientReasoning 为 false 时忽略请求体 reasoningnil 或未设置等同于 true。 // AllowClientReasoning 为 false 时忽略请求体 reasoningnil 或未设置等同于 true。
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"` AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
@@ -562,6 +618,51 @@ type AuthConfig struct {
GeneratedPasswordPersistErr string `yaml:"-" json:"-"` GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
} }
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
type AuditConfig struct {
// Enabled nil or true enables persistence; explicit false disables.
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
}
// EnabledEffective returns true unless audit.enabled is explicitly false.
func (a AuditConfig) EnabledEffective() bool {
if a.Enabled == nil {
return true
}
return *a.Enabled
}
// RetentionDaysEffective returns retention; 0 means keep forever.
func (a AuditConfig) RetentionDaysEffective() int {
if a.RetentionDays < 0 {
return 0
}
return a.RetentionDays
}
// MaxDetailBytesEffective caps serialized detail JSON size.
func (a AuditConfig) MaxDetailBytesEffective() int {
if a.MaxDetailBytes <= 0 {
return 8192
}
return a.MaxDetailBytes
}
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
func (a AuditConfig) AuthFailureCooldownEffective() int {
if a.AuthFailureCooldownSeconds < 0 {
return 0
}
if a.AuthFailureCooldownSeconds == 0 {
return 60
}
return a.AuthFailureCooldownSeconds
}
// ExternalMCPConfig 外部MCP配置 // ExternalMCPConfig 外部MCP配置
type ExternalMCPConfig struct { type ExternalMCPConfig struct {
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"` Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
@@ -654,6 +755,9 @@ func Load(path string) (*Config, error) {
if cfg.Auth.SessionDurationHours <= 0 { if cfg.Auth.SessionDurationHours <= 0 {
cfg.Auth.SessionDurationHours = 12 cfg.Auth.SessionDurationHours = 12
} }
if cfg.Audit.MaxDetailBytes <= 0 {
cfg.Audit.MaxDetailBytes = 8192
}
if strings.TrimSpace(cfg.Auth.Password) == "" { if strings.TrimSpace(cfg.Auth.Password) == "" {
password, err := generateStrongPassword(24) password, err := generateStrongPassword(24)
if err != nil { if err != nil {
@@ -1157,6 +1261,14 @@ func Default() *Config {
Auth: AuthConfig{ Auth: AuthConfig{
SessionDurationHours: 12, SessionDurationHours: 12,
}, },
Audit: func() AuditConfig {
on := true
return AuditConfig{
RetentionDays: 90,
MaxDetailBytes: 8192,
Enabled: &on,
}
}(),
Robots: RobotsConfig{ Robots: RobotsConfig{
Session: RobotSessionConfig{ Session: RobotSessionConfig{
StrictUserIdentity: &strictRobotIdentity, StrictUserIdentity: &strictRobotIdentity,
+210
View File
@@ -0,0 +1,210 @@
package database
import (
"encoding/json"
"errors"
"strings"
"time"
)
// AuditLog platform operation audit record.
type AuditLog struct {
ID string `json:"id"`
CreatedAt time.Time `json:"createdAt"`
Level string `json:"level"`
Category string `json:"category"`
Action string `json:"action"`
Result string `json:"result"`
Actor string `json:"actor"`
SessionHint string `json:"sessionHint,omitempty"`
ClientIP string `json:"clientIp,omitempty"`
UserAgent string `json:"userAgent,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
ResourceID string `json:"resourceId,omitempty"`
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
Message string `json:"message"`
Detail map[string]interface{} `json:"detail,omitempty"`
}
// ListAuditLogsFilter query parameters.
type ListAuditLogsFilter struct {
Level string
Category string
Action string
Result string
Query string
ResourceType string
ResourceID string
Since *time.Time
Until *time.Time
Limit int
Offset int
}
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
conditions := []string{"1=1"}
args := []interface{}{}
if filter.Level != "" {
conditions = append(conditions, "level = ?")
args = append(args, filter.Level)
}
if filter.Category != "" {
conditions = append(conditions, "category = ?")
args = append(args, filter.Category)
}
if filter.Action != "" {
conditions = append(conditions, "action = ?")
args = append(args, filter.Action)
}
if filter.Result != "" {
conditions = append(conditions, "result = ?")
args = append(args, filter.Result)
}
if filter.ResourceType != "" {
conditions = append(conditions, "resource_type = ?")
args = append(args, filter.ResourceType)
}
if filter.ResourceID != "" {
conditions = append(conditions, "resource_id = ?")
args = append(args, filter.ResourceID)
}
if filter.Since != nil {
conditions = append(conditions, "created_at >= ?")
args = append(args, *filter.Since)
}
if filter.Until != nil {
conditions = append(conditions, "created_at <= ?")
args = append(args, *filter.Until)
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
args = append(args, like, like, like, like)
}
return strings.Join(conditions, " AND "), args
}
// AppendAuditLog inserts one audit row.
func (db *DB) AppendAuditLog(row *AuditLog) error {
if row == nil {
return errors.New("audit log is nil")
}
if strings.TrimSpace(row.ID) == "" {
return errors.New("audit id is required")
}
if row.CreatedAt.IsZero() {
row.CreatedAt = time.Now()
}
if strings.TrimSpace(row.Level) == "" {
row.Level = "info"
}
detailJSON := ""
if len(row.Detail) > 0 {
if b, err := json.Marshal(row.Detail); err == nil {
detailJSON = string(b)
}
}
query := `
INSERT INTO audit_logs (
id, created_at, level, category, action, result, actor, session_hint,
client_ip, user_agent, resource_type, resource_id, message, detail_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
row.ResourceType, row.ResourceID, row.Message, detailJSON,
)
return err
}
// GetAuditLogByID returns one row.
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
id = strings.TrimSpace(id)
if id == "" {
return nil, errors.New("id is required")
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs WHERE id = ?
`
var row AuditLog
var detailJSON string
err := db.QueryRow(query, id).Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
)
if err != nil {
return nil, err
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
return &row, nil
}
// CountAuditLogs counts rows matching filter.
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
where, args := buildAuditLogsWhere(filter)
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
var n int64
err := db.QueryRow(query, args...).Scan(&n)
return n, err
}
// ListAuditLogs lists audit rows newest first.
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
where, args := buildAuditLogsWhere(filter)
limit := filter.Limit
if limit <= 0 || limit > 500 {
limit = 50
}
offset := filter.Offset
if offset < 0 {
offset = 0
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs
WHERE ` + where + `
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*AuditLog
for rows.Next() {
var row AuditLog
var detailJSON string
if err := rows.Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
); err != nil {
continue
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
list = append(list, &row)
}
return list, rows.Err()
}
// DeleteAuditLogsBefore removes rows older than cutoff.
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
+14 -8
View File
@@ -22,6 +22,7 @@ type BatchTaskQueueRow struct {
LastScheduleTriggerAt sql.NullTime LastScheduleTriggerAt sql.NullTime
LastScheduleError sql.NullString LastScheduleError sql.NullString
LastRunError sql.NullString LastRunError sql.NullString
ProjectID sql.NullString
Status string Status string
CreatedAt time.Time CreatedAt time.Time
StartedAt sql.NullTime StartedAt sql.NullTime
@@ -51,6 +52,7 @@ func (db *DB) CreateBatchQueue(
scheduleMode string, scheduleMode string,
cronExpr string, cronExpr string,
nextRunAt *time.Time, nextRunAt *time.Time,
projectID string,
tasks []map[string]interface{}, tasks []map[string]interface{},
) error { ) error {
tx, err := db.Begin() tx, err := db.Begin()
@@ -65,9 +67,13 @@ func (db *DB) CreateBatchQueue(
nextRunAtValue = *nextRunAt nextRunAtValue = *nextRunAt
} }
var projectIDVal interface{}
if strings.TrimSpace(projectID) != "" {
projectIDVal = strings.TrimSpace(projectID)
}
_, err = tx.Exec( _, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0, queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
) )
if err != nil { if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err) return fmt.Errorf("创建批量任务队列失败: %w", err)
@@ -101,9 +107,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
err := db.QueryRow( err := db.QueryRow(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
queueID, queueID,
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@@ -127,7 +133,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
// GetAllBatchQueues 获取所有批量任务队列 // GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
@@ -138,7 +144,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -158,7 +164,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
// ListBatchQueues 列出批量任务队列(支持筛选和分页) // ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
args := []interface{}{} args := []interface{}{}
// 状态筛选 // 状态筛选
@@ -186,7 +192,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
+70 -15
View File
@@ -17,6 +17,7 @@ import (
type Conversation struct { type Conversation struct {
ID string `json:"id"` ID string `json:"id"`
Title string `json:"title"` Title string `json:"title"`
ProjectID string `json:"projectId,omitempty"`
Pinned bool `json:"pinned"` Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"` UpdatedAt time.Time `json:"updatedAt"`
@@ -37,22 +38,41 @@ type Message struct {
} }
// CreateConversation 创建新对话 // CreateConversation 创建新对话
func (db *DB) CreateConversation(title string) (*Conversation, error) { func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
return db.CreateConversationWithWebshell("", title) return db.CreateConversationWithWebshell("", title, meta)
} }
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话) // CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) { func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
id := uuid.New().String() id := uuid.New().String()
now := time.Now() now := time.Now()
projectID := strings.TrimSpace(meta.ProjectID)
if projectID != "" {
if _, err := db.GetProject(projectID); err != nil {
return nil, err
}
}
var err error var err error
if webshellConnectionID != "" { wsID := strings.TrimSpace(webshellConnectionID)
switch {
case wsID != "" && projectID != "":
_, err = db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id, project_id) VALUES (?, ?, ?, ?, ?, ?)",
id, title, now, now, wsID, projectID,
)
case wsID != "":
_, err = db.Exec( _, err = db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)", "INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
id, title, now, now, webshellConnectionID, id, title, now, now, wsID,
) )
} else { case projectID != "":
_, err = db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at, project_id) VALUES (?, ?, ?, ?, ?)",
id, title, now, now, projectID,
)
default:
_, err = db.Exec( _, err = db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
id, title, now, now, id, title, now, now,
@@ -62,12 +82,18 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
return nil, fmt.Errorf("创建对话失败: %w", err) return nil, fmt.Errorf("创建对话失败: %w", err)
} }
return &Conversation{ conv := &Conversation{
ID: id, ID: id,
Title: title, Title: title,
ProjectID: projectID,
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
}, nil }
if wsID != "" {
meta.WebShellConnectionID = wsID
}
notifyConversationCreated(conv, meta)
return conv, nil
} }
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化) // GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
@@ -182,22 +208,43 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
return list, rows.Err() return list, rows.Err()
} }
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
func (db *DB) ConversationExists(id string) (bool, error) {
id = strings.TrimSpace(id)
if id == "" {
return false, nil
}
var one int
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// GetConversation 获取对话 // GetConversation 获取对话
func (db *DB) GetConversation(id string) (*Conversation, error) { func (db *DB) GetConversation(id string) (*Conversation, error) {
var conv Conversation var conv Conversation
var createdAt, updatedAt string var createdAt, updatedAt string
var pinned int var pinned int
var projectID sql.NullString
err := db.QueryRow( err := db.QueryRow(
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", "SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
id, id,
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, fmt.Errorf("对话不存在") return nil, fmt.Errorf("对话不存在")
} }
return nil, fmt.Errorf("查询对话失败: %w", err) return nil, fmt.Errorf("查询对话失败: %w", err)
} }
if projectID.Valid {
conv.ProjectID = strings.TrimSpace(projectID.String)
}
// 尝试多种时间格式解析 // 尝试多种时间格式解析
var err1, err2 error var err1, err2 error
@@ -270,16 +317,20 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
var createdAt, updatedAt string var createdAt, updatedAt string
var pinned int var pinned int
var projectID sql.NullString
err := db.QueryRow( err := db.QueryRow(
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", "SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?",
id, id,
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, fmt.Errorf("对话不存在") return nil, fmt.Errorf("对话不存在")
} }
return nil, fmt.Errorf("查询对话失败: %w", err) return nil, fmt.Errorf("查询对话失败: %w", err)
} }
if projectID.Valid {
conv.ProjectID = strings.TrimSpace(projectID.String)
}
// 尝试多种时间格式解析 // 尝试多种时间格式解析
var err1, err2 error var err1, err2 error
@@ -319,7 +370,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
searchPattern := "%" + search + "%" searchPattern := "%" + search + "%"
rows, err = db.Query( rows, err = db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
FROM conversations c FROM conversations c
WHERE c.title LIKE ? WHERE c.title LIKE ?
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
@@ -329,7 +380,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
) )
} else { } else {
rows, err = db.Query( rows, err = db.Query(
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
limit, offset, limit, offset,
) )
} }
@@ -344,10 +395,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
var conv Conversation var conv Conversation
var createdAt, updatedAt string var createdAt, updatedAt string
var pinned int var pinned int
var projectID sql.NullString
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil { if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err) return nil, fmt.Errorf("扫描对话失败: %w", err)
} }
if projectID.Valid {
conv.ProjectID = strings.TrimSpace(projectID.String)
}
// 尝试多种时间格式解析 // 尝试多种时间格式解析
var err1, err2 error var err1, err2 error
@@ -0,0 +1,30 @@
package database
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
type ConversationCreateMeta struct {
Source string
WebShellConnectionID string
ProjectID string
ClientIP string
SessionHint string
}
// ConversationCreateHook is invoked after a conversation row is inserted.
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
var conversationCreateHook ConversationCreateHook
// SetConversationCreateHook registers a global hook (e.g. platform audit).
func SetConversationCreateHook(h ConversationCreateHook) {
conversationCreateHook = h
}
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
if conversationCreateHook == nil || conv == nil {
return
}
if meta.Source == "" {
meta.Source = "unknown"
}
conversationCreateHook(conv, meta)
}
+124
View File
@@ -213,6 +213,40 @@ func (db *DB) initTables() error {
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);` );`
// 创建项目表
createProjectsTable := `
CREATE TABLE IF NOT EXISTS projects (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
scope_json TEXT,
status TEXT NOT NULL DEFAULT 'active',
pinned INTEGER NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
);`
// 创建项目事实表(黑板)
createProjectFactsTable := `
CREATE TABLE IF NOT EXISTS project_facts (
id TEXT PRIMARY KEY,
project_id TEXT NOT NULL,
fact_key TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'note',
summary TEXT NOT NULL DEFAULT '',
body TEXT,
confidence TEXT NOT NULL DEFAULT 'tentative',
source_conversation_id TEXT,
source_message_id TEXT,
pinned INTEGER NOT NULL DEFAULT 0,
supersedes_fact_id TEXT,
related_vulnerability_id TEXT,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
UNIQUE(project_id, fact_key)
);`
// 创建漏洞表 // 创建漏洞表
createVulnerabilitiesTable := ` createVulnerabilitiesTable := `
CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE TABLE IF NOT EXISTS vulnerabilities (
@@ -387,6 +421,24 @@ func (db *DB) initTables() error {
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);` );`
createAuditLogsTable := `
CREATE TABLE IF NOT EXISTS audit_logs (
id TEXT PRIMARY KEY,
created_at DATETIME NOT NULL,
level TEXT NOT NULL DEFAULT 'info',
category TEXT NOT NULL,
action TEXT NOT NULL,
result TEXT NOT NULL,
actor TEXT NOT NULL DEFAULT 'admin',
session_hint TEXT,
client_ip TEXT,
user_agent TEXT,
resource_type TEXT,
resource_id TEXT,
message TEXT NOT NULL,
detail_json TEXT
);`
createC2ProfilesTable := ` createC2ProfilesTable := `
CREATE TABLE IF NOT EXISTS c2_profiles ( CREATE TABLE IF NOT EXISTS c2_profiles (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -427,6 +479,12 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status);
CREATE INDEX IF NOT EXISTS idx_projects_updated_at ON projects(updated_at);
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at); CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title); CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
@@ -445,6 +503,10 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at); CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category); CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id); CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
` `
if _, err := db.Exec(createConversationsTable); err != nil { if _, err := db.Exec(createConversationsTable); err != nil {
@@ -494,6 +556,14 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建robot_user_sessions表失败: %w", err) return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
} }
if _, err := db.Exec(createProjectsTable); err != nil {
return fmt.Errorf("创建projects表失败: %w", err)
}
if _, err := db.Exec(createProjectFactsTable); err != nil {
return fmt.Errorf("创建project_facts表失败: %w", err)
}
if _, err := db.Exec(createVulnerabilitiesTable); err != nil { if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
return fmt.Errorf("创建vulnerabilities表失败: %w", err) return fmt.Errorf("创建vulnerabilities表失败: %w", err)
} }
@@ -514,6 +584,10 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建webshell_connection_states表失败: %w", err) return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
} }
if _, err := db.Exec(createAuditLogsTable); err != nil {
return fmt.Errorf("创建audit_logs表失败: %w", err)
}
for tableName, ddl := range map[string]string{ for tableName, ddl := range map[string]string{
"c2_listeners": createC2ListenersTable, "c2_listeners": createC2ListenersTable,
"c2_sessions": createC2SessionsTable, "c2_sessions": createC2SessionsTable,
@@ -557,6 +631,10 @@ func (db *DB) initTables() error {
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateProjectsTable(); err != nil {
db.logger.Warn("迁移projects相关表失败", zap.Error(err))
}
if err := db.migrateWebshellConnectionsTable(); err != nil { if err := db.migrateWebshellConnectionsTable(); err != nil {
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err)) db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
@@ -904,6 +982,51 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
} }
} }
var projectIDCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='project_id'").Scan(&projectIDCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(addErr))
}
}
} else if projectIDCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); err != nil {
db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(err))
}
}
return nil
}
// migrateProjectsTable 迁移 projects / conversations / vulnerabilities 的项目关联字段。
func (db *DB) migrateProjectsTable() error {
for _, col := range []struct {
table string
name string
stmt string
}{
{"conversations", "project_id", "ALTER TABLE conversations ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE SET NULL"},
{"vulnerabilities", "project_id", "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
} {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info(?) WHERE name=?", col.table, col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil return nil
} }
@@ -915,6 +1038,7 @@ func (db *DB) migrateVulnerabilitiesTable() error {
}{ }{
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"}, {name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"}, {name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
{name: "project_id", stmt: "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"},
} }
for _, col := range columns { for _, col := range columns {
+451
View File
@@ -0,0 +1,451 @@
package database
import (
"database/sql"
"fmt"
"regexp"
"strings"
"time"
"github.com/google/uuid"
)
var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`)
// ValidateFactKey 校验事实 key(项目内唯一标识)。
func ValidateFactKey(key string) error {
key = strings.TrimSpace(key)
if key == "" {
return fmt.Errorf("fact_key 不能为空")
}
if len(key) > 128 {
return fmt.Errorf("fact_key 过长(最多 128 字符)")
}
if !factKeyPattern.MatchString(key) {
return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头")
}
return nil
}
// Project 渗透测试项目(跨对话共享黑板)。
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
ScopeJSON string `json:"scope_json,omitempty"`
Status string `json:"status"` // active | archived
Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectFact 项目事实(黑板条目)。
type ProjectFact struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
FactKey string `json:"fact_key"`
Category string `json:"category"`
Summary string `json:"summary"`
Body string `json:"body"`
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
SourceConversationID string `json:"source_conversation_id,omitempty"`
SourceMessageID string `json:"source_message_id,omitempty"`
Pinned bool `json:"pinned"`
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectFactListFilter 事实列表筛选。
type ProjectFactListFilter struct {
Category string
Confidence string
Search string
}
// CreateProject 创建项目。
func (db *DB) CreateProject(p *Project) (*Project, error) {
if p.ID == "" {
p.ID = uuid.New().String()
}
if strings.TrimSpace(p.Status) == "" {
p.Status = "active"
}
now := time.Now()
if p.CreatedAt.IsZero() {
p.CreatedAt = now
}
p.UpdatedAt = now
_, err := db.Exec(
`INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建项目失败: %w", err)
}
return p, nil
}
// GetProject 获取项目。
func (db *DB) GetProject(id string) (*Project, error) {
var p Project
var pinned int
var createdAt, updatedAt string
err := db.QueryRow(
`SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
FROM projects WHERE id = ?`, id,
).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("项目不存在")
}
return nil, fmt.Errorf("获取项目失败: %w", err)
}
p.Pinned = pinned != 0
p.CreatedAt = parseDBTime(createdAt)
p.UpdatedAt = parseDBTime(updatedAt)
return &p, nil
}
// ListProjects 列出项目。
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) {
if limit <= 0 {
limit = 200
}
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
FROM projects WHERE 1=1`
args := []interface{}{}
if s := strings.TrimSpace(status); s != "" {
query += " AND status = ?"
args = append(args, s)
}
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("列出项目失败: %w", err)
}
defer rows.Close()
var out []*Project
for rows.Next() {
var p Project
var pinned int
var createdAt, updatedAt string
if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil {
return nil, err
}
p.Pinned = pinned != 0
p.CreatedAt = parseDBTime(createdAt)
p.UpdatedAt = parseDBTime(updatedAt)
out = append(out, &p)
}
return out, rows.Err()
}
// UpdateProject 更新项目。
func (db *DB) UpdateProject(p *Project) error {
p.UpdatedAt = time.Now()
_, err := db.Exec(
`UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`,
p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID,
)
if err != nil {
return fmt.Errorf("更新项目失败: %w", err)
}
return nil
}
// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理)。
func (db *DB) DeleteProject(id string) error {
_, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("删除项目失败: %w", err)
}
return nil
}
// GetConversationProjectID 返回对话绑定的项目 ID。
func (db *DB) GetConversationProjectID(conversationID string) (string, error) {
var pid sql.NullString
err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid)
if err != nil {
if err == sql.ErrNoRows {
return "", fmt.Errorf("对话不存在")
}
return "", err
}
if pid.Valid {
return strings.TrimSpace(pid.String), nil
}
return "", nil
}
// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。
func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
projectID = strings.TrimSpace(projectID)
if projectID != "" {
if _, err := db.GetProject(projectID); err != nil {
return err
}
}
var val interface{}
if projectID == "" {
val = nil
} else {
val = projectID
}
_, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID)
if err != nil {
return fmt.Errorf("设置对话项目失败: %w", err)
}
return nil
}
// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ?`
args := []interface{}{projectID}
if !includeDeprecated {
query += " AND confidence != 'deprecated'"
}
query += " ORDER BY pinned DESC, updated_at DESC"
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFacts(rows)
}
// ListProjectFacts 分页列出项目事实。
func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) {
if limit <= 0 {
limit = 100
}
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ?`
args := []interface{}{projectID}
if c := strings.TrimSpace(filter.Category); c != "" {
query += " AND category = ?"
args = append(args, c)
}
if c := strings.TrimSpace(filter.Confidence); c != "" {
query += " AND confidence = ?"
args = append(args, c)
}
if s := strings.TrimSpace(filter.Search); s != "" {
pat := "%" + s + "%"
query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)"
args = append(args, pat, pat, pat)
}
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFacts(rows)
}
// GetProjectFactByKey 按 key 获取事实。
func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) {
row := db.QueryRow(
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
projectID, factKey,
)
return scanProjectFactRow(row)
}
// GetProjectFact 按 ID 获取事实。
func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
row := db.QueryRow(
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE id = ?`, id,
)
return scanProjectFactRow(row)
}
// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。
func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
if err := ValidateFactKey(f.FactKey); err != nil {
return nil, err
}
if strings.TrimSpace(f.Category) == "" {
f.Category = "note"
}
if strings.TrimSpace(f.Confidence) == "" {
f.Confidence = "tentative"
}
now := time.Now()
existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey)
if err == nil && existing != nil {
f.ID = existing.ID
f.CreatedAt = existing.CreatedAt
f.UpdatedAt = now
_, err = db.Exec(
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
source_conversation_id = ?, source_message_id = ?, pinned = ?,
supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ?
WHERE id = ?`,
f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
)
if err != nil {
return nil, fmt.Errorf("更新事实失败: %w", err)
}
return f, nil
}
if f.ID == "" {
f.ID = uuid.New().String()
}
f.CreatedAt = now
f.UpdatedAt = now
_, err = db.Exec(
`INSERT INTO project_facts (
id, project_id, fact_key, category, summary, body, confidence,
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID),
f.CreatedAt, f.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建事实失败: %w", err)
}
return f, nil
}
// DeprecateProjectFact 将事实标记为 deprecated。
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
res, err := db.Exec(
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
time.Now(), projectID, factKey,
)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return fmt.Errorf("事实不存在")
}
return nil
}
// DeleteProjectFact 删除事实。
func (db *DB) DeleteProjectFact(id string) error {
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
return err
}
func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) {
var out []*ProjectFact
for rows.Next() {
f, err := scanProjectFactFromRows(rows)
if err != nil {
return nil, err
}
out = append(out, f)
}
return out, rows.Err()
}
func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
var f ProjectFact
var pinned int
var createdAt, updatedAt string
err := row.Scan(
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
&f.SourceConversationID, &f.SourceMessageID, &pinned,
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("事实不存在")
}
return nil, err
}
f.Pinned = pinned != 0
f.CreatedAt = parseDBTime(createdAt)
f.UpdatedAt = parseDBTime(updatedAt)
return &f, nil
}
func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
var f ProjectFact
var pinned int
var createdAt, updatedAt string
err := rows.Scan(
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
&f.SourceConversationID, &f.SourceMessageID, &pinned,
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
f.Pinned = pinned != 0
f.CreatedAt = parseDBTime(createdAt)
f.UpdatedAt = parseDBTime(updatedAt)
return &f, nil
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func nullIfEmpty(s string) interface{} {
if strings.TrimSpace(s) == "" {
return nil
}
return s
}
func parseDBTime(s string) time.Time {
s = strings.TrimSpace(s)
if s == "" {
return time.Time{}
}
// go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态
layouts := []string{
time.RFC3339Nano,
time.RFC3339,
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02T15:04:05-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05.999999999",
"2006-01-02T15:04:05",
}
for _, layout := range layouts {
if t, e := time.Parse(layout, s); e == nil {
return t
}
}
return time.Time{}
}
+93
View File
@@ -0,0 +1,93 @@
package database
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"go.uber.org/zap"
)
func TestParseDBTime_projectFactFormats(t *testing.T) {
cases := []string{
"2026-05-26 11:13:07.442143+08:00",
"2026-05-26 11:13:07",
"2026-05-26T11:13:07.442143+08:00",
}
for _, s := range cases {
got := parseDBTime(s)
if got.IsZero() {
t.Fatalf("parseDBTime(%q) returned zero", s)
}
}
}
func TestListProjectFacts_updatedAtJSON(t *testing.T) {
root, err := os.Getwd()
if err != nil {
t.Skip(err)
}
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
if _, err := os.Stat(dbPath); err != nil {
t.Skip("conversations.db not found")
}
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
projects, err := db.ListProjects("", 1, 0)
if err != nil || len(projects) == 0 {
t.Skip("no projects")
}
pid := projects[0].ID
list, err := db.ListProjectFacts(pid, ProjectFactListFilter{}, 5, 0)
if err != nil {
t.Fatal(err)
}
if len(list) == 0 {
t.Skip("no facts")
}
for _, f := range list {
if f.UpdatedAt.IsZero() {
t.Fatalf("fact %s UpdatedAt is zero after ListProjectFacts", f.FactKey)
}
b, err := json.Marshal(f)
if err != nil {
t.Fatal(err)
}
var m map[string]interface{}
if err := json.Unmarshal(b, &m); err != nil {
t.Fatal(err)
}
raw, ok := m["updated_at"].(string)
if !ok || raw == "" || raw[:4] == "0001" {
t.Fatalf("bad updated_at in JSON: %v", m["updated_at"])
}
}
}
func TestParseDBTime_zeroOnGarbage(t *testing.T) {
if !parseDBTime("").IsZero() {
t.Fatal("expected zero for empty")
}
}
// Ensure RFC3339 round-trip used by API is after year 2000.
func TestParseDBTime_marshalRoundTrip(t *testing.T) {
s := "2026-05-26 11:13:07.442143+08:00"
tm := parseDBTime(s)
b, err := json.Marshal(tm)
if err != nil {
t.Fatal(err)
}
var back time.Time
if err := json.Unmarshal(b, &back); err != nil {
t.Fatal(err)
}
if back.IsZero() {
t.Fatalf("unmarshal zero from %s", string(b))
}
}
+104 -78
View File
@@ -3,16 +3,94 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
) )
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
type VulnerabilityListFilter struct {
ID string
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
ConversationID string
ProjectID string
Severity string
Status string
TaskID string
ConversationTag string
TaskTag string
}
func escapeVulnerabilityLikePattern(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return "%" + s + "%"
}
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
if f.ID != "" {
query += " AND id = ?"
args = append(args, f.ID)
}
if f.ConversationID != "" {
query += " AND conversation_id = ?"
args = append(args, f.ConversationID)
}
if f.ProjectID != "" {
query += " AND project_id = ?"
args = append(args, f.ProjectID)
}
if f.TaskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, f.TaskID, f.TaskID)
}
if f.ConversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, f.ConversationTag)
}
if f.TaskTag != "" {
query += " AND task_tag = ?"
args = append(args, f.TaskTag)
}
if f.Severity != "" {
query += " AND severity = ?"
args = append(args, f.Severity)
}
if f.Status != "" {
query += " AND status = ?"
args = append(args, f.Status)
}
search := strings.TrimSpace(f.Search)
if search != "" {
pattern := escapeVulnerabilityLikePattern(search)
query += ` AND (
LOWER(id) LIKE LOWER(?) OR
LOWER(title) LIKE LOWER(?) OR
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
)`
for i := 0; i < 11; i++ {
args = append(args, pattern)
}
}
return query, args
}
// Vulnerability 漏洞 // Vulnerability 漏洞
type Vulnerability struct { type Vulnerability struct {
ID string `json:"id"` ID string `json:"id"`
ConversationID string `json:"conversation_id"` ConversationID string `json:"conversation_id"`
ProjectID string `json:"project_id,omitempty"`
ConversationTag string `json:"conversation_tag,omitempty"` ConversationTag string `json:"conversation_tag,omitempty"`
TaskTag string `json:"task_tag,omitempty"` TaskTag string `json:"task_tag,omitempty"`
TaskID string `json:"task_id,omitempty"` TaskID string `json:"task_id,omitempty"`
@@ -44,17 +122,23 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
} }
vuln.UpdatedAt = now vuln.UpdatedAt = now
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID != "" {
if pid, err := db.GetConversationProjectID(vuln.ConversationID); err == nil {
vuln.ProjectID = pid
}
}
query := ` query := `
INSERT INTO vulnerabilities ( INSERT INTO vulnerabilities (
id, conversation_id, conversation_tag, task_tag, title, description, severity, status, id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt, vuln.CreatedAt, vuln.UpdatedAt,
@@ -70,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability var vuln Vulnerability
query := ` query := `
SELECT id, conversation_id, title, description, severity, status, SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status,
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
@@ -80,7 +164,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
` `
err := db.QueryRow(query, id).Scan( err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID, &vuln.TaskID, &vuln.TaskQueueID,
@@ -97,9 +181,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
} }
// ListVulnerabilities 列出漏洞 // ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) { func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
query := ` query := `
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
@@ -108,35 +192,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
WHERE 1=1 WHERE 1=1
` `
args := []interface{}{} args := []interface{}{}
query, args = filter.appendWhere(query, args)
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset) args = append(args, limit, offset)
@@ -151,7 +207,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
for rows.Next() { for rows.Next() {
var vuln Vulnerability var vuln Vulnerability
err := rows.Scan( err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID, &vuln.TaskID, &vuln.TaskQueueID,
@@ -168,38 +224,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
} }
// CountVulnerabilities 统计漏洞总数(支持筛选条件) // CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) { func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{} args := []interface{}{}
query, args = filter.appendWhere(query, args)
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
var count int var count int
err := db.QueryRow(query, args...).Scan(&count) err := db.QueryRow(query, args...).Scan(&count)
@@ -216,7 +244,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
query := ` query := `
UPDATE vulnerabilities UPDATE vulnerabilities
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?, SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ? recommendation = ?, updated_at = ?
WHERE id = ? WHERE id = ?
@@ -224,7 +252,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id, vuln.Recommendation, vuln.UpdatedAt, id,
) )
@@ -245,19 +273,12 @@ func (db *DB) DeleteVulnerability(id string) error {
} }
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致) // GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) { func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
stats := make(map[string]interface{}) stats := make(map[string]interface{})
where := "WHERE 1=1" where := "WHERE 1=1"
args := []interface{}{} args := []interface{}{}
if conversationID != "" { where, args = filter.appendWhere(where, args)
where += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
// 总漏洞数 // 总漏洞数
var totalCount int var totalCount int
@@ -357,10 +378,15 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("查询任务标签建议失败: %w", err) return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
} }
projectIDs, err := collect(`SELECT DISTINCT project_id FROM vulnerabilities WHERE project_id IS NOT NULL AND project_id <> '' ORDER BY created_at DESC LIMIT 200`)
if err != nil {
return nil, fmt.Errorf("查询项目ID建议失败: %w", err)
}
return map[string][]string{ return map[string][]string{
"vulnerability_ids": vulnIDs, "vulnerability_ids": vulnIDs,
"conversation_ids": conversationIDs, "conversation_ids": conversationIDs,
"project_ids": projectIDs,
"task_ids": taskIDs, "task_ids": taskIDs,
"queue_ids": queueIDs, "queue_ids": queueIDs,
"conversation_tags": conversationTags, "conversation_tags": conversationTags,
+210 -55
View File
@@ -17,12 +17,14 @@ import (
"unicode/utf8" "unicode/utf8"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"cyberstrike-ai/internal/openai"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/robfig/cron/v3" "github.com/robfig/cron/v3"
@@ -130,6 +132,12 @@ type AgentHandler struct {
batchRunning map[string]struct{} batchRunning map[string]struct{}
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
hitlWhitelistSaver HitlToolWhitelistSaver hitlWhitelistSaver HitlToolWhitelistSaver
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *AgentHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 // HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
@@ -206,7 +214,7 @@ type ChatAttachment struct {
type ChatReasoningRequest struct { type ChatReasoningRequest struct {
// Mode: default(跟随系统)| off | on | auto // Mode: default(跟随系统)| off | on | auto
Mode string `json:"mode,omitempty"` Mode string `json:"mode,omitempty"`
// Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定) // Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定
Effort string `json:"effort,omitempty"` Effort string `json:"effort,omitempty"`
} }
@@ -214,6 +222,7 @@ type ChatReasoningRequest struct {
type ChatRequest struct { type ChatRequest struct {
Message string `json:"message" binding:"required"` Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"` ConversationID string `json:"conversationId,omitempty"`
ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id
Role string `json:"role,omitempty"` // 角色名称 Role string `json:"role,omitempty"` // 角色名称
Attachments []ChatAttachment `json:"attachments,omitempty"` Attachments []ChatAttachment `json:"attachments,omitempty"`
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具 WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
@@ -552,7 +561,9 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
conversationID := req.ConversationID conversationID := req.ConversationID
if conversationID == "" { if conversationID == "" {
title := safeTruncateString(req.Message, 50) title := safeTruncateString(req.Message, 50)
conv, err := h.db.CreateConversation(title) meta := audit.ConversationCreateMetaFromGin(c, "agent_loop")
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
conv, err := h.db.CreateConversation(title, meta)
if err != nil { if err != nil {
h.logger.Error("创建对话失败", zap.Error(err)) h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -627,6 +638,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
builtin.ToolWebshellFileRead, builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite, builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability, builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolListKnowledgeRiskTypes, builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase, builtin.ToolSearchKnowledgeBase,
} }
@@ -674,7 +687,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil) taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil)
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表) // 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -716,11 +729,45 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
}) })
} }
func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) {
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
return "", conversationID, errMA
}
func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) {
if assistantMessageID != "" {
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
}
} else {
if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
}
return resultMA.Response, conversationID, nil
}
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复 // ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) { func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) {
if conversationID == "" { if conversationID == "" {
title := safeTruncateString(message, 50) title := safeTruncateString(message, 50)
conv, createErr := h.db.CreateConversation(title) src := "robot"
if strings.TrimSpace(platform) != "" {
src = "robot:" + strings.TrimSpace(platform)
}
meta := audit.ConversationCreateMeta(src)
meta.ProjectID = effectiveProjectID(h.config, "")
conv, createErr := h.db.CreateConversation(title, meta)
if createErr != nil { if createErr != nil {
return "", "", fmt.Errorf("创建对话失败: %w", createErr) return "", "", fmt.Errorf("创建对话失败: %w", createErr)
} }
@@ -768,53 +815,92 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
if assistantMsg != nil { if assistantMsg != nil {
assistantMessageID = assistantMsg.ID assistantMessageID = assistantMsg.ID
} }
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent // 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流(与 agent-loop/stream 一致)。
if useRobotMulti { taskCtx, cancelWithCause := context.WithCancelCause(ctx)
resultMA, errMA := multiagent.RunDeepAgent( defer cancelWithCause(nil)
ctx, taskStatus := "completed"
h.config, defer func() {
&h.config.MultiAgent, h.tasks.FinishTask(conversationID, taskStatus)
h.agent, }()
h.logger, if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil {
conversationID, if errors.Is(err, ErrTaskAlreadyRunning) {
finalMessage, return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试")
agentHistoryMessages,
roleTools,
progressCallback,
h.agentsMarkdownDir,
"deep",
nil,
)
if errMA != nil {
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
return "", conversationID, errMA
} }
if assistantMessageID != "" { return "", conversationID, fmt.Errorf("无法启动任务: %w", err)
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil { }
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU)) progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
robotMode := "react"
if h.config != nil {
robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent)
}
switch robotMode {
case "eino_single":
curHist := agentHistoryMessages
curMsg := finalMessage
segmentUserMessage := finalMessage
var resultMA *multiagent.RunResult
var errMA error
var transientRunAttempts int
for {
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
)
if errMA == nil {
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
transientRunAttempts = 0
break
} }
} else { if handled, _ := h.handleEinoTransientRetryContinue(
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil { taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) &curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
} }
taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
} }
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" { return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput) case "deep", "plan_execute", "supervisor":
if h.config == nil || !h.config.MultiAgent.Enabled {
h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退原生 ReAct",
zap.String("robot_mode", robotMode))
break
} }
return resultMA.Response, conversationID, nil curHist := agentHistoryMessages
curMsg := finalMessage
segmentUserMessage := finalMessage
var resultMA *multiagent.RunResult
var errMA error
var transientRunAttempts int
for {
resultMA, errMA = multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback,
h.agentsMarkdownDir, robotMode, nil, h.projectBlackboardBlock(conversationID),
)
if errMA == nil {
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
transientRunAttempts = 0
break
}
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
}
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
} }
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
if err != nil { if err != nil {
taskStatus = "failed"
errMsg := "执行失败: " + err.Error() errMsg := "执行失败: " + err.Error()
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
@@ -846,6 +932,23 @@ type StreamEvent struct {
Data interface{} `json:"data,omitempty"` Data interface{} `json:"data,omitempty"`
} }
// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。
func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) {
if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" {
return
}
event := StreamEvent{Type: eventType, Message: message, Data: data}
eventJSON, err := json.Marshal(event)
if err != nil {
return
}
sseLine := make([]byte, 0, len(eventJSON)+8)
sseLine = append(sseLine, []byte("data: ")...)
sseLine = append(sseLine, eventJSON...)
sseLine = append(sseLine, '\n', '\n')
h.taskEventBus.Publish(conversationID, sseLine)
}
// createProgressCallback 创建进度回调函数,用于保存processDetails // createProgressCallback 创建进度回调函数,用于保存processDetails
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件 // sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback { func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
@@ -955,9 +1058,11 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
return func(eventType, message string, data interface{}) { return func(eventType, message string, data interface{}) {
// 如果提供了sendEventFunc,发送流式事件 // 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅
if sendEventFunc != nil { if sendEventFunc != nil {
sendEventFunc(eventType, message, data) sendEventFunc(eventType, message, data)
} else {
h.publishProgressToTaskEventBus(conversationID, eventType, message, data)
} }
// 保存tool_call事件中的参数 // 保存tool_call事件中的参数
@@ -1158,7 +1263,16 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
return return
} }
if eventType == "response_delta" { if eventType == "response_delta" {
respPlan.b.WriteString(message) if dataMap, ok := data.(map[string]interface{}); ok {
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
respPlan.b.Reset()
respPlan.b.WriteString(acc)
} else {
respPlan.b.WriteString(message)
}
} else {
respPlan.b.WriteString(message)
}
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil { if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
respPlan.meta = make(map[string]interface{}, len(dataMap)) respPlan.meta = make(map[string]interface{}, len(dataMap))
for k, v := range dataMap { for k, v := range dataMap {
@@ -1213,8 +1327,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} else if tb.persistAs == "" { } else if tb.persistAs == "" {
tb.persistAs = persistAs tb.persistAs = persistAs
} }
// delta 片段直接拼接 if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
tb.b.WriteString(message) tb.b.Reset()
tb.b.WriteString(acc)
} else {
tb.b.WriteString(message)
}
// 有时 delta 先到 start 未到,补充元信息 // 有时 delta 先到 start 未到,补充元信息
for k, v := range dataMap { for k, v := range dataMap {
tb.meta[k] = v tb.meta[k] = v
@@ -1406,10 +1524,13 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
title := safeTruncateString(req.Message, 50) title := safeTruncateString(req.Message, 50)
var conv *database.Conversation var conv *database.Conversation
var err error var err error
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
if req.WebShellConnectionID != "" { if req.WebShellConnectionID != "" {
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) meta.Source = "webshell_chat"
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
} else { } else {
conv, err = h.db.CreateConversation(title) conv, err = h.db.CreateConversation(title, meta)
} }
if err != nil { if err != nil {
h.logger.Error("创建对话失败", zap.Error(err)) h.logger.Error("创建对话失败", zap.Error(err))
@@ -1482,6 +1603,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
builtin.ToolWebshellFileRead, builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite, builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability, builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolListKnowledgeRiskTypes, builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase, builtin.ToolSearchKnowledgeBase,
} }
@@ -1612,7 +1735,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
go sseKeepalive(c, stopKeepalive, &sseWriteMu) go sseKeepalive(c, stopKeepalive, &sseWriteMu)
defer close(stopKeepalive) defer close(stopKeepalive)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
@@ -1924,6 +2047,7 @@ type BatchTaskRequest struct {
ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
} }
func normalizeBatchQueueAgentMode(mode string) string { func normalizeBatchQueueAgentMode(mode string) string {
@@ -2004,7 +2128,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
nextRunAt = &next nextRunAt = &next
} }
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks) queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
if createErr != nil { if createErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
return return
@@ -2025,6 +2149,11 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
queue = refreshed queue = refreshed
} }
} }
if h.audit != nil {
h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{
"task_count": len(validTasks), "started": started,
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"queueId": queue.ID, "queueId": queue.ID,
"queue": queue, "queue": queue,
@@ -2132,6 +2261,9 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID}) c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
} }
@@ -2160,6 +2292,9 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID}) c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
} }
@@ -2171,6 +2306,9 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"}) c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"}) c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
} }
@@ -2266,6 +2404,16 @@ func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return return
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "task",
Action: "delete_queue",
Result: "success",
ResourceType: "batch_queue",
ResourceID: queueID,
Message: "删除批量任务队列",
})
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"}) c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
} }
@@ -2351,6 +2499,11 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{
"batch_queue_id": queueID,
})
}
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
} }
@@ -2509,7 +2662,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 创建新对话 // 创建新对话
title := safeTruncateString(task.Message, 50) title := safeTruncateString(task.Message, 50)
conv, err := h.db.CreateConversation(title) batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
var conversationID string var conversationID string
if err != nil { if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
@@ -2659,15 +2814,15 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
var runErr error var runErr error
switch { switch {
case useBatchMulti: case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil) resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
case useEinoSingle: case useEinoSingle:
if h.config == nil { if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载") runErr = fmt.Errorf("服务器配置未加载")
} else { } else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil) resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
} }
default: default:
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools) result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
} }
if runErr != nil { if runErr != nil {
+147
View File
@@ -0,0 +1,147 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuditHandler serves platform audit log APIs.
type AuditHandler struct {
db *database.DB
audit *audit.Service
logger *zap.Logger
}
// NewAuditHandler creates an audit log handler.
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
}
// Meta GET /api/audit/meta
func (h *AuditHandler) Meta(c *gin.Context) {
enabled := false
retentionDays := 0
if h.audit != nil {
enabled = h.audit.Enabled()
retentionDays = h.audit.RetentionDays()
}
c.JSON(http.StatusOK, gin.H{
"enabled": enabled,
"retention_days": retentionDays,
"default_page_size": 20,
"max_page_size": 100,
"max_export": 5000,
})
}
// Summary GET /api/audit/summary
func (h *AuditHandler) Summary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
base := auditFilterFromQuery(c)
total, err := h.db.CountAuditLogs(base)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
failFilter := base
failFilter.Result = "failure"
failures, err := h.db.CountAuditLogs(failFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
since := time.Now().AddDate(0, 0, -7)
recentFilter := base
recentFilter.Since = &since
recent7d, err := h.db.CountAuditLogs(recentFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"total": total,
"failures": failures,
"recent_7d": recent7d,
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
})
}
// ListLogs GET /api/audit/logs
func (h *AuditHandler) ListLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
page, pageSize := auditPaginationFromQuery(c)
filter.Limit = pageSize
filter.Offset = (page - 1) * pageSize
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
total, err := h.db.CountAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// GetLog GET /api/audit/logs/:id
func (h *AuditHandler) GetLog(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
row, err := h.db.GetAuditLogByID(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
return
}
audit.ApplyResourceAvailability(h.db, row)
c.JSON(http.StatusOK, gin.H{"log": row})
}
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
func (h *AuditHandler) ExportLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
filter.Limit = 5000
filter.Offset = 0
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if c.Query("format") == "csv" {
writeAuditLogsCSV(c, logs)
return
}
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
c.JSON(http.StatusOK, gin.H{
"exported_at": time.Now().UTC().Format(time.RFC3339),
"logs": logs,
})
}
+42
View File
@@ -0,0 +1,42 @@
package handler
import (
"encoding/csv"
"fmt"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
c.Header("Content-Type", "text/csv; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
w := csv.NewWriter(c.Writer)
_ = w.Write([]string{
"id", "created_at", "level", "category", "action", "result", "actor",
"session_hint", "client_ip", "resource_type", "resource_id", "message",
})
for _, row := range logs {
if row == nil {
continue
}
_ = w.Write([]string{
row.ID,
row.CreatedAt.UTC().Format(time.RFC3339),
row.Level,
row.Category,
row.Action,
row.Result,
row.Actor,
row.SessionHint,
row.ClientIP,
row.ResourceType,
row.ResourceID,
row.Message,
})
}
w.Flush()
}
+48
View File
@@ -0,0 +1,48 @@
package handler
import (
"strconv"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
filter := database.ListAuditLogsFilter{
Level: c.Query("level"),
Category: c.Query("category"),
Action: c.Query("action"),
Result: c.Query("result"),
Query: c.Query("q"),
ResourceType: c.Query("resource_type"),
ResourceID: c.Query("resource_id"),
}
if since := c.Query("since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil {
filter.Since = &t
}
}
if until := c.Query("until"); until != "" {
if t, err := time.Parse(time.RFC3339, until); err == nil {
filter.Until = &t
}
}
return filter
}
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
page = p
}
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
pageSize = ps
if pageSize > 100 {
pageSize = 100
}
}
return page, pageSize
}
+55
View File
@@ -5,6 +5,7 @@ import (
"strings" "strings"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
@@ -18,6 +19,12 @@ type AuthHandler struct {
config *config.Config config *config.Config
configPath string configPath string
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *AuthHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewAuthHandler creates a new AuthHandler. // NewAuthHandler creates a new AuthHandler.
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
token, expiresAt, err := h.manager.Authenticate(req.Password) token, expiresAt, err := h.manager.Authenticate(req.Password)
if err != nil { if err != nil {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "login",
Result: "failure",
Message: "登录失败:密码错误",
})
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
return return
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "login",
Result: "success",
SessionHint: audit.HintFromToken(token),
Message: "登录成功",
Detail: map[string]interface{}{
"expires_at": expiresAt.UTC().Format(time.RFC3339),
},
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"token": token, "token": token,
"expires_at": expiresAt.UTC().Format(time.RFC3339), "expires_at": expiresAt.UTC().Format(time.RFC3339),
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
} }
h.manager.RevokeToken(token) h.manager.RevokeToken(token)
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "logout",
Result: "success",
Message: "退出登录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"}) c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
} }
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
} }
if !h.manager.CheckPassword(oldPassword) { if !h.manager.CheckPassword(oldPassword) {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "change_password",
Result: "failure",
Message: "修改密码失败:当前密码不正确",
})
}
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"}) c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
return return
} }
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
h.logger.Info("登录密码已更新,所有会话已失效") h.logger.Info("登录密码已更新,所有会话已失效")
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "change_password",
Result: "success",
Message: "登录密码已修改",
})
}
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"}) c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
} }
+10 -1
View File
@@ -65,6 +65,7 @@ type BatchTaskQueue struct {
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
LastScheduleError string `json:"lastScheduleError,omitempty"` LastScheduleError string `json:"lastScheduleError,omitempty"`
LastRunError string `json:"lastRunError,omitempty"` LastRunError string `json:"lastRunError,omitempty"`
ProjectID string `json:"projectId,omitempty"`
Tasks []*BatchTask `json:"tasks"` Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@@ -103,7 +104,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
// CreateBatchQueue 创建批量任务队列 // CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue( func (m *BatchTaskManager) CreateBatchQueue(
title, role, agentMode, scheduleMode, cronExpr string, title, role, agentMode, scheduleMode, cronExpr, projectID string,
nextRunAt *time.Time, nextRunAt *time.Time,
tasks []string, tasks []string,
) (*BatchTaskQueue, error) { ) (*BatchTaskQueue, error) {
@@ -126,6 +127,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
ID: queueID, ID: queueID,
Title: title, Title: title,
Role: role, Role: role,
ProjectID: strings.TrimSpace(projectID),
AgentMode: normalizeBatchQueueAgentMode(agentMode), AgentMode: normalizeBatchQueueAgentMode(agentMode),
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode), ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
CronExpr: strings.TrimSpace(cronExpr), CronExpr: strings.TrimSpace(cronExpr),
@@ -171,6 +173,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
queue.ScheduleMode, queue.ScheduleMode,
queue.CronExpr, queue.CronExpr,
queue.NextRunAt, queue.NextRunAt,
queue.ProjectID,
dbTasks, dbTasks,
); err != nil { ); err != nil {
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
@@ -263,6 +266,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if queueRow.LastRunError.Valid { if queueRow.LastRunError.Valid {
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
} }
if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
}
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -499,6 +505,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
if queueRow.LastRunError.Valid { if queueRow.LastRunError.Valid {
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
} }
if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
}
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
+6 -1
View File
@@ -176,6 +176,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"type": "boolean", "type": "boolean",
"description": "创建后是否立即开始执行队列,默认 falsepending,需 batch_task_start", "description": "创建后是否立即开始执行队列,默认 falsepending,需 batch_task_start",
}, },
"project_id": map[string]interface{}{
"type": "string",
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id",
},
}, },
}, },
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
@@ -204,7 +208,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
if !ok { if !ok {
executeNow = false executeNow = false
} }
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks) projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
if createErr != nil { if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
} }
+37
View File
@@ -13,6 +13,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/c2" "cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
@@ -25,6 +26,12 @@ import (
type C2Handler struct { type C2Handler struct {
mgrPtr atomic.Pointer[c2.Manager] mgrPtr atomic.Pointer[c2.Manager]
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *C2Handler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时) // NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
@@ -104,6 +111,11 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
implantToken := listener.ImplantToken implantToken := listener.ImplantToken
listener.EncryptionKey = "" listener.EncryptionKey = ""
listener.ImplantToken = "" listener.ImplantToken = ""
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{
"name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort,
})
}
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken}) c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
} }
@@ -205,6 +217,9 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()}) c.JSON(code, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil)
}
c.JSON(http.StatusOK, gin.H{"deleted": true}) c.JSON(http.StatusOK, gin.H{"deleted": true})
} }
@@ -222,6 +237,9 @@ func (h *C2Handler) StartListener(c *gin.Context) {
} }
listener.EncryptionKey = "" listener.EncryptionKey = ""
listener.ImplantToken = "" listener.ImplantToken = ""
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil)
}
c.JSON(http.StatusOK, gin.H{"listener": listener}) c.JSON(http.StatusOK, gin.H{"listener": listener})
} }
@@ -236,6 +254,9 @@ func (h *C2Handler) StopListener(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()}) c.JSON(code, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil)
}
c.JSON(http.StatusOK, gin.H{"stopped": true}) c.JSON(http.StatusOK, gin.H{"stopped": true})
} }
@@ -297,6 +318,9 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil)
}
c.JSON(http.StatusOK, gin.H{"deleted": true}) c.JSON(http.StatusOK, gin.H{"deleted": true})
} }
@@ -407,6 +431,11 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{
"count": n, "ids": req.IDs,
})
}
c.JSON(http.StatusOK, gin.H{"deleted": n}) c.JSON(http.StatusOK, gin.H{"deleted": n})
} }
@@ -457,6 +486,11 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()}) c.JSON(code, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{
"session_id": req.SessionID, "task_type": req.TaskType,
})
}
c.JSON(http.StatusOK, gin.H{"task": task}) c.JSON(http.StatusOK, gin.H{"task": task})
} }
@@ -471,6 +505,9 @@ func (h *C2Handler) CancelTask(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()}) c.JSON(code, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil)
}
c.JSON(http.StatusOK, gin.H{"cancelled": true}) c.JSON(http.StatusOK, gin.H{"cancelled": true})
} }
+16
View File
@@ -12,6 +12,8 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"cyberstrike-ai/internal/audit"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -24,6 +26,12 @@ const (
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API // ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
type ChatUploadsHandler struct { type ChatUploadsHandler struct {
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewChatUploadsHandler 创建处理器 // NewChatUploadsHandler 创建处理器
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
return return
} }
} }
if h.audit != nil {
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
} }
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
} }
rel, _ := filepath.Rel(root, fullPath) rel, _ := filepath.Rel(root, fullPath)
absSaved, _ := filepath.Abs(fullPath) absSaved, _ := filepath.Abs(fullPath)
if h.audit != nil {
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
"name": unique,
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"ok": true, "ok": true,
"relativePath": filepath.ToSlash(rel), "relativePath": filepath.ToSlash(rel),
+72 -4
View File
@@ -14,6 +14,7 @@ import (
"time" "time"
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
@@ -87,6 +88,7 @@ type ConfigHandler struct {
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选) appUpdater AppUpdater // App更新器(可选)
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书 robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
audit *audit.Service
logger *zap.Logger logger *zap.Logger
mu sync.RWMutex mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
@@ -206,6 +208,32 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
h.robotRestarter = restarter h.robotRestarter = restarter
} }
// SetAudit wires platform audit logging.
func (h *ConfigHandler) SetAudit(s *audit.Service) {
h.mu.Lock()
defer h.mu.Unlock()
h.audit = s
}
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
h.mu.Lock()
wc.Enabled = true
h.config.Robots.Wechat = wc
h.mu.Unlock()
if err := h.saveConfig(); err != nil {
return err
}
if h.robotRestarter != nil {
h.robotRestarter.RestartRobotConnections()
}
h.logger.Info("微信机器人绑定已保存",
zap.String("ilink_bot_id", wc.ILinkBotID),
zap.Bool("enabled", wc.Enabled),
)
return nil
}
// GetConfigResponse 获取配置响应 // GetConfigResponse 获取配置响应
type GetConfigResponse struct { type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"` OpenAI config.OpenAIConfig `json:"openai"`
@@ -291,7 +319,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
} }
multiPub := config.MultiAgentPublic{ multiPub := config.MultiAgentPublic{
Enabled: h.config.MultiAgent.Enabled, Enabled: h.config.MultiAgent.Enabled,
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent, RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent, BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
SubAgentCount: subAgentCount, SubAgentCount: subAgentCount,
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration), Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
@@ -735,6 +763,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
if req.Robots != nil { if req.Robots != nil {
h.config.Robots = *req.Robots h.config.Robots = *req.Robots
h.logger.Info("更新机器人配置", h.logger.Info("更新机器人配置",
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled), zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled), zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled), zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
@@ -750,8 +779,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// 多代理标量(sub_agents 等仍由 config.yaml 维护) // 多代理标量(sub_agents 等仍由 config.yaml 维护)
if req.MultiAgent != nil { if req.MultiAgent != nil {
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
h.config.MultiAgent.RobotDefaultAgentMode = mode
} else {
h.config.MultiAgent.RobotDefaultAgentMode = "react"
}
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil { if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
} }
@@ -760,7 +793,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
} }
h.logger.Info("更新多代理配置", h.logger.Info("更新多代理配置",
zap.Bool("enabled", h.config.MultiAgent.Enabled), zap.Bool("enabled", h.config.MultiAgent.Enabled),
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent), zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent), zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations), zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)), zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
@@ -883,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
} }
@@ -1013,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件") h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
if _, err := knowledgeInitializer(); err != nil { if _, err := knowledgeInitializer(); err != nil {
h.logger.Error("动态初始化知识库失败", zap.Error(err)) h.logger.Error("动态初始化知识库失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
return return
} }
@@ -1047,6 +1086,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)") h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
if _, err := reinitKnowledgeInitializer(); err != nil { if _, err := reinitKnowledgeInitializer(); err != nil {
h.logger.Error("重新初始化知识库失败", zap.Error(err)) h.logger.Error("重新初始化知识库失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
return return
} }
@@ -1060,6 +1102,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
if c2Rt != nil { if c2Rt != nil {
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil { if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
h.logger.Error("C2 配置应用失败", zap.Error(err)) h.logger.Error("C2 配置应用失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
return return
} }
@@ -1201,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
zap.Int("tools_count", len(h.config.Security.Tools)), zap.Int("tools_count", len(h.config.Security.Tools)),
) )
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "config",
Action: "apply",
Result: "success",
Message: "配置已应用",
Detail: map[string]interface{}{
"tools_count": len(h.config.Security.Tools),
"knowledge_enabled": h.config.Knowledge.Enabled,
"c2_enabled": h.config.C2.EnabledEffective(),
},
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "配置已应用", "message": "配置已应用",
"tools_count": len(h.config.Security.Tools), "tools_count": len(h.config.Security.Tools),
@@ -1481,6 +1540,15 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity) setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
} }
wechatNode := ensureMap(robotsNode, "wechat")
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
wecomNode := ensureMap(robotsNode, "wecom") wecomNode := ensureMap(robotsNode, "wecom")
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled) setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
setStringInMap(wecomNode, "token", cfg.Wecom.Token) setStringInMap(wecomNode, "token", cfg.Wecom.Token)
@@ -1507,7 +1575,7 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
root := doc.Content[0] root := doc.Content[0]
maNode := ensureMap(root, "multi_agent") maNode := ensureMap(root, "multi_agent")
setBoolInMap(maNode, "enabled", cfg.Enabled) setBoolInMap(maNode, "enabled", cfg.Enabled)
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent) setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent) setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations) setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
mwNode := ensureMap(maNode, "eino_middleware") mwNode := ensureMap(maNode, "eino_middleware")
+54 -2
View File
@@ -4,7 +4,9 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -14,6 +16,12 @@ import (
type ConversationHandler struct { type ConversationHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ConversationHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewConversationHandler 创建新的对话处理器 // NewConversationHandler 创建新的对话处理器
@@ -26,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa
// CreateConversationRequest 创建对话请求 // CreateConversationRequest 创建对话请求
type CreateConversationRequest struct { type CreateConversationRequest struct {
Title string `json:"title"` Title string `json:"title"`
ProjectID string `json:"projectId,omitempty"`
}
// SetConversationProjectRequest 设置对话所属项目
type SetConversationProjectRequest struct {
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
} }
// CreateConversation 创建新对话 // CreateConversation 创建新对话
@@ -42,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
title = "新对话" title = "新对话"
} }
conv, err := h.db.CreateConversation(title) meta := audit.ConversationCreateMetaFromGin(c, "api")
meta.ProjectID = strings.TrimSpace(req.ProjectID)
conv, err := h.db.CreateConversation(title, meta)
if err != nil { if err != nil {
h.logger.Error("创建对话失败", zap.Error(err)) h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -52,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
c.JSON(http.StatusOK, conv) c.JSON(http.StatusOK, conv)
} }
// SetConversationProject 设置或清除对话绑定的项目
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
id := c.Param("id")
var req SetConversationProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := h.db.GetConversation(id); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
}
// ListConversations 列出对话 // ListConversations 列出对话
func (h *ConversationHandler) ListConversations(c *gin.Context) { func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50") limitStr := c.DefaultQuery("limit", "50")
@@ -189,6 +224,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "conversation",
Action: "delete",
Result: "success",
ResourceType: "conversation",
ResourceID: id,
Message: "删除对话",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
} }
@@ -227,6 +273,12 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
"message_id": req.MessageID,
"deleted": len(deletedIDs),
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"deletedMessageIds": deletedIDs, "deletedMessageIds": deletedIDs,
"message": "ok", "message": "ok",
+122
View File
@@ -0,0 +1,122 @@
package handler
import (
"context"
"errors"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent"
)
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
if h.config != nil {
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
}
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
}
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
}
return 0
}
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
func (h *AgentHandler) applyEinoTraceResumeSegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if segmentUserMessage != "" {
*curFinalMessage = segmentUserMessage
}
}
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
func (h *AgentHandler) applyEinoTransientRetrySegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if s := strings.TrimSpace(segmentUserMessage); s != "" {
*curFinalMessage = segmentUserMessage
}
}
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
func (h *AgentHandler) handleEinoTransientRetryContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
transientAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, fatal error) {
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
return false, nil
}
maxAttempts := h.einoRunRetryMaxAttempts()
*transientAttempts++
if *transientAttempts > maxAttempts {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, errors.New("transient retry exhausted: " + runErr.Error())
}
attemptNo := *transientAttempts
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
if progressCallback != nil {
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
}
select {
case <-baseCtx.Done():
return false, context.Cause(baseCtx)
case <-time.After(backoff):
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if progressCallback != nil {
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
})
}
if sendProgress != nil {
sendProgress("正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "transient_retry",
})
}
return true, nil
}
+69 -5
View File
@@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
zap.String("conversationId", req.ConversationID), zap.String("conversationId", req.ConversationID),
) )
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
if err != nil { if err != nil {
sendEvent("error", err.Error(), nil) sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil) sendEvent("done", "", nil)
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History curHistory := prep.History
roleTools := prep.RoleTools roleTools := prep.RoleTools
@@ -176,9 +177,41 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
taskOwned = true taskOwned = true
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int
for { for {
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) segmentMainIterationMax := 0
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
progressCallback := func(eventType, message string, data interface{}) {
if eventType == "iteration" {
if m, ok := data.(map[string]interface{}); ok {
if scope, _ := m["einoScope"].(string); scope == "main" {
raw := 0
switch v := m["iteration"].(type) {
case int:
raw = v
case int32:
raw = int(v)
case int64:
raw = int(v)
case float64:
raw = int(v)
case float32:
raw = int(v)
}
if raw > 0 {
if raw > segmentMainIterationMax {
segmentMainIterationMax = raw
}
m["iteration"] = raw + mainIterationOffset
}
}
}
}
rawProgressCallback(eventType, message, data)
}
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
@@ -197,17 +230,38 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
roleTools, roleTools,
progressCallback, progressCallback,
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID),
) )
timeoutCancel()
if result != nil && len(result.MCPExecutionIDs) > 0 { if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
} }
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
timeoutCancel()
break break
} }
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) { if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -231,10 +285,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"conversationId": conversationID, "conversationId": conversationID,
"source": "interrupt_continue", "source": "interrupt_continue",
}) })
h.tasks.UpdateTaskStatus(conversationID, "running") mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause) h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue continue
} }
@@ -261,6 +319,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID, "messageId": assistantMessageID,
}) })
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
@@ -278,6 +337,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"errorType": "timeout", "errorType": "timeout",
}) })
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
@@ -294,9 +354,12 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID, "messageId": assistantMessageID,
}) })
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
timeoutCancel()
if assistantMessageID != "" { if assistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
} }
@@ -326,7 +389,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID)) h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@@ -367,6 +430,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.RoleTools, prep.RoleTools,
progressCallback, progressCallback,
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
) )
if runErr != nil { if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
+27
View File
@@ -6,6 +6,7 @@ import (
"os" "os"
"sync" "sync"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
config *config.Config config *config.Config
configPath string configPath string
logger *zap.Logger logger *zap.Logger
audit *audit.Service
mu sync.RWMutex mu sync.RWMutex
} }
// SetAudit wires platform audit logging.
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewExternalMCPHandler 创建外部MCP处理器 // NewExternalMCPHandler 创建外部MCP处理器
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler { func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
return &ExternalMCPHandler{ return &ExternalMCPHandler{
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
} }
h.logger.Info("外部MCP配置已更新", zap.String("name", name)) h.logger.Info("外部MCP配置已更新", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "upsert",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "更新外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
} }
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
} }
h.logger.Info("外部MCP配置已删除", zap.String("name", name)) h.logger.Info("外部MCP配置已删除", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "delete",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "删除外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"}) c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
} }
+5
View File
@@ -616,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
"decision": req.Decision,
})
}
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
} }
+13
View File
@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
indexer *knowledge.Indexer indexer *knowledge.Indexer
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewKnowledgeHandler 创建新的知识库处理器 // NewKnowledgeHandler 创建新的知识库处理器
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
} }
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
} }
}() }()
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"}) c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
} }
+17 -1
View File
@@ -9,6 +9,7 @@ import (
"strings" "strings"
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。 // MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
type MarkdownAgentsHandler struct { type MarkdownAgentsHandler struct {
dir string dir string
audit *audit.Service
} }
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。 // NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)} return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
} }
// SetAudit wires platform audit logging.
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) { func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
filename = strings.TrimSpace(filename) filename = strings.TrimSpace(filename)
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) { if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
}
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"}) c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
} }
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已保存"}) c.JSON(http.StatusOK, gin.H{"message": "已保存"})
} }
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已删除"}) c.JSON(http.StatusOK, gin.H{"message": "已删除"})
} }
+17
View File
@@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
@@ -23,6 +24,12 @@ type MonitorHandler struct {
executor *security.Executor executor *security.Executor
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *MonitorHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewMonitorHandler 创建新的监控处理器 // NewMonitorHandler 创建新的监控处理器
@@ -365,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
} }
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName)) h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
"tool_name": exec.ToolName,
})
}
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
return return
} }
@@ -440,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
} }
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs))) h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
"count": len(request.IDs),
})
}
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)}) c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
return return
} }
+69 -5
View File
@@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
zap.String("conversationId", req.ConversationID), zap.String("conversationId", req.ConversationID),
) )
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
if err != nil { if err != nil {
sendEvent("error", err.Error(), nil) sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil) sendEvent("done", "", nil)
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History curHistory := prep.History
roleTools := prep.RoleTools roleTools := prep.RoleTools
orch := strings.TrimSpace(req.Orchestration) orch := strings.TrimSpace(req.Orchestration)
@@ -186,9 +187,41 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int
for { for {
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) segmentMainIterationMax := 0
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
progressCallback := func(eventType, message string, data interface{}) {
if eventType == "iteration" {
if m, ok := data.(map[string]interface{}); ok {
if scope, _ := m["einoScope"].(string); scope == "main" {
raw := 0
switch v := m["iteration"].(type) {
case int:
raw = v
case int32:
raw = int(v)
case int64:
raw = int(v)
case float64:
raw = int(v)
case float32:
raw = int(v)
}
if raw > 0 {
if raw > segmentMainIterationMax {
segmentMainIterationMax = raw
}
m["iteration"] = raw + mainIterationOffset
}
}
}
}
rawProgressCallback(eventType, message, data)
}
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
@@ -209,17 +242,38 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.agentsMarkdownDir, h.agentsMarkdownDir,
orch, orch,
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID),
) )
timeoutCancel()
if result != nil && len(result.MCPExecutionIDs) > 0 { if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
} }
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
timeoutCancel()
break break
} }
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) { if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -243,10 +297,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"conversationId": conversationID, "conversationId": conversationID,
"source": "interrupt_continue", "source": "interrupt_continue",
}) })
h.tasks.UpdateTaskStatus(conversationID, "running") mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause) h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue continue
} }
@@ -273,6 +331,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID, "messageId": assistantMessageID,
}) })
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
@@ -290,6 +349,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"errorType": "timeout", "errorType": "timeout",
}) })
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
@@ -306,9 +366,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID, "messageId": assistantMessageID,
}) })
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
timeoutCancel()
if assistantMessageID != "" { if assistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
} }
@@ -347,7 +410,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID)) h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
if err != nil { if err != nil {
status, msg := multiAgentHTTPErrorStatus(err) status, msg := multiAgentHTTPErrorStatus(err)
c.JSON(status, gin.H{"error": msg}) c.JSON(status, gin.H{"error": msg})
@@ -381,6 +444,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
h.agentsMarkdownDir, h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration), strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
) )
if runErr != nil { if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
+16 -3
View File
@@ -5,9 +5,11 @@ import (
"strings" "strings"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
UserMessageID string UserMessageID string
} }
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) { func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
if len(req.Attachments) > maxAttachments { if len(req.Attachments) > maxAttachments {
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments) return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
} }
@@ -33,10 +35,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
title := safeTruncateString(req.Message, 50) title := safeTruncateString(req.Message, 50)
var conv *database.Conversation var conv *database.Conversation
var err error var err error
meta := audit.ConversationCreateMetaFromGin(c, source)
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
if strings.TrimSpace(req.WebShellConnectionID) != "" { if strings.TrimSpace(req.WebShellConnectionID) != "" {
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) meta.Source = source + "_webshell"
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
} else { } else {
conv, err = h.db.CreateConversation(title) conv, err = h.db.CreateConversation(title, meta)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err) return nil, fmt.Errorf("创建对话失败: %w", err)
@@ -85,6 +91,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
builtin.ToolWebshellFileRead, builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite, builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability, builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolUpsertProjectFact,
builtin.ToolGetProjectFact,
builtin.ToolListProjectFacts,
builtin.ToolSearchProjectFacts,
builtin.ToolDeprecateProjectFact,
builtin.ToolListKnowledgeRiskTypes, builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase, builtin.ToolSearchKnowledgeBase,
} }
+139 -1
View File
@@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"description": "对话标题", "description": "对话标题",
"example": "Web应用安全测试", "example": "Web应用安全测试",
}, },
"projectId": map[string]interface{}{
"type": "string",
"description": "绑定的项目 ID(可选,共享事实黑板)",
},
}, },
}, },
"SetConversationProjectRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"projectId": map[string]interface{}{
"type": "string",
"description": "项目 ID;空字符串表示解除绑定",
},
},
"required": []string{"projectId"},
},
"Conversation": map[string]interface{}{ "Conversation": map[string]interface{}{
"type": "object", "type": "object",
"properties": map[string]interface{}{ "properties": map[string]interface{}{
@@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"format": "date-time", "format": "date-time",
"description": "更新时间", "description": "更新时间",
}, },
"projectId": map[string]interface{}{
"type": "string",
"description": "绑定的项目 ID(可选)",
},
}, },
}, },
"ConversationDetail": map[string]interface{}{ "ConversationDetail": map[string]interface{}{
@@ -1326,6 +1344,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
}, },
}, },
}, },
"/api/conversations/{id}/project": map[string]interface{}{
"put": map[string]interface{}{
"tags": []string{"对话管理"},
"summary": "设置对话所属项目",
"description": "绑定或解除对话与项目的关联,用于共享事实黑板",
"operationId": "setConversationProject",
"parameters": []map[string]interface{}{
{
"name": "id", "in": "path", "required": true,
"description": "对话ID",
"schema": map[string]interface{}{"type": "string"},
},
},
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"$ref": "#/components/schemas/SetConversationProjectRequest",
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{"description": "设置成功"},
"400": map[string]interface{}{"description": "项目不存在或参数错误"},
"404": map[string]interface{}{"description": "对话不存在"},
"401": map[string]interface{}{"description": "未授权"},
},
},
},
"/api/conversations/{id}/results": map[string]interface{}{ "/api/conversations/{id}/results": map[string]interface{}{
"get": map[string]interface{}{ "get": map[string]interface{}{
"tags": []string{"对话管理"}, "tags": []string{"对话管理"},
@@ -2444,6 +2493,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
}, },
}, },
}, },
"/api/projects": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"},
"summary": "列出项目",
"operationId": "listProjects",
"parameters": []map[string]interface{}{
{"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}},
{"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{"description": "项目列表"},
"401": map[string]interface{}{"description": "未授权"},
},
},
"post": map[string]interface{}{
"tags": []string{"项目管理"},
"summary": "创建项目",
"operationId": "createProject",
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{"type": "string"},
"description": map[string]interface{}{"type": "string"},
"scope_json": map[string]interface{}{"type": "string"},
},
"required": []string{"name"},
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{"description": "创建成功"},
"401": map[string]interface{}{"description": "未授权"},
},
},
},
"/api/projects/{id}": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}},
},
"put": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}},
},
"delete": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
},
},
"/api/projects/{id}/facts": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
},
"post": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
},
},
"/api/vulnerabilities": map[string]interface{}{ "/api/vulnerabilities": map[string]interface{}{
"get": map[string]interface{}{ "get": map[string]interface{}{
"tags": []string{"漏洞管理"}, "tags": []string{"漏洞管理"},
@@ -2502,6 +2631,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"type": "string", "type": "string",
}, },
}, },
{
"name": "project_id",
"in": "query",
"required": false,
"description": "项目ID",
"schema": map[string]interface{}{
"type": "string",
},
},
{ {
"name": "severity", "name": "severity",
"in": "query", "in": "query",
@@ -6254,7 +6392,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
} }
// 获取漏洞列表 // 获取漏洞列表
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "") vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
if err != nil { if err != nil {
h.logger.Warn("获取漏洞列表失败", zap.Error(err)) h.logger.Warn("获取漏洞列表失败", zap.Error(err))
vulnList = []*database.Vulnerability{} vulnList = []*database.Vulnerability{}
+262
View File
@@ -0,0 +1,262 @@
package handler
import (
"net/http"
"strconv"
"strings"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProjectHandler 项目管理处理器。
type ProjectHandler struct {
db *database.DB
logger *zap.Logger
}
// NewProjectHandler 创建项目管理处理器。
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
return &ProjectHandler{db: db, logger: logger}
}
type createProjectRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
ScopeJSON string `json:"scope_json"`
Status string `json:"status"`
}
type updateProjectRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ScopeJSON string `json:"scope_json"`
Status string `json:"status"`
Pinned *bool `json:"pinned"`
}
// CreateProject POST /api/projects
func (h *ProjectHandler) CreateProject(c *gin.Context) {
var req createProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p := &database.Project{
Name: strings.TrimSpace(req.Name),
Description: req.Description,
ScopeJSON: req.ScopeJSON,
Status: strings.TrimSpace(req.Status),
}
created, err := h.db.CreateProject(p)
if err != nil {
h.logger.Error("创建项目失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// ListProjects GET /api/projects
func (h *ProjectHandler) ListProjects(c *gin.Context) {
status := c.Query("status")
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
offset, _ := strconv.Atoi(c.Query("offset"))
list, err := h.db.ListProjects(status, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.Project{}
}
c.JSON(http.StatusOK, list)
}
// GetProject GET /api/projects/:id
func (h *ProjectHandler) GetProject(c *gin.Context) {
p, err := h.db.GetProject(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
c.JSON(http.StatusOK, p)
}
// UpdateProject PUT /api/projects/:id
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
id := c.Param("id")
p, err := h.db.GetProject(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
var req updateProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if s := strings.TrimSpace(req.Name); s != "" {
p.Name = s
}
if req.Description != "" || c.Request.ContentLength > 0 {
p.Description = req.Description
}
if req.ScopeJSON != "" || c.GetHeader("Content-Type") != "" {
p.ScopeJSON = req.ScopeJSON
}
if s := strings.TrimSpace(req.Status); s != "" {
p.Status = s
}
if req.Pinned != nil {
p.Pinned = *req.Pinned
}
if err := h.db.UpdateProject(p); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, p)
}
// DeleteProject DELETE /api/projects/:id
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
if err := h.db.DeleteProject(c.Param("id")); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type upsertFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
Category string `json:"category"`
Summary string `json:"summary" binding:"required"`
Body string `json:"body"`
Confidence string `json:"confidence"`
Pinned bool `json:"pinned"`
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
}
// ListFacts GET /api/projects/:id/facts fact_key 查询参数可获取单条详情)
func (h *ProjectHandler) ListFacts(c *gin.Context) {
projectID := c.Param("id")
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
f, err := h.db.GetProjectFactByKey(projectID, key)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, f)
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.Query("offset"))
filter := database.ProjectFactListFilter{
Category: c.Query("category"),
Confidence: c.Query("confidence"),
Search: c.Query("search"),
}
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.ProjectFact{}
}
c.JSON(http.StatusOK, list)
}
// CreateFact POST /api/projects/:id/facts
func (h *ProjectHandler) CreateFact(c *gin.Context) {
var req upsertFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
f := &database.ProjectFact{
ProjectID: c.Param("id"),
FactKey: req.FactKey,
Category: req.Category,
Summary: req.Summary,
Body: req.Body,
Confidence: req.Confidence,
Pinned: req.Pinned,
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
}
created, err := h.db.UpsertProjectFact(f)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// UpdateFact PUT /api/projects/:id/facts/:factId
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
var req upsertFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if k := strings.TrimSpace(req.FactKey); k != "" {
existing.FactKey = k
}
if req.Category != "" {
existing.Category = req.Category
}
if req.Summary != "" {
existing.Summary = req.Summary
}
existing.Body = req.Body
if req.Confidence != "" {
existing.Confidence = req.Confidence
}
existing.Pinned = req.Pinned
existing.RelatedVulnerabilityID = req.RelatedVulnerabilityID
updated, err := h.db.UpsertProjectFact(existing)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, updated)
}
// DeleteFact DELETE /api/projects/:id/facts/:factId
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type deprecateFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
}
// DeprecateFact POST /api/projects/:id/facts/deprecate
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
var req deprecateFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
+32
View File
@@ -0,0 +1,32 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/project"
"go.uber.org/zap"
)
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
if h == nil || h.db == nil || h.config == nil {
return ""
}
if !h.config.Project.Enabled {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID, err := h.db.GetConversationProjectID(conversationID)
if err != nil || projectID == "" {
return ""
}
block, err := project.BuildFactIndexBlock(h.db, projectID, h.config.Project)
if err != nil {
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
return ""
}
return strings.TrimSpace(block)
}
+18
View File
@@ -0,0 +1,18 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/config"
)
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
func effectiveProjectID(cfg *config.Config, explicit string) string {
if pid := strings.TrimSpace(explicit); pid != "" {
return pid
}
if cfg != nil {
return strings.TrimSpace(cfg.Project.DefaultProjectID)
}
return ""
}
+7 -3
View File
@@ -133,7 +133,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
} else { } else {
t = safeTruncateString(t, 50) t = safeTruncateString(t, 50)
} }
conv, err := h.db.CreateConversation(t) meta := database.ConversationCreateMeta{Source: "robot:" + platform}
meta.ProjectID = effectiveProjectID(h.config, "")
conv, err := h.db.CreateConversation(t, meta)
if err != nil { if err != nil {
h.logger.Warn("创建机器人会话失败", zap.Error(err)) h.logger.Warn("创建机器人会话失败", zap.Error(err))
return "", false return "", false
@@ -188,7 +190,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
// clearConversation 清空当前会话(切换到新对话) // clearConversation 清空当前会话(切换到新对话)
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) { func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
title := "新对话 " + time.Now().Format("01-02 15:04") title := "新对话 " + time.Now().Format("01-02 15:04")
conv, err := h.db.CreateConversation(title) meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}
meta.ProjectID = effectiveProjectID(h.config, "")
conv, err := h.db.CreateConversation(title, meta)
if err != nil { if err != nil {
h.logger.Warn("创建新对话失败", zap.Error(err)) h.logger.Warn("创建新对话失败", zap.Error(err))
return "" return ""
@@ -242,7 +246,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
h.cancelMu.Unlock() h.cancelMu.Unlock()
}() }()
role := h.getRole(platform, userID) role := h.getRole(platform, userID)
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role) resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
if err != nil { if err != nil {
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err)) h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
+16
View File
@@ -8,6 +8,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -21,6 +22,12 @@ type RoleHandler struct {
config *config.Config config *config.Config
configPath string configPath string
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *RoleHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewRoleHandler 创建新的角色处理器 // NewRoleHandler 创建新的角色处理器
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
} }
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name)) h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "角色已更新", "message": "角色已更新",
"role": req, "role": req,
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
} }
h.logger.Info("创建角色", zap.String("roleName", req.Name)) h.logger.Info("创建角色", zap.String("roleName", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "角色已创建", "message": "角色已创建",
"role": req, "role": req,
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
} }
h.logger.Info("删除角色", zap.String("roleName", roleName)) h.logger.Info("删除角色", zap.String("roleName", roleName))
if h.audit != nil {
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "角色已删除", "message": "角色已删除",
}) })
+18
View File
@@ -8,6 +8,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skillpackage" "cyberstrike-ai/internal/skillpackage"
@@ -23,6 +24,12 @@ type SkillsHandler struct {
configPath string configPath string
logger *zap.Logger logger *zap.Logger
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除) db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *SkillsHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewSkillsHandler 创建新的Skills处理器 // NewSkillsHandler 创建新的Skills处理器
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
} }
h.logger.Info("创建skill成功", zap.String("skill", req.Name)) h.logger.Info("创建skill成功", zap.String("skill", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "skill已创建", "message": "skill已创建",
"skill": map[string]interface{}{ "skill": map[string]interface{}{
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
} }
h.logger.Info("更新skill成功", zap.String("skill", skillName)) h.logger.Info("更新skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "skill已更新", "message": "skill已更新",
}) })
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
} }
h.logger.Info("删除skill成功", zap.String("skill", skillName)) h.logger.Info("删除skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
"affected_roles": affectedRoles,
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": responseMsg, "message": responseMsg,
"affected_roles": affectedRoles, "affected_roles": affectedRoles,
+1 -1
View File
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
flusher.Flush() flusher.Flush()
} }
runCommandStreamImpl(cmd, sendEvent, ctx) _ = runCommandStreamImpl(cmd, sendEvent, ctx)
} }
+3 -2
View File
@@ -15,11 +15,11 @@ const ptyCols = 256
const ptyRows = 40 const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真) // runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) { func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows}) ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil { if err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
defer ptmx.Close() defer ptmx.Close()
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
exitCode = -1 exitCode = -1
} }
sendEvent(streamEvent{T: "exit", C: exitCode}) sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
} }
+5 -4
View File
@@ -11,20 +11,20 @@ import (
) )
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行 // runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) { func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
stdoutPipe, err := cmd.StdoutPipe() stdoutPipe, err := cmd.StdoutPipe()
if err != nil { if err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
stderrPipe, err := cmd.StderrPipe() stderrPipe, err := cmd.StderrPipe()
if err != nil { if err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
normalize := func(s string) string { normalize := func(s string) string {
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
exitCode = -1 exitCode = -1
} }
sendEvent(streamEvent{T: "exit", C: exitCode}) sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
} }
+71 -32
View File
@@ -7,6 +7,7 @@ import (
"strings" "strings"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -16,6 +17,12 @@ import (
type VulnerabilityHandler struct { type VulnerabilityHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewVulnerabilityHandler 创建新的漏洞处理器 // NewVulnerabilityHandler 创建新的漏洞处理器
@@ -29,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
// CreateVulnerabilityRequest 创建漏洞请求 // CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct { type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"` ConversationID string `json:"conversation_id" binding:"required"`
ProjectID string `json:"project_id"`
ConversationTag string `json:"conversation_tag"` ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"` TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"` Title string `json:"title" binding:"required"`
@@ -52,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
vuln := &database.Vulnerability{ vuln := &database.Vulnerability{
ConversationID: req.ConversationID, ConversationID: req.ConversationID,
ProjectID: strings.TrimSpace(req.ProjectID),
ConversationTag: req.ConversationTag, ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag, TaskTag: req.TaskTag,
Title: req.Title, Title: req.Title,
@@ -72,6 +81,11 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
"severity": created.Severity, "title": created.Title,
})
}
c.JSON(http.StatusOK, created) c.JSON(http.StatusOK, created)
} }
@@ -98,18 +112,30 @@ type ListVulnerabilitiesResponse struct {
TotalPages int `json:"total_pages"` TotalPages int `json:"total_pages"`
} }
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
q := strings.TrimSpace(c.Query("q"))
if q == "" {
q = strings.TrimSpace(c.Query("search"))
}
return database.VulnerabilityListFilter{
ProjectID: c.Query("project_id"),
ID: c.Query("id"),
Search: q,
ConversationID: c.Query("conversation_id"),
Severity: c.Query("severity"),
Status: c.Query("status"),
TaskID: c.Query("task_id"),
ConversationTag: c.Query("conversation_tag"),
TaskTag: c.Query("task_tag"),
}
}
// ListVulnerabilities 列出漏洞 // ListVulnerabilities 列出漏洞
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "20") limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0") offsetStr := c.DefaultQuery("offset", "0")
pageStr := c.Query("page") pageStr := c.Query("page")
id := c.Query("id") filter := parseVulnerabilityListFilter(c)
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
limit, _ := strconv.Atoi(limitStr) limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr) offset, _ := strconv.Atoi(offsetStr)
@@ -131,7 +157,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
} }
// 获取总数 // 获取总数
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) total, err := h.db.CountVulnerabilities(filter)
if err != nil { if err != nil {
h.logger.Error("获取漏洞总数失败", zap.Error(err)) h.logger.Error("获取漏洞总数失败", zap.Error(err))
// 继续执行,使用0作为总数 // 继续执行,使用0作为总数
@@ -139,7 +165,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
} }
// 获取漏洞列表 // 获取漏洞列表
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag) vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
if err != nil { if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err)) h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -170,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
// UpdateVulnerabilityRequest 更新漏洞请求 // UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct { type UpdateVulnerabilityRequest struct {
ConversationTag string `json:"conversation_tag"` ProjectID *string `json:"project_id"`
TaskTag string `json:"task_tag"` ConversationTag string `json:"conversation_tag"`
Title string `json:"title"` TaskTag string `json:"task_tag"`
Description string `json:"description"` Title string `json:"title"`
Severity string `json:"severity"` Description string `json:"description"`
Status string `json:"status"` Severity string `json:"severity"`
Type string `json:"type"` Status string `json:"status"`
Target string `json:"target"` Type string `json:"type"`
Proof string `json:"proof"` Target string `json:"target"`
Impact string `json:"impact"` Proof string `json:"proof"`
Recommendation string `json:"recommendation"` Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
} }
// UpdateVulnerability 更新漏洞 // UpdateVulnerability 更新漏洞
@@ -201,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
} }
// 更新字段 // 更新字段
if req.ProjectID != nil {
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
}
if req.ConversationTag != "" { if req.ConversationTag != "" {
existing.ConversationTag = req.ConversationTag existing.ConversationTag = req.ConversationTag
} }
@@ -249,6 +279,11 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
})
}
c.JSON(http.StatusOK, updated) c.JSON(http.StatusOK, updated)
} }
@@ -262,15 +297,25 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "vulnerability",
Action: "delete",
Result: "success",
ResourceType: "vulnerability",
ResourceID: id,
Message: "删除漏洞记录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
} }
// GetVulnerabilityStats 获取漏洞统计 // GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
conversationID := c.Query("conversation_id") filter := parseVulnerabilityListFilter(c)
taskID := c.Query("task_id")
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID) stats, err := h.db.GetVulnerabilityStats(filter)
if err != nil { if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err)) h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -304,15 +349,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
return return
} }
id := c.Query("id") filter := parseVulnerabilityListFilter(c)
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) total, err := h.db.CountVulnerabilities(filter)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -322,7 +361,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
return return
} }
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag) items, err := h.db.ListVulnerabilities(total, 0, filter)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
+28 -3
View File
@@ -2,6 +2,7 @@ package handler
import ( import (
"bytes" "bytes"
"crypto/tls"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@@ -12,6 +13,7 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -304,6 +306,12 @@ type WebShellHandler struct {
logger *zap.Logger logger *zap.Logger
client *http.Client client *http.Client
db *database.DB db *database.DB
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *WebShellHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用) // NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
@@ -311,8 +319,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
return &WebShellHandler{ return &WebShellHandler{
logger: logger, logger: logger,
client: &http.Client{ client: &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: &http.Transport{DisableKeepAlives: false}, Transport: &http.Transport{
DisableKeepAlives: false,
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
},
}, },
db: db, db: db,
} }
@@ -403,6 +415,15 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
host := req.URL
if u, err := url.Parse(req.URL); err == nil {
host = u.Host
}
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
"host": host, "type": shellType,
})
}
c.JSON(http.StatusOK, conn) c.JSON(http.StatusOK, conn)
} }
@@ -485,6 +506,9 @@ func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
} }
@@ -714,8 +738,9 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
output := decodeWebshellOutput(out, req.Encoding) output := decodeWebshellOutput(out, req.Encoding)
httpCode := resp.StatusCode httpCode := resp.StatusCode
ok := resp.StatusCode == http.StatusOK
c.JSON(http.StatusOK, ExecResponse{ c.JSON(http.StatusOK, ExecResponse{
OK: resp.StatusCode == http.StatusOK, OK: ok,
Output: output, Output: output,
HTTPCode: httpCode, HTTPCode: httpCode,
}) })
+1 -1
View File
@@ -15,7 +15,7 @@ const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `s
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。 // webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。 // 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base" const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、list_knowledge_risk_types、search_knowledge_base"
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。 // BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、 // 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
+293
View File
@@ -0,0 +1,293 @@
package handler
import (
"context"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/robot/ilink"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const wechatLoginTTL = 5 * time.Minute
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
type WechatConfigSaver interface {
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
}
type wechatLoginSession struct {
QRCode string
QRCodeImgURL string
PendingVerify string
CurrentBaseURL string
StartedAt time.Time
}
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
type WechatRobotHandler struct {
config *config.Config
configSaver WechatConfigSaver
logger *zap.Logger
mu sync.Mutex
logins map[string]*wechatLoginSession
}
// NewWechatRobotHandler 创建微信机器人处理器
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
return &WechatRobotHandler{
config: cfg,
configSaver: saver,
logger: logger,
logins: make(map[string]*wechatLoginSession),
}
}
func (h *WechatRobotHandler) purgeExpiredLogins() {
now := time.Now()
for k, v := range h.logins {
if now.Sub(v.StartedAt) > wechatLoginTTL {
delete(h.logins, k)
}
}
}
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
ver := h.config.Version
if ver == "" {
ver = "1.0.0"
}
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
ver = strings.TrimPrefix(ver, "V")
wc := h.config.Robots.Wechat
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
}
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
h.mu.Lock()
h.purgeExpiredLogins()
h.mu.Unlock()
var req struct {
BotType string `json:"bot_type"`
}
_ = c.ShouldBindJSON(&req)
botType := req.BotType
if botType == "" {
botType = h.config.Robots.Wechat.BotType
}
if botType == "" {
botType = ilink.DefaultBotType
}
baseURL := h.config.Robots.Wechat.BaseURL
if baseURL == "" {
baseURL = ilink.DefaultBaseURL
}
var localTokens []string
if t := h.config.Robots.Wechat.BotToken; t != "" {
localTokens = []string{t}
}
client := h.ilinkClient(baseURL)
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
defer cancel()
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
if err != nil {
h.logger.Warn("获取微信二维码失败", zap.Error(err))
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
return
}
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
return
}
sessionKey := uuid.New().String()
h.mu.Lock()
h.logins[sessionKey] = &wechatLoginSession{
QRCode: qr.QRCode,
QRCodeImgURL: qr.QRCodeImgContent,
CurrentBaseURL: baseURL,
StartedAt: time.Now(),
}
h.mu.Unlock()
resp := gin.H{
"session_key": sessionKey,
"qrcode": qr.QRCode,
"qrcode_open_url": qr.QRCodeImgContent,
"message": "请使用微信扫描二维码并确认绑定",
}
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
h.logger.Warn("生成二维码图片失败", zap.Error(err))
} else {
resp["qrcode_image_data_url"] = dataURL
}
c.JSON(http.StatusOK, resp)
}
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
sessionKey := c.Query("session_key")
verifyCode := c.Query("verify_code")
if sessionKey == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
return
}
h.mu.Lock()
sess, ok := h.logins[sessionKey]
h.mu.Unlock()
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
return
}
if time.Since(sess.StartedAt) > wechatLoginTTL {
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
return
}
baseURL := sess.CurrentBaseURL
if baseURL == "" {
baseURL = ilink.DefaultBaseURL
}
vc := verifyCode
if vc == "" {
vc = sess.PendingVerify
}
client := h.ilinkClient(baseURL)
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
defer cancel()
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
if err != nil {
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
switch st.Status {
case "wait", "scaned":
c.JSON(http.StatusOK, gin.H{"status": st.Status})
return
case "need_verifycode":
c.JSON(http.StatusOK, gin.H{
"status": st.Status,
"message": "请在手机微信查看配对数字,并在下方输入",
})
return
case "scaned_but_redirect":
if st.RedirectHost != "" {
h.mu.Lock()
if s, ok := h.logins[sessionKey]; ok {
s.CurrentBaseURL = "https://" + st.RedirectHost
}
h.mu.Unlock()
}
c.JSON(http.StatusOK, gin.H{"status": st.Status})
return
case "binded_redirect":
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusOK, gin.H{
"status": st.Status,
"already_connected": true,
"message": "该微信已绑定过,无需重复绑定",
})
return
case "confirmed":
if st.BotToken == "" || st.ILinkBotID == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
return
}
saveBase := st.BaseURL
if saveBase == "" {
saveBase = baseURL
}
wc := h.config.Robots.Wechat
wc.Enabled = true
wc.BotToken = st.BotToken
wc.ILinkBotID = st.ILinkBotID
wc.ILinkUserID = st.ILinkUserID
wc.BaseURL = saveBase
if wc.BotType == "" {
wc.BotType = ilink.DefaultBotType
}
if wc.BotAgent == "" {
wc.BotAgent = ilink.DefaultBotAgent
}
if h.configSaver != nil {
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
} else {
h.config.Robots.Wechat = wc
}
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusOK, gin.H{
"status": "confirmed",
"message": "绑定成功,微信机器人已启用",
"ilink_bot_id": st.ILinkBotID,
"ilink_user_id": st.ILinkUserID,
})
return
default:
c.JSON(http.StatusOK, gin.H{"status": st.Status})
}
}
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
var req struct {
SessionKey string `json:"session_key"`
VerifyCode string `json:"verify_code"`
}
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
return
}
h.mu.Lock()
sess, ok := h.logins[req.SessionKey]
if ok {
sess.PendingVerify = req.VerifyCode
}
h.mu.Unlock()
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
}
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
wc := h.config.Robots.Wechat
bound := wc.BotToken != "" && wc.ILinkBotID != ""
c.JSON(http.StatusOK, gin.H{
"enabled": wc.Enabled,
"bound": bound,
"ilink_bot_id": wc.ILinkBotID,
"ilink_user_id": wc.ILinkUserID,
"base_url": wc.BaseURL,
})
}
+24 -1
View File
@@ -4,7 +4,16 @@ package builtin
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串 // 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
const ( const (
// 漏洞管理工具 // 漏洞管理工具
ToolRecordVulnerability = "record_vulnerability" ToolRecordVulnerability = "record_vulnerability"
ToolListVulnerabilities = "list_vulnerabilities"
ToolGetVulnerability = "get_vulnerability"
// 项目黑板(事实)工具
ToolUpsertProjectFact = "upsert_project_fact"
ToolGetProjectFact = "get_project_fact"
ToolListProjectFacts = "list_project_facts"
ToolSearchProjectFacts = "search_project_facts"
ToolDeprecateProjectFact = "deprecate_project_fact"
// 知识库工具 // 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types" ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
@@ -53,6 +62,13 @@ const (
func IsBuiltinTool(toolName string) bool { func IsBuiltinTool(toolName string) bool {
switch toolName { switch toolName {
case ToolRecordVulnerability, case ToolRecordVulnerability,
ToolListVulnerabilities,
ToolGetVulnerability,
ToolUpsertProjectFact,
ToolGetProjectFact,
ToolListProjectFacts,
ToolSearchProjectFacts,
ToolDeprecateProjectFact,
ToolListKnowledgeRiskTypes, ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase, ToolSearchKnowledgeBase,
ToolWebshellExec, ToolWebshellExec,
@@ -96,6 +112,13 @@ func IsBuiltinTool(toolName string) bool {
func GetAllBuiltinTools() []string { func GetAllBuiltinTools() []string {
return []string{ return []string{
ToolRecordVulnerability, ToolRecordVulnerability,
ToolListVulnerabilities,
ToolGetVulnerability,
ToolUpsertProjectFact,
ToolGetProjectFact,
ToolListProjectFacts,
ToolSearchProjectFacts,
ToolDeprecateProjectFact,
ToolListKnowledgeRiskTypes, ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase, ToolSearchKnowledgeBase,
ToolWebshellExec, ToolWebshellExec,
+63 -25
View File
@@ -77,6 +77,9 @@ type einoADKRunLoopArgs struct {
StreamsMainAssistant func(agent string) bool StreamsMainAssistant func(agent string) bool
EinoRoleTag func(agent string) string EinoRoleTag func(agent string) string
CheckpointDir string CheckpointDir string
// RunRetryMaxAttempts / RunRetryMaxBackoffSec429、5xx、网络抖动时的指数退避续跑(0=默认 10 次 / 30s 上限)。
RunRetryMaxAttempts int
RunRetryMaxBackoffSec int
McpIDsMu *sync.Mutex McpIDsMu *sync.Mutex
McpIDs *[]string McpIDs *[]string
@@ -177,6 +180,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
var einoMainRound int var einoMainRound int
var einoLastAgent string var einoLastAgent string
subAgentToolStep := make(map[string]int) subAgentToolStep := make(map[string]int)
// mainAgentToolStep:主代理每次工具调用批次递增,供 UI 显示「第 N 轮」(单代理无子代理切换时原先会一直停在第 1 轮)。
mainAgentToolStep := make(map[string]int)
pendingByID := make(map[string]toolCallPendingInfo) pendingByID := make(map[string]toolCallPendingInfo)
pendingQueueByAgent := make(map[string][]string) pendingQueueByAgent := make(map[string][]string)
markPending := func(tc toolCallPendingInfo) { markPending := func(tc toolCallPendingInfo) {
@@ -435,6 +440,28 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
return runErr return runErr
} }
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
if runErr == nil || !isEinoTransientRunError(runErr) {
return false, handleRunErr(runErr)
}
if logger != nil {
logger.Warn("eino transient error, ending run segment for handler resume",
zap.Error(runErr),
zap.String("orchestration", orchMode))
}
if progress != nil {
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"orchestration": orchMode,
"error": runErr.Error(),
"resumeKind": "trace_segment",
})
}
return false, ErrTransientRetryContinue
}
takePartial := func(runErr error) (*RunResult, error) { takePartial := func(runErr error) (*RunResult, error) {
if len(runAccumulatedMsgs) <= baseAccumulatedCount { if len(runAccumulatedMsgs) <= baseAccumulatedCount {
return nil, runErr return nil, runErr
@@ -517,7 +544,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue continue
} }
if ev.Err != nil { if ev.Err != nil {
if retErr := handleRunErr(ev.Err); retErr != nil { if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
return takePartial(retErr) return takePartial(retErr)
} }
} }
@@ -529,8 +556,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
if streamsMainAssistant(ev.AgentName) { if streamsMainAssistant(ev.AgentName) {
mainIterKey := einoMainIterationKey(iterEinoAgent, orchestratorName)
if einoMainRound == 0 { if einoMainRound == 0 {
einoMainRound = 1 einoMainRound = 1
mainAgentToolStep[mainIterKey] = 1
progress("iteration", "", map[string]interface{}{ progress("iteration", "", map[string]interface{}{
"iteration": 1, "iteration": 1,
"einoScope": "main", "einoScope": "main",
@@ -540,17 +569,26 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
}) })
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) { } else if einoLastAgent != "" {
einoMainRound++ needBump := false
progress("iteration", "", map[string]interface{}{ if !streamsMainAssistant(einoLastAgent) {
"iteration": einoMainRound, needBump = true // 子代理 → 主代理
"einoScope": "main", } else if einoLastAgent != ev.AgentName {
"einoRole": "orchestrator", needBump = true // plan_executeplanner ↔ executor 等主代理切换
"einoAgent": iterEinoAgent, }
"orchestration": orchMode, if needBump {
"conversationId": conversationID, einoMainRound++
"source": "eino", mainAgentToolStep[mainIterKey] = einoMainRound
}) progress("iteration", "", map[string]interface{}{
"iteration": einoMainRound,
"einoScope": "main",
"einoRole": "orchestrator",
"einoAgent": iterEinoAgent,
"orchestration": orchMode,
"conversationId": conversationID,
"source": "eino",
})
}
} }
} }
einoLastAgent = ev.AgentName einoLastAgent = ev.AgentName
@@ -644,9 +682,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"orchestration": orchMode, "orchestration": orchMode,
}) })
} }
progress("reasoning_chain_stream_delta", displayDelta, map[string]interface{}{ progress("reasoning_chain_stream_delta", displayDelta, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": reasoningStreamID, "streamId": reasoningStreamID,
}) }, fullDisplay))
} }
} }
} }
@@ -676,13 +714,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
}) })
streamHeaderSent = true streamHeaderSent = true
} }
progress("response_delta", contentDelta, map[string]interface{}{ progress("response_delta", contentDelta, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(), "mcpExecutionIds": snapshotMCPIDs(),
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
}) }, mainAssistantBuf))
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta) mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
} }
} }
@@ -701,10 +739,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"source": "eino", "source": "eino",
}) })
} }
progress("eino_agent_reply_stream_delta", subDelta, map[string]interface{}{ progress("eino_agent_reply_stream_delta", subDelta, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": subReplyStreamID, "streamId": subReplyStreamID,
"conversationId": conversationID, "conversationId": conversationID,
}) }, subAssistantBuf))
} }
} }
} }
@@ -743,13 +781,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"orchestration": orchMode, "orchestration": orchMode,
}) })
} }
progress("response_delta", eofTail, map[string]interface{}{ progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(), "mcpExecutionIds": snapshotMCPIDs(),
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
}) }, mainAssistantBuf))
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail) mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
} }
} }
@@ -791,7 +829,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 { if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged}) lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
} }
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending) tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。 // 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 { if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls)) runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
@@ -808,7 +846,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": einoRoleTag(ev.AgentName), "einoRole": einoRoleTag(ev.AgentName),
}) })
} }
if retErr := handleRunErr(streamRecvErr); retErr != nil { if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
return takePartial(retErr) return takePartial(retErr)
} }
} }
@@ -820,7 +858,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue continue
} }
runAccumulatedMsgs = append(runAccumulatedMsgs, msg) runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending) tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
if mv.Role == schema.Assistant { if mv.Role == schema.Assistant {
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" { if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
@@ -859,13 +897,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
}) })
progress("response_delta", body, map[string]interface{}{ progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(), "mcpExecutionIds": snapshotMCPIDs(),
"einoRole": "orchestrator", "einoRole": "orchestrator",
"einoAgent": ev.AgentName, "einoAgent": ev.AgentName,
"orchestration": orchMode, "orchestration": orchMode,
}) }, body))
} }
lastAssistant = body lastAssistant = body
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") { if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
+7 -3
View File
@@ -13,12 +13,12 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai" einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/compose" "github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -39,6 +39,7 @@ func RunEinoSingleChatModelAgent(
roleTools []string, roleTools []string,
progress func(eventType, message string, data interface{}), progress func(eventType, message string, data interface{}),
reasoningClient *reasoning.ClientIntent, reasoningClient *reasoning.ClientIntent,
systemPromptExtra string,
) (*RunResult, error) { ) (*RunResult, error) {
if appCfg == nil || ag == nil { if appCfg == nil || ag == nil {
return nil, fmt.Errorf("eino single: 配置或 Agent 为空") return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
@@ -178,7 +179,8 @@ func RunEinoSingleChatModelAgent(
}, },
EmitInternalEvents: true, EmitInternalEvents: true,
} }
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive) ins := project.AppendSystemPromptBlock(ag.EinoSingleAgentSystemInstruction(), systemPromptExtra)
ins = injectToolNamesOnlyInstruction(ctx, ins, mainTools, singleToolSearchActive)
if logger != nil { if logger != nil {
names := collectToolNames(ctx, mainTools) names := collectToolNames(ctx, mainTools)
mountedNames := collectToolNames(ctx, mainToolsForCfg) mountedNames := collectToolNames(ctx, mainToolsForCfg)
@@ -213,7 +215,7 @@ func RunEinoSingleChatModelAgent(
} }
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware) baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage)) baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
streamsMainAssistant := func(agent string) bool { streamsMainAssistant := func(agent string) bool {
return agent == "" || agent == einoSingleAgentName return agent == "" || agent == einoSingleAgentName
@@ -233,6 +235,8 @@ func RunEinoSingleChatModelAgent(
StreamsMainAssistant: streamsMainAssistant, StreamsMainAssistant: streamsMainAssistant,
EinoRoleTag: einoRoleTag, EinoRoleTag: einoRoleTag,
CheckpointDir: ma.EinoMiddleware.CheckpointDir, CheckpointDir: ma.EinoMiddleware.CheckpointDir,
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
McpIDsMu: &mcpIDsMu, McpIDsMu: &mcpIDsMu,
McpIDs: &mcpIDs, McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag, FilesystemMonitorAgent: ag,
+173
View File
@@ -0,0 +1,173 @@
package multiagent
import (
"context"
"errors"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
const (
defaultEinoRunRetryMaxAttempts = 10
defaultEinoRunRetryMaxBackoff = 30 * time.Second
)
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
func isEinoTransientRunError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if isEinoIterationLimitError(err) {
return false
}
msg := strings.ToLower(strings.TrimSpace(err.Error()))
if msg == "" {
return false
}
transientMarkers := []string{
"406",
"429",
"too many requests",
"rate limit",
"rate_limit",
"ratelimit",
"quota exceeded",
"overloaded",
"capacity",
"temporarily unavailable",
"service unavailable",
"bad gateway",
"gateway timeout",
"internal server error",
"connection reset",
"connection refused",
"connection closed",
"i/o timeout",
"no such host",
"network is unreachable",
"broken pipe",
"eof",
"read tcp",
"write tcp",
"dial tcp",
"tls handshake timeout",
"stream error",
"unexpected eof",
"unexpected end of json",
"status code: 406",
"status code: 502",
"502",
"503",
"504",
"500",
}
for _, m := range transientMarkers {
if strings.Contains(msg, m) {
return true
}
}
return false
}
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
if args != nil && args.RunRetryMaxAttempts > 0 {
return args.RunRetryMaxAttempts
}
return defaultEinoRunRetryMaxAttempts
}
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
if mw != nil && mw.RunRetryMaxAttempts > 0 {
return mw.RunRetryMaxAttempts
}
return defaultEinoRunRetryMaxAttempts
}
// TransientRetryBackoff 供 handler 在分段续跑前退避。
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
max := defaultEinoRunRetryMaxBackoff
if maxBackoffSec > 0 {
max = time.Duration(maxBackoffSec) * time.Second
}
return einoTransientRetryBackoff(attempt, max)
}
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
if args != nil && args.RunRetryMaxBackoffSec > 0 {
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
}
return defaultEinoRunRetryMaxBackoff
}
// einoRunRestartContextSource 描述无 checkpoint Resume 时 Run 使用的消息来源(日志/SSE)。
type einoRunRestartContextSource string
const (
einoRestartContextInitial einoRunRestartContextSource = "initial"
einoRestartContextAccumulated einoRunRestartContextSource = "accumulated"
einoRestartContextModelTrace einoRunRestartContextSource = "model_trace"
)
// einoMessagesForRunRestart 在退避后重新 Run 时选用最完整的上下文:
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
if trace := persistTraceSource(args, nil); len(trace) > 0 {
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
}
if len(accumulated) > baseCount {
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
}
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
}
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
want = strings.TrimSpace(want)
if want == "" {
return true
}
for i := len(msgs) - 1; i >= 0; i-- {
m := msgs[i]
if m == nil {
continue
}
if m.Role == schema.User {
return strings.TrimSpace(m.Content) == want
}
if m.Role == schema.Assistant || m.Role == schema.Tool {
continue
}
break
}
return false
}
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
return msgs
}
return append(msgs, schema.UserMessage(userMessage))
}
// einoTransientRetryBackoff 指数退避:2s, 4s, 8s… capped by maxBackoff。
func einoTransientRetryBackoff(attempt int, maxBackoff time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
if maxBackoff > 0 && backoff > maxBackoff {
backoff = maxBackoff
}
return backoff
}
@@ -0,0 +1,104 @@
package multiagent
import (
"context"
"errors"
"testing"
"time"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func TestIsEinoTransientRunError(t *testing.T) {
t.Parallel()
cases := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"429", errors.New("HTTP 429 Too Many Requests"), true},
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
{"503", errors.New("upstream returned 503"), true},
{"iteration limit", errors.New("max iteration reached"), false},
{"canceled", context.Canceled, false},
{"deadline", context.DeadlineExceeded, false},
{"auth", errors.New("invalid api key"), false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isEinoTransientRunError(tc.err); got != tc.want {
t.Fatalf("isEinoTransientRunError(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}
func TestEinoTransientRetryBackoff(t *testing.T) {
t.Parallel()
max := 30 * time.Second
if got := einoTransientRetryBackoff(0, max); got != 2*time.Second {
t.Fatalf("attempt 0: got %v", got)
}
if got := einoTransientRetryBackoff(4, max); got != 30*time.Second {
t.Fatalf("attempt 4 capped: got %v", got)
}
}
func TestEinoMessagesForRunRestart(t *testing.T) {
t.Parallel()
base := []adk.Message{schema.UserMessage("hi")}
acc := append([]adk.Message(nil), base...)
acc = append(acc, schema.AssistantMessage("step1", nil))
got, src := einoMessagesForRunRestart(nil, base, acc, len(base))
if src != einoRestartContextAccumulated || len(got) != 2 {
t.Fatalf("accumulated: src=%s len=%d", src, len(got))
}
holder := newModelFacingTraceHolder()
holder.storeFromState(&adk.ChatModelAgentState{
Messages: []adk.Message{schema.UserMessage("u"), schema.AssistantMessage("model-view", nil)},
})
got2, src2 := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, base, acc, len(base))
if src2 != einoRestartContextModelTrace || len(got2) != 2 {
t.Fatalf("model trace: src=%s len=%d", src2, len(got2))
}
}
func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
t.Parallel()
if einoRunRetryMaxAttempts(nil) != defaultEinoRunRetryMaxAttempts {
t.Fatal("nil args should use default")
}
if einoRunRetryMaxAttempts(&einoADKRunLoopArgs{RunRetryMaxAttempts: 3}) != 3 {
t.Fatal("custom max attempts")
}
if RunRetryMaxAttemptsFromConfig(nil) != defaultEinoRunRetryMaxAttempts {
t.Fatal("config nil should use default")
}
}
func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Parallel()
msgs := []adk.Message{schema.UserMessage("old task")}
out := appendUserMessageIfNeeded(msgs, "你好,你是谁")
if len(out) != 2 || out[1].Content != "你好,你是谁" {
t.Fatalf("should append user: len=%d", len(out))
}
dup := appendUserMessageIfNeeded(out, "你好,你是谁")
if len(dup) != 2 {
t.Fatalf("should not duplicate user message: len=%d", len(dup))
}
}
func TestErrTransientRetryContinue(t *testing.T) {
t.Parallel()
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
t.Fatal("sentinel should match")
}
}
+4
View File
@@ -5,3 +5,7 @@ import "errors"
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时, // ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。 // 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context") var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
@@ -106,16 +106,16 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
当工具返回错误时错误信息会包含在工具响应中请仔细阅读并做出合理的决策 当工具返回错误时错误信息会包含在工具响应中请仔细阅读并做出合理的决策
## 漏洞记录 ## 项目黑板事实与漏洞记录分离
发现有效漏洞时必须使用 ` + builtin.ToolRecordVulnerability + ` 记录标题描述严重程度类型目标证明POC影响修复建议 绑定项目时会自动注入黑板索引fact_key + 摘要**摘要不足必须 ` + builtin.ToolGetProjectFact + `(fact_key) body禁止臆造** 环境认知用 ` + builtin.ToolUpsertProjectFact + `key target/primary_domain正式漏洞用 ` + builtin.ToolRecordVulnerability + `记前可先 ` + builtin.ToolListVulnerabilities + ` 防重复详情用 ` + builtin.ToolGetVulnerability + `二者可各记一次误报用 ` + builtin.ToolDeprecateProjectFact + `漏洞查询默认仅当前项目未绑项目则仅当前会话
严重程度critical / high / medium / low / info证明须含足够证据请求响应截图命令输出等记录后可在授权范围内继续测试 严重程度critical / high / medium / low / info证明须含足够证据
## 技能库Skills与知识库 ## 技能库Skills与知识库
- 技能包位于服务器 skills/ 目录各子目录 SKILL.md遵循 agentskills.io知识库用于向量检索片段Skills 为可执行工作流指令 - 技能包位于服务器 skills/ 目录各子目录 SKILL.md遵循 agentskills.io知识库用于向量检索片段Skills 为可执行工作流指令
- plan_execute 执行器通过 MCP 使用知识库与漏洞记录等Skills 的渐进式加载在多代理 / Eino DeepAgent等模式中由内置 skill 工具完成 multi_agent.eino_skills - plan_execute 执行器通过 MCP 使用知识库项目事实与漏洞记录等Skills 的渐进式加载在多代理 / Eino DeepAgent等模式中由内置 skill 工具完成 multi_agent.eino_skills
- 若需要完整 Skill 工作流而当前会话无 skill 工具请在计划或对用户说明中建议切换多代理或 Eino 编排会话 - 若需要完整 Skill 工作流而当前会话无 skill 工具请在计划或对用户说明中建议切换多代理或 Eino 编排会话
## 执行器对用户输出重要 ## 执行器对用户输出重要
@@ -206,7 +206,7 @@ func DefaultSupervisorOrchestratorInstruction() string {
- **委派优先**可独立封装需要专项上下文的子目标枚举验证归纳报告素材优先 transfer 给匹配子代理并在委派说明中写清子目标约束期望交付物结构证据要求 - **委派优先**可独立封装需要专项上下文的子目标枚举验证归纳报告素材优先 transfer 给匹配子代理并在委派说明中写清子目标约束期望交付物结构证据要求
- **亲自执行**仅当无合适专家需全局衔接或子代理结果不足时由你直接调用工具 - **亲自执行**仅当无合适专家需全局衔接或子代理结果不足时由你直接调用工具
- **汇总**子代理输出是证据来源你要对齐矛盾补全上下文给出统一结论与可复现验证步骤避免机械拼接 - **汇总**子代理输出是证据来源你要对齐矛盾补全上下文给出统一结论与可复现验证步骤避免机械拼接
- **漏洞**有效漏洞应通过 ` + builtin.ToolRecordVulnerability + ` 记录 POC 与严重性critical / high / medium / low / info - **事实与漏洞**环境认知用 ` + builtin.ToolUpsertProjectFact + `正式漏洞用 ` + builtin.ToolRecordVulnerability + `查询用 ` + builtin.ToolListVulnerabilities + ` / ` + builtin.ToolGetVulnerability + `索引摘要不足时必须 ` + builtin.ToolGetProjectFact + ` 取详情
## transfer 交接与防重复劳动 ## transfer 交接与防重复劳动
+47 -7
View File
@@ -17,6 +17,7 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai" einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
@@ -64,6 +65,7 @@ func RunDeepAgent(
agentsMarkdownDir string, agentsMarkdownDir string,
orchestrationOverride string, orchestrationOverride string,
reasoningClient *reasoning.ClientIntent, reasoningClient *reasoning.ClientIntent,
systemPromptExtra string,
) (*RunResult, error) { ) (*RunResult, error) {
if appCfg == nil || ma == nil || ag == nil { if appCfg == nil || ma == nil || ag == nil {
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空") return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
@@ -339,6 +341,7 @@ func RunDeepAgent(
return nil, err return nil, err
} }
orchInstruction = project.AppendSystemPromptBlock(orchInstruction, systemPromptExtra)
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive) orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
if logger != nil { if logger != nil {
mainNames := collectToolNames(ctx, mainTools) mainNames := collectToolNames(ctx, mainTools)
@@ -387,7 +390,8 @@ func RunDeepAgent(
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。 // noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()} deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes); mw != nil { taskEnrichExtra := systemPromptExtra
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil {
deepHandlers = append(deepHandlers, mw) deepHandlers = append(deepHandlers, mw)
} }
if len(mainOrchestratorPre) > 0 { if len(mainOrchestratorPre) > 0 {
@@ -538,7 +542,7 @@ func RunDeepAgent(
} }
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware) baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage)) baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
streamsMainAssistant := func(agent string) bool { streamsMainAssistant := func(agent string) bool {
if orchMode == "plan_execute" { if orchMode == "plan_execute" {
@@ -566,6 +570,8 @@ func RunDeepAgent(
StreamsMainAssistant: streamsMainAssistant, StreamsMainAssistant: streamsMainAssistant,
EinoRoleTag: einoRoleTag, EinoRoleTag: einoRoleTag,
CheckpointDir: ma.EinoMiddleware.CheckpointDir, CheckpointDir: ma.EinoMiddleware.CheckpointDir,
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
McpIDsMu: &mcpIDsMu, McpIDsMu: &mcpIDsMu,
McpIDs: &mcpIDs, McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag, FilesystemMonitorAgent: ag,
@@ -595,6 +601,13 @@ func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
argsStr = string(b) argsStr = string(b)
} }
} }
// Some OpenAI-compatible gateways require `function.arguments` to exist
// on every assistant tool_call message. When args are empty, omitempty may
// drop the field during serialization and cause "missing field arguments"
// on the next turn history replay.
if strings.TrimSpace(argsStr) == "" {
argsStr = "{}"
}
typ := tc.Type typ := tc.Type
if typ == "" { if typ == "" {
typ = "function" typ = "function"
@@ -737,12 +750,23 @@ func toolCallsRichSignature(msg *schema.Message) string {
return base + "|" + strings.Join(parts, ";") return base + "|" + strings.Join(parts, ";")
} }
func einoMainIterationKey(agentName, orchestratorName string) string {
key := strings.TrimSpace(agentName)
if key == "" {
key = strings.TrimSpace(orchestratorName)
}
if key == "" {
return "_main"
}
return key
}
func tryEmitToolCallsOnce( func tryEmitToolCallsOnce(
msg *schema.Message, msg *schema.Message,
agentName, orchestratorName, conversationID string, agentName, orchestratorName, conversationID, orchMode string,
progress func(string, string, interface{}), progress func(string, string, interface{}),
seen map[string]struct{}, seen map[string]struct{},
subAgentToolStep map[string]int, subAgentToolStep, mainAgentToolStep map[string]int,
markPending func(toolCallPendingInfo), markPending func(toolCallPendingInfo),
) { ) {
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil { if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
@@ -756,14 +780,14 @@ func tryEmitToolCallsOnce(
return return
} }
seen[sig] = struct{}{} seen[sig] = struct{}{}
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending) emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, orchMode, progress, subAgentToolStep, mainAgentToolStep, markPending)
} }
func emitToolCallsFromMessage( func emitToolCallsFromMessage(
msg *schema.Message, msg *schema.Message,
agentName, orchestratorName, conversationID string, agentName, orchestratorName, conversationID, orchMode string,
progress func(string, string, interface{}), progress func(string, string, interface{}),
subAgentToolStep map[string]int, subAgentToolStep, mainAgentToolStep map[string]int,
markPending func(toolCallPendingInfo), markPending func(toolCallPendingInfo),
) { ) {
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil { if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
@@ -784,6 +808,22 @@ func emitToolCallsFromMessage(
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
}) })
} else if mainAgentToolStep != nil {
key := einoMainIterationKey(agentName, orchestratorName)
mainAgentToolStep[key]++
n := mainAgentToolStep[key]
// 第 1 轮已在主代理进入时发出;此后每次工具批次对应新一轮 ReAct(与子代理按工具计步一致)。
if n > 1 {
progress("iteration", "", map[string]interface{}{
"iteration": n,
"einoScope": "main",
"einoRole": "orchestrator",
"einoAgent": agentName,
"orchestration": orchMode,
"conversationId": conversationID,
"source": "eino",
})
}
} }
role := "orchestrator" role := "orchestrator"
if isSubToolRound { if isSubToolRound {
+8 -1
View File
@@ -30,8 +30,15 @@ type taskContextEnrichMiddleware struct {
// newTaskContextEnrichMiddleware returns a middleware that enriches task // newTaskContextEnrichMiddleware returns a middleware that enriches task
// descriptions with user conversation context. Returns nil if disabled // descriptions with user conversation context. Returns nil if disabled
// (maxRunes < 0) or no user messages exist. // (maxRunes < 0) or no user messages exist.
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int) adk.ChatModelAgentMiddleware { func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
supplement := buildUserContextSupplement(userMessage, history, maxRunes) supplement := buildUserContextSupplement(userMessage, history, maxRunes)
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
if supplement != "" {
supplement += "\n\n## 项目黑板索引\n" + bb
} else {
supplement = "\n\n## 项目黑板索引\n" + bb
}
}
if supplement == "" { if supplement == "" {
return nil return nil
} }
@@ -105,6 +105,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) {
"继续测试", "继续测试",
[]agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}}, []agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}},
0, 0,
"",
) )
if mw == nil { if mw == nil {
t.Fatal("expected non-nil middleware") t.Fatal("expected non-nil middleware")
@@ -149,7 +150,7 @@ func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) {
} }
func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) { func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
mw := newTaskContextEnrichMiddleware("test", nil, 0) mw := newTaskContextEnrichMiddleware("test", nil, 0, "")
if mw == nil { if mw == nil {
t.Fatal("expected non-nil middleware") t.Fatal("expected non-nil middleware")
} }
@@ -175,7 +176,7 @@ func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) {
} }
func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) { func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) {
mw := newTaskContextEnrichMiddleware("test", nil, -1) mw := newTaskContextEnrichMiddleware("test", nil, -1, "")
if mw != nil { if mw != nil {
t.Error("middleware should be nil when disabled") t.Error("middleware should be nil when disabled")
} }
+20
View File
@@ -0,0 +1,20 @@
package openai
// SSEAccumulatedKey 为 SSE progress 事件 data 中的服务端权威流式全文快照字段。
// 前端应优先用该字段更新 buffer,避免对 delta 二次 normalize 导致叠字。
const SSEAccumulatedKey = "accumulated"
// WithSSEAccumulated 在 progress data 中附带当前流式累计全文(权威快照)。
func WithSSEAccumulated(data map[string]interface{}, accumulated string) map[string]interface{} {
if data == nil {
data = make(map[string]interface{}, 1)
}
data[SSEAccumulatedKey] = accumulated
return data
}
// NormalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
// 与 unexported normalizeStreamingDelta 相同,供 agent / multiagent 等包在发 SSE 前累计正文。
func NormalizeStreamingDelta(current, incoming string) (next, delta string) {
return normalizeStreamingDelta(current, incoming)
}
+77
View File
@@ -0,0 +1,77 @@
package project
import (
"fmt"
"sort"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
)
// AppendSystemPromptBlock 将附加块追加到 system prompt。
func AppendSystemPromptBlock(base, block string) string {
base = strings.TrimSpace(base)
block = strings.TrimSpace(block)
if block == "" {
return base
}
if base == "" {
return block
}
return base + "\n\n" + block
}
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
if db == nil || !cfg.Enabled {
return "", nil
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return "", nil
}
proj, err := db.GetProject(projectID)
if err != nil {
return "", err
}
facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated)
if err != nil {
return "", err
}
if len(facts) == 0 {
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
}
sort.SliceStable(facts, func(i, j int) bool {
if facts[i].Pinned != facts[j].Pinned {
return facts[i].Pinned
}
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
})
maxRunes := cfg.FactIndexMaxRunesEffective()
var b strings.Builder
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s\n", proj.Name, proj.ID))
used := len([]rune(b.String()))
omitted := 0
for _, f := range facts {
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
lineRunes := len([]rune(line))
if used+lineRunes > maxRunes {
omitted++
continue
}
b.WriteString(line)
used += lineRunes
}
if omitted > 0 {
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
}
b.WriteString("需要完整内容(POC、长文本等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
return b.String(), nil
}
+9 -4
View File
@@ -149,13 +149,18 @@ func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, all
func normalizeEffort(s string) string { func normalizeEffort(s string) string {
e := strings.ToLower(strings.TrimSpace(s)) e := strings.ToLower(strings.TrimSpace(s))
switch e { switch e {
case "low", "medium", "high", "max": case "low", "medium", "high", "max", "xhigh":
return e return e
default: default:
return "" return ""
} }
} }
// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。
func usesExtraFieldsReasoningEffort(e string) bool {
return e == "max" || e == "xhigh"
}
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile { func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") { if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
return wireClaude return wireClaude
@@ -210,11 +215,11 @@ func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
if e == "" { if e == "" {
return return
} }
if e == "max" { if usesExtraFieldsReasoningEffort(e) {
if cfg.ExtraFields == nil { if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any) cfg.ExtraFields = make(map[string]any)
} }
cfg.ExtraFields["reasoning_effort"] = "max" cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e)
return return
} }
switch e { switch e {
@@ -245,6 +250,6 @@ func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort strin
} }
func effortStringForAPI(e string) string { func effortStringForAPI(e string) string {
// Gateways expect lowercase strings; "max" kept as max. // 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。
return strings.ToLower(strings.TrimSpace(e)) return strings.ToLower(strings.TrimSpace(e))
} }
+66
View File
@@ -0,0 +1,66 @@
package reasoning
import (
"testing"
"cyberstrike-ai/internal/config"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
)
func TestEffortStringForAPI_passthrough(t *testing.T) {
cases := map[string]string{
"max": "max",
"xhigh": "xhigh",
"HIGH": "high",
"Medium": "medium",
}
for in, want := range cases {
if got := effortStringForAPI(in); got != want {
t.Fatalf("%q -> %q, want %q", in, got, want)
}
}
}
func TestNormalizeEffort_maxAndXhigh(t *testing.T) {
if normalizeEffort("xhigh") != "xhigh" {
t.Fatal("xhigh not accepted")
}
if normalizeEffort("max") != "max" {
t.Fatal("max not accepted")
}
}
func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "xhigh",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
if cfg.ExtraFields == nil {
t.Fatal("expected ExtraFields")
}
if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" {
t.Fatalf("reasoning_effort=%q", got)
}
}
func TestApplyOpenAICompat_maxPassthrough(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "max",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
got, _ := cfg.ExtraFields["reasoning_effort"].(string)
if got != "max" {
t.Fatalf("max effort wire=%q, want max", got)
}
}
+316
View File
@@ -0,0 +1,316 @@
package ilink
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
DefaultBaseURL = "https://ilinkai.weixin.qq.com"
DefaultBotType = "3"
DefaultBotAgent = "CyberStrikeAI/1.0"
ILinkAppID = "bot"
QRLongPollTimeout = 35 * time.Second
APIDefaultTimeout = 15 * time.Second
GetUpdatesTimeout = 35 * time.Second
)
// Client 微信 iLink Bot HTTP 客户端(与 @tencent-weixin/openclaw-weixin 协议兼容)
type Client struct {
BaseURL string
BotToken string
BotAgent string
ClientVersion uint32
HTTP *http.Client
}
func NewClient(baseURL, botToken, botAgent string, clientVersion uint32) *Client {
base := strings.TrimSpace(baseURL)
if base == "" {
base = DefaultBaseURL
}
agent := strings.TrimSpace(botAgent)
if agent == "" {
agent = DefaultBotAgent
}
return &Client{
BaseURL: strings.TrimRight(base, "/"),
BotToken: strings.TrimSpace(botToken),
BotAgent: sanitizeBotAgent(agent),
ClientVersion: clientVersion,
HTTP: &http.Client{Timeout: 0},
}
}
// BuildClientVersion 将 semver 编码为 iLink-App-ClientVersion0x00MMNNPP
func BuildClientVersion(version string) uint32 {
parts := strings.Split(version, ".")
parse := func(i int) int {
if i >= len(parts) {
return 0
}
n, _ := strconv.Atoi(strings.TrimSpace(parts[i]))
if n < 0 {
return 0
}
return n
}
major := parse(0) & 0xff
minor := parse(1) & 0xff
patch := parse(2) & 0xff
return uint32((major << 16) | (minor << 8) | patch)
}
type baseInfo struct {
ChannelVersion string `json:"channel_version"`
BotAgent string `json:"bot_agent"`
}
func (c *Client) buildBaseInfo() baseInfo {
return baseInfo{
ChannelVersion: "1.0.0",
BotAgent: c.BotAgent,
}
}
func randomWechatUIN() string {
var b [4]byte
_, _ = rand.Read(b[:])
u := uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
return base64.StdEncoding.EncodeToString([]byte(strconv.FormatUint(uint64(u), 10)))
}
func (c *Client) commonHeaders() http.Header {
h := http.Header{}
h.Set("iLink-App-Id", ILinkAppID)
h.Set("iLink-App-ClientVersion", strconv.FormatUint(uint64(c.ClientVersion), 10))
return h
}
func (c *Client) authHeaders() http.Header {
h := c.commonHeaders()
h.Set("Content-Type", "application/json")
h.Set("AuthorizationType", "ilink_bot_token")
h.Set("X-WECHAT-UIN", randomWechatUIN())
if c.BotToken != "" {
h.Set("Authorization", "Bearer "+c.BotToken)
}
return h
}
func (c *Client) endpointURL(path string) (string, error) {
u, err := url.Parse(c.BaseURL + "/")
if err != nil {
return "", err
}
ref, err := url.Parse(path)
if err != nil {
return "", err
}
return u.ResolveReference(ref).String(), nil
}
func (c *Client) doRequest(ctx context.Context, method, path string, body []byte, headers http.Header, timeout time.Duration) ([]byte, error) {
reqURL, err := c.endpointURL(path)
if err != nil {
return nil, err
}
var bodyReader io.Reader
if len(body) > 0 {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader)
if err != nil {
return nil, err
}
for k, vs := range headers {
for _, v := range vs {
req.Header.Add(k, v)
}
}
client := c.HTTP
if client == nil {
client = http.DefaultClient
}
if timeout > 0 {
ctx2, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req = req.WithContext(ctx2)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("ilink %s %s: %d %s", method, path, resp.StatusCode, string(raw))
}
return raw, nil
}
// QRCodeResponse 获取二维码响应
type QRCodeResponse struct {
QRCode string `json:"qrcode"`
QRCodeImgContent string `json:"qrcode_img_content"`
}
// GetBotQRCode 获取绑定二维码
func (c *Client) GetBotQRCode(ctx context.Context, botType string, localTokenList []string) (*QRCodeResponse, error) {
if strings.TrimSpace(botType) == "" {
botType = DefaultBotType
}
body, _ := json.Marshal(map[string]interface{}{
"local_token_list": localTokenList,
})
path := "ilink/bot/get_bot_qrcode?bot_type=" + url.QueryEscape(botType)
raw, err := c.doRequest(ctx, http.MethodPost, path, body, c.authHeaders(), APIDefaultTimeout)
if err != nil {
return nil, err
}
var out QRCodeResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
return &out, nil
}
// QRStatusResponse 二维码状态轮询响应
type QRStatusResponse struct {
Status string `json:"status"`
BotToken string `json:"bot_token"`
ILinkBotID string `json:"ilink_bot_id"`
ILinkUserID string `json:"ilink_user_id"`
BaseURL string `json:"baseurl"`
RedirectHost string `json:"redirect_host"`
}
// GetQRCodeStatus 长轮询二维码扫码状态
func (c *Client) GetQRCodeStatus(ctx context.Context, qrcode, verifyCode string) (*QRStatusResponse, error) {
path := "ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(qrcode)
if verifyCode != "" {
path += "&verify_code=" + url.QueryEscape(verifyCode)
}
raw, err := c.doRequest(ctx, http.MethodGet, path, nil, c.commonHeaders(), QRLongPollTimeout)
if err != nil {
if ctx.Err() != nil {
return &QRStatusResponse{Status: "wait"}, nil
}
return &QRStatusResponse{Status: "wait"}, nil
}
var out QRStatusResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
return &out, nil
}
// MessageItem 消息内容项
type MessageItem struct {
Type int `json:"type"`
TextItem *struct {
Text string `json:"text"`
} `json:"text_item,omitempty"`
}
// WeixinMessage 入站消息
type WeixinMessage struct {
FromUserID string `json:"from_user_id"`
MessageType int `json:"message_type"`
MessageState int `json:"message_state"`
ItemList []MessageItem `json:"item_list"`
ContextToken string `json:"context_token"`
}
// GetUpdatesResponse 长轮询消息响应
type GetUpdatesResponse struct {
Ret int `json:"ret"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
Msgs []WeixinMessage `json:"msgs"`
GetUpdatesBuf string `json:"get_updates_buf"`
LongPollingTimeoutMs int `json:"longpolling_timeout_ms"`
}
// GetUpdates 长轮询获取新消息
func (c *Client) GetUpdates(ctx context.Context, getUpdatesBuf string) (*GetUpdatesResponse, error) {
body, _ := json.Marshal(map[string]interface{}{
"get_updates_buf": getUpdatesBuf,
"base_info": c.buildBaseInfo(),
})
raw, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/getupdates", body, c.authHeaders(), GetUpdatesTimeout)
if err != nil {
if ctx.Err() != nil {
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
}
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
}
var out GetUpdatesResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
return &out, nil
}
// SendTextMessage 发送文本回复
func (c *Client) SendTextMessage(ctx context.Context, toUserID, contextToken, text, clientID string) error {
if clientID == "" {
clientID = randomClientID()
}
payload := map[string]interface{}{
"msg": map[string]interface{}{
"to_user_id": toUserID,
"client_id": clientID,
"message_type": 2,
"message_state": 2,
"context_token": contextToken,
"item_list": []map[string]interface{}{
{"type": 1, "text_item": map[string]string{"text": text}},
},
},
"base_info": c.buildBaseInfo(),
}
body, _ := json.Marshal(payload)
_, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/sendmessage", body, c.authHeaders(), APIDefaultTimeout)
return err
}
func randomClientID() string {
var b [8]byte
_, _ = rand.Read(b[:])
return fmt.Sprintf("%x", b)
}
func sanitizeBotAgent(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return DefaultBotAgent
}
if len(raw) > 256 {
return raw[:256]
}
return raw
}
// ExtractText 从消息中提取首条文本
func ExtractText(msg WeixinMessage) string {
for _, item := range msg.ItemList {
if item.Type == 1 && item.TextItem != nil {
return strings.TrimSpace(item.TextItem.Text)
}
}
return ""
}
+26
View File
@@ -0,0 +1,26 @@
package ilink
import (
"encoding/base64"
"fmt"
"strings"
"github.com/skip2/go-qrcode"
)
// QRCodeDataURL 将扫码内容(一般为 liteapp 链接)编码为 PNG data URL,供 Web 端展示。
// qrcode_img_content 不是图片直链,不能用作 <img src>。
func QRCodeDataURL(content string, size int) (string, error) {
content = strings.TrimSpace(content)
if content == "" {
return "", fmt.Errorf("empty qr content")
}
if size <= 0 {
size = 256
}
png, err := qrcode.Encode(content, qrcode.Medium, size)
if err != nil {
return "", err
}
return "data:image/png;base64," + base64.StdEncoding.EncodeToString(png), nil
}
+96
View File
@@ -0,0 +1,96 @@
package robot
import (
"context"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/robot/ilink"
"go.uber.org/zap"
)
const (
wechatReconnectInitial = 5 * time.Second
wechatReconnectMax = 60 * time.Second
wechatPlatform = "wechat"
)
// StartWechat 启动微信 iLink 长轮询(无需公网回调),收到消息后调用 handler 并回复。
func StartWechat(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
cfg := robotsCfg.Wechat
if !cfg.Enabled || cfg.BotToken == "" {
return
}
go runWechatLoop(ctx, cfg, h, appVersion, logger)
}
func runWechatLoop(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
backoff := wechatReconnectInitial
for {
err := runWechatPoll(ctx, cfg, h, appVersion, logger)
if ctx.Err() != nil {
logger.Info("微信 iLink 长轮询已按配置关闭")
return
}
if err != nil {
logger.Warn("微信 iLink 长轮询异常,将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
if backoff < wechatReconnectMax {
backoff *= 2
if backoff > wechatReconnectMax {
backoff = wechatReconnectMax
}
}
}
}
}
func runWechatPoll(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) error {
client := ilink.NewClient(cfg.BaseURL, cfg.BotToken, cfg.BotAgent, ilink.BuildClientVersion(appVersion))
buf := cfg.GetUpdatesBuf
logger.Info("微信 iLink 长轮询已启动", zap.String("ilink_bot_id", cfg.ILinkBotID))
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
resp, err := client.GetUpdates(ctx, buf)
if err != nil {
return err
}
if resp.ErrCode != 0 && resp.Ret != 0 {
logger.Warn("微信 getUpdates 返回错误", zap.Int("errcode", resp.ErrCode), zap.String("errmsg", resp.ErrMsg))
}
if resp.GetUpdatesBuf != "" {
buf = resp.GetUpdatesBuf
}
for _, msg := range resp.Msgs {
if msg.MessageType != 1 {
continue
}
text := ilink.ExtractText(msg)
if text == "" {
continue
}
userID := strings.TrimSpace(msg.FromUserID)
if userID == "" {
continue
}
logger.Info("微信收到消息", zap.String("from", userID), zap.String("content", text))
reply := h.HandleMessage(wechatPlatform, userID, text)
if strings.TrimSpace(reply) == "" {
continue
}
if err := client.SendTextMessage(ctx, userID, msg.ContextToken, reply, ""); err != nil {
logger.Warn("微信发送回复失败", zap.String("to", userID), zap.Error(err))
}
}
}
}
+2020 -112
View File
File diff suppressed because it is too large Load Diff
+146 -3
View File
@@ -820,6 +820,7 @@
"robots": "Bots", "robots": "Bots",
"terminal": "Terminal", "terminal": "Terminal",
"security": "Security", "security": "Security",
"audit": "Audit logs",
"infocollect": "Recon" "infocollect": "Recon"
}, },
"infocollect": { "infocollect": {
@@ -836,7 +837,32 @@
}, },
"robots": { "robots": {
"title": "Bot settings", "title": "Bot settings",
"description": "Configure WeCom, DingTalk and Lark bots so you can chat with CyberStrikeAI on your phone without opening the web UI.", "description": "Configure WeChat (iLink), WeCom, DingTalk and Lark bots so you can chat with CyberStrikeAI on your phone without opening the web UI.",
"wechat": {
"title": "WeChat / iLink",
"subtitle": "Bind personal WeChat via QR code and chat with CyberStrikeAI on your phone",
"statusIdle": "Not bound",
"statusBound": "Connected",
"statusScanning": "Binding…",
"step1": "Generate QR",
"step2": "Scan in WeChat",
"step3": "Confirm",
"enabled": "Enable WeChat bot",
"bindButton": "Generate QR code and bind",
"bindHint": "Scan with WeChat to confirm; settings are saved automatically.",
"qrLoading": "Generating QR code…",
"verifyCodeLabel": "Code on your phone (only if WeChat asks)",
"rebindButton": "Re-bind",
"boundBotId": "Bound Bot ID: ",
"verifyCodeSubmit": "Submit",
"advanced": "Advanced settings",
"baseUrl": "API Base URL",
"botType": "Bot Type",
"botAgent": "Bot Agent",
"ilinkBotId": "iLink Bot ID (filled after bind)",
"boundSuccess": "Binding successful. WeChat bot is enabled.",
"openLink": "QR not showing? Open link in WeChat on your phone"
},
"wecom": { "wecom": {
"title": "WeCom", "title": "WeCom",
"enabled": "Enable WeCom bot", "enabled": "Enable WeCom bot",
@@ -1505,6 +1531,7 @@
"confirmDelete": "Delete this file?", "confirmDelete": "Delete this file?",
"editTitle": "Edit file", "editTitle": "Edit file",
"renameTitle": "Rename", "renameTitle": "Rename",
"renameCurrentFile": "Current file",
"newFileName": "New file name", "newFileName": "New file name",
"empty": "No chat uploads yet", "empty": "No chat uploads yet",
"errorLoad": "Failed to load", "errorLoad": "Failed to load",
@@ -1518,6 +1545,10 @@
"statClickAll": "View all (clear severity filter)", "statClickAll": "View all (clear severity filter)",
"statClickFilter": "Click to filter by this severity; click again to clear", "statClickFilter": "Click to filter by this severity; click again to clear",
"advancedFilters": "Advanced filters", "advancedFilters": "Advanced filters",
"moreFilters": "More filters",
"applyFilters": "Apply",
"clearAdvanced": "Clear",
"clearAll": "Reset all",
"activeFilters": "Active filters", "activeFilters": "Active filters",
"chipRemove": "Remove", "chipRemove": "Remove",
"filter": "Filter", "filter": "Filter",
@@ -1537,6 +1568,10 @@
"statusFixed": "Fixed", "statusFixed": "Fixed",
"statusFalsePositive": "False positive", "statusFalsePositive": "False positive",
"searchVulnId": "Search vuln ID", "searchVulnId": "Search vuln ID",
"searchKeyword": "Search title, description, type, target…",
"searchKeywordShort": "Keyword",
"filterExactId": "Exact vuln ID",
"filterEnterHint": "Press Enter to filter",
"filterConversation": "Filter by conversation", "filterConversation": "Filter by conversation",
"loading": "Loading...", "loading": "Loading...",
"loadListFailed": "Failed to load", "loadListFailed": "Failed to load",
@@ -1553,6 +1588,11 @@
"detailVulnId": "Vuln ID", "detailVulnId": "Vuln ID",
"detailType": "Type", "detailType": "Type",
"detailTarget": "Target", "detailTarget": "Target",
"detailProject": "Project",
"projectUnbound": "No project",
"projectBindHint": "Once bound, agents can list this finding under the project scope.",
"projectBindFailed": "Failed to update project binding",
"projectBindOk": "Project binding updated",
"detailConversationId": "Conversation ID", "detailConversationId": "Conversation ID",
"detailTaskId": "Task ID", "detailTaskId": "Task ID",
"detailTaskQueueId": "Task queue ID", "detailTaskQueueId": "Task queue ID",
@@ -1664,8 +1704,8 @@
"multiAgentPeLoop": "plan_execute outer loop limit", "multiAgentPeLoop": "plan_execute outer loop limit",
"multiAgentPeLoopPlaceholder": "0 uses Eino default (10)", "multiAgentPeLoopPlaceholder": "0 uses Eino default (10)",
"multiAgentPeLoopHint": "Only for plan_execute; max execute↔replan rounds.", "multiAgentPeLoopHint": "Only for plan_execute; max execute↔replan rounds.",
"multiAgentRobotUse": "Use multi-agent for WeCom / DingTalk / Lark bots", "multiAgentRobotMode": "Default conversation mode for bots",
"multiAgentRobotUseHint": "Requires 'Enable multi-agent' to be checked; usage and cost will be higher.", "multiAgentRobotModeHint": "Execution mode for WeCom / DingTalk / Lark bot messages. Deep / Plan-Execute / Supervisor require multi-agent to be enabled.",
"multiAgentBatchUse": "Use multi-agent for batch task queues", "multiAgentBatchUse": "Use multi-agent for batch task queues",
"multiAgentBatchUseHint": "When enabled, each sub-task executed by queue in Task Management will run through Eino DeepAgent (requires multi-agent).", "multiAgentBatchUseHint": "When enabled, each sub-task executed by queue in Task Management will run through Eino DeepAgent (requires multi-agent).",
"enableKnowledge": "Enable knowledge retrieval", "enableKnowledge": "Enable knowledge retrieval",
@@ -1757,6 +1797,106 @@
"close": "×", "close": "×",
"newTerminal": "+" "newTerminal": "+"
}, },
"settingsAudit": {
"title": "Audit logs",
"description": "Platform admin actions (login, config, deletes). Does not log chat content, per-command terminal/WebShell runs, or per-tool invocations.",
"filterCategory": "Category",
"filterAction": "Action",
"filterEvent": "Event type",
"filterAllCategories": "All categories",
"filterAllActions": "All actions",
"filterCascadeHint": "Select a category to filter by action",
"filterResult": "Result",
"pageSize": "Per page",
"statTotal": "Filtered total",
"statFailures": "Failures",
"statRecent7d": "Last 7 days",
"retentionHint": "Audit records are kept for {{days}} days, then purged automatically.",
"disabledHint": "Audit logging is disabled; new actions are not written.",
"filterSince": "From",
"filterUntil": "Until",
"filterQuery": "Keyword",
"filterQueryPlaceholder": "Message / resource ID / action",
"cat": {
"auth": "Auth",
"config": "Config",
"terminal": "Terminal",
"c2": "C2",
"webshell": "WebShell",
"knowledge": "Knowledge",
"conversation": "Conversation",
"vulnerability": "Vulnerability",
"externalMcp": "External MCP",
"task": "Tasks",
"tool": "Tools",
"file": "Files",
"hitl": "HITL",
"role": "Roles",
"skill": "Skills",
"agent": "Sub-agents"
},
"act": {
"login": "Login",
"logout": "Logout",
"login_failed": "Login failed",
"password_change": "Password change",
"change_password": "Change password",
"apply": "Apply config",
"update": "Update",
"exec": "Terminal exec",
"exec_stream": "Terminal stream",
"listener_create": "Create listener",
"listener_delete": "Delete listener",
"listener_start": "Start listener",
"listener_stop": "Stop listener",
"session_delete": "Delete session",
"task_create": "Create task",
"task_cancel": "Cancel task",
"task_delete": "Delete task",
"connection_create": "Create connection",
"connection_delete": "Delete connection",
"item_delete": "Delete knowledge item",
"index_rebuild": "Rebuild index",
"delete": "Delete",
"delete_turn": "Delete turn",
"create": "Create",
"upsert": "Upsert external MCP",
"create_queue": "Create batch queue",
"start_queue": "Start batch queue",
"delete_queue": "Delete batch queue",
"pause_queue": "Pause batch queue",
"rerun_queue": "Rerun batch queue",
"delete_batch_task": "Delete batch subtask",
"execution_delete": "Delete execution",
"execution_delete_batch": "Batch delete executions",
"upload": "Upload",
"decision": "HITL decision",
"markdown_create": "Create sub-agent",
"markdown_update": "Update sub-agent",
"markdown_delete": "Delete sub-agent"
},
"openResource": "Open linked resource",
"openResourceChat": "Open linked resource (chat)",
"resourceIdLabel": "Resource ID",
"resourceRemoved": "(resource no longer exists)",
"filterAll": "All",
"filterBtn": "Filter",
"resetBtn": "Reset",
"exportBtn": "Export",
"exportJson": "Export JSON",
"export": "Export JSON",
"exportCsv": "Export CSV",
"exportDone": "Export complete",
"loading": "Loading...",
"empty": "No audit records",
"paginationShow": "{{start}}-{{end}} of {{total}}",
"detailTitle": "Audit detail",
"detailTime": "Time",
"detailCategory": "Category",
"detailResult": "Result",
"detailMessage": "Message",
"detailSession": "Session"
},
"settingsSecurity": { "settingsSecurity": {
"changePasswordTitle": "Change password", "changePasswordTitle": "Change password",
"changePasswordDesc": "After changing password, sign in again with the new password.", "changePasswordDesc": "After changing password, sign in again with the new password.",
@@ -2026,6 +2166,9 @@
"add": "Add" "add": "Add"
}, },
"vulnerabilityModal": { "vulnerabilityModal": {
"project": "Project",
"projectNone": "(Unbound)",
"projectHint": "Bound findings appear in list_vulnerabilities for that project; leave empty to infer from the conversation when possible.",
"conversationId": "Conversation ID", "conversationId": "Conversation ID",
"conversationIdPlaceholder": "Enter conversation ID", "conversationIdPlaceholder": "Enter conversation ID",
"conversationTag": "Conversation tag", "conversationTag": "Conversation tag",
+147 -4
View File
@@ -809,6 +809,7 @@
"robots": "机器人设置", "robots": "机器人设置",
"terminal": "终端", "terminal": "终端",
"security": "安全设置", "security": "安全设置",
"audit": "日志审计",
"infocollect": "信息收集" "infocollect": "信息收集"
}, },
"infocollect": { "infocollect": {
@@ -825,7 +826,32 @@
}, },
"robots": { "robots": {
"title": "机器人设置", "title": "机器人设置",
"description": "配置企业微信、钉钉、飞书等机器人,在手机端直接与 CyberStrikeAI 对话,无需在服务器上打开网页。", "description": "配置微信、企业微信、钉钉、飞书等机器人,在手机端直接与 CyberStrikeAI 对话,无需在服务器上打开网页。",
"wechat": {
"title": "微信 / iLink",
"subtitle": "扫码绑定个人微信,在手机端直接与 CyberStrikeAI 对话",
"statusIdle": "未绑定",
"statusBound": "已连接",
"statusScanning": "绑定中…",
"step1": "生成二维码",
"step2": "微信扫码",
"step3": "确认绑定",
"enabled": "启用微信机器人",
"bindButton": "生成二维码并绑定",
"bindHint": "用微信扫码确认后会自动保存并启用。",
"qrLoading": "正在生成二维码…",
"verifyCodeLabel": "手机显示的数字(仅部分账号需要)",
"rebindButton": "重新绑定",
"boundBotId": "已绑定 Bot ID",
"verifyCodeSubmit": "提交",
"advanced": "高级设置",
"baseUrl": "API Base URL",
"botType": "Bot Type",
"botAgent": "Bot Agent",
"ilinkBotId": "iLink Bot ID(绑定后自动填充)",
"boundSuccess": "绑定成功,微信机器人已启用。",
"openLink": "无法显示二维码?点击用手机微信打开链接"
},
"wecom": { "wecom": {
"title": "企业微信", "title": "企业微信",
"enabled": "启用企业微信机器人", "enabled": "启用企业微信机器人",
@@ -1494,6 +1520,7 @@
"confirmDelete": "确定删除该文件?", "confirmDelete": "确定删除该文件?",
"editTitle": "编辑文件", "editTitle": "编辑文件",
"renameTitle": "重命名", "renameTitle": "重命名",
"renameCurrentFile": "当前文件",
"newFileName": "新文件名", "newFileName": "新文件名",
"empty": "暂无对话附件", "empty": "暂无对话附件",
"errorLoad": "加载失败", "errorLoad": "加载失败",
@@ -1507,6 +1534,10 @@
"statClickAll": "查看全部(清除严重度筛选)", "statClickAll": "查看全部(清除严重度筛选)",
"statClickFilter": "点击按此严重度筛选;再次点击清除", "statClickFilter": "点击按此严重度筛选;再次点击清除",
"advancedFilters": "高级筛选", "advancedFilters": "高级筛选",
"moreFilters": "更多筛选",
"applyFilters": "应用",
"clearAdvanced": "清空",
"clearAll": "重置全部",
"activeFilters": "已选条件", "activeFilters": "已选条件",
"chipRemove": "移除", "chipRemove": "移除",
"filter": "筛选", "filter": "筛选",
@@ -1525,7 +1556,11 @@
"statusConfirmed": "已确认", "statusConfirmed": "已确认",
"statusFixed": "已修复", "statusFixed": "已修复",
"statusFalsePositive": "误报", "statusFalsePositive": "误报",
"searchVulnId": "搜索漏洞ID", "searchVulnId": "搜索漏洞 ID",
"searchKeyword": "搜索标题、描述、类型、目标…",
"searchKeywordShort": "关键词",
"filterExactId": "精确匹配漏洞 ID",
"filterEnterHint": "回车筛选",
"filterConversation": "筛选特定会话", "filterConversation": "筛选特定会话",
"loading": "加载中...", "loading": "加载中...",
"loadListFailed": "加载失败", "loadListFailed": "加载失败",
@@ -1542,6 +1577,11 @@
"detailVulnId": "漏洞ID", "detailVulnId": "漏洞ID",
"detailType": "类型", "detailType": "类型",
"detailTarget": "目标", "detailTarget": "目标",
"detailProject": "所属项目",
"projectUnbound": "未绑定项目",
"projectBindHint": "绑定后 Agent 可在项目范围内查询到该漏洞",
"projectBindFailed": "绑定项目失败",
"projectBindOk": "已更新项目绑定",
"detailConversationId": "会话ID", "detailConversationId": "会话ID",
"detailTaskId": "任务ID", "detailTaskId": "任务ID",
"detailTaskQueueId": "任务队列ID", "detailTaskQueueId": "任务队列ID",
@@ -1653,8 +1693,8 @@
"multiAgentPeLoop": "plan_execute 外层循环上限", "multiAgentPeLoop": "plan_execute 外层循环上限",
"multiAgentPeLoopPlaceholder": "0 表示 Eino 默认 10", "multiAgentPeLoopPlaceholder": "0 表示 Eino 默认 10",
"multiAgentPeLoopHint": "仅 plan_execute 有效;execute 与 replan 之间的最大轮次。", "multiAgentPeLoopHint": "仅 plan_execute 有效;execute 与 replan 之间的最大轮次。",
"multiAgentRobotUse": "企业微信 / 钉钉 / 飞书机器人也使用多代理", "multiAgentRobotMode": "机器人默认对话模式",
"multiAgentRobotUseHint": "需同时勾选「启用多代理」;调用量与成本更高。", "multiAgentRobotModeHint": "企业微信 / 钉钉 / 飞书机器人每条消息使用的执行模式;Deep / Plan-Execute / Supervisor 需启用多代理。",
"multiAgentBatchUse": "批量任务队列也使用多代理", "multiAgentBatchUse": "批量任务队列也使用多代理",
"multiAgentBatchUseHint": "开启后,任务管理中按队列执行的每个子任务将走 Eino DeepAgent(需启用多代理)。", "multiAgentBatchUseHint": "开启后,任务管理中按队列执行的每个子任务将走 Eino DeepAgent(需启用多代理)。",
"enableKnowledge": "启用知识检索功能", "enableKnowledge": "启用知识检索功能",
@@ -1746,6 +1786,106 @@
"close": "×", "close": "×",
"newTerminal": "+" "newTerminal": "+"
}, },
"settingsAudit": {
"title": "日志审计",
"description": "记录平台管理类操作(登录、配置、删除等),不记录对话正文、终端/WebShell 每次命令与工具调用明细。",
"filterCategory": "类别",
"filterAction": "操作",
"filterEvent": "事件类型",
"filterAllCategories": "全部类别",
"filterAllActions": "全部操作",
"filterCascadeHint": "选择类别后可筛选具体操作",
"filterResult": "结果",
"pageSize": "每页",
"statTotal": "当前筛选",
"statFailures": "失败",
"statRecent7d": "近 7 天",
"retentionHint": "审计记录保留 {{days}} 天,超期自动清理。",
"disabledHint": "审计功能已关闭,新操作不会写入审计表。",
"filterSince": "开始时间",
"filterUntil": "结束时间",
"filterQuery": "关键词",
"filterQueryPlaceholder": "消息 / 资源 ID / 操作名",
"cat": {
"auth": "认证",
"config": "配置",
"terminal": "终端",
"c2": "C2",
"webshell": "WebShell",
"knowledge": "知识库",
"conversation": "对话",
"vulnerability": "漏洞",
"externalMcp": "外部 MCP",
"task": "任务",
"tool": "工具",
"file": "文件",
"hitl": "人机协同",
"role": "角色",
"skill": "Skill",
"agent": "子代理"
},
"act": {
"login": "登录",
"logout": "登出",
"login_failed": "登录失败",
"password_change": "修改密码",
"change_password": "修改密码",
"apply": "应用配置",
"update": "更新",
"exec": "终端执行",
"exec_stream": "终端流式执行",
"listener_create": "创建监听器",
"listener_delete": "删除监听器",
"listener_start": "启动监听器",
"listener_stop": "停止监听器",
"session_delete": "删除会话",
"task_create": "创建任务",
"task_cancel": "取消任务",
"task_delete": "删除任务",
"connection_create": "创建连接",
"connection_delete": "删除连接",
"item_delete": "删除知识项",
"index_rebuild": "重建索引",
"delete": "删除",
"delete_turn": "删除轮次",
"create": "创建",
"upsert": "保存外部 MCP",
"create_queue": "创建批量队列",
"start_queue": "启动批量队列",
"delete_queue": "删除批量队列",
"pause_queue": "暂停批量队列",
"rerun_queue": "重跑批量队列",
"delete_batch_task": "删除批量子任务",
"execution_delete": "删除执行记录",
"execution_delete_batch": "批量删除执行",
"upload": "上传",
"decision": "HITL 决策",
"markdown_create": "创建子代理",
"markdown_update": "更新子代理",
"markdown_delete": "删除子代理"
},
"openResource": "打开关联资源",
"openResourceChat": "打开关联资源(chat",
"resourceIdLabel": "资源 ID",
"resourceRemoved": "(关联对象已删除)",
"filterAll": "全部",
"filterBtn": "筛选",
"resetBtn": "重置",
"exportBtn": "导出",
"exportJson": "导出 JSON",
"export": "导出 JSON",
"exportCsv": "导出 CSV",
"exportDone": "导出完成",
"loading": "加载中...",
"empty": "暂无审计记录",
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}} 条",
"detailTitle": "审计详情",
"detailTime": "时间",
"detailCategory": "类别",
"detailResult": "结果",
"detailMessage": "说明",
"detailSession": "会话"
},
"settingsSecurity": { "settingsSecurity": {
"changePasswordTitle": "修改密码", "changePasswordTitle": "修改密码",
"changePasswordDesc": "修改登录密码后,需要使用新密码重新登录。", "changePasswordDesc": "修改登录密码后,需要使用新密码重新登录。",
@@ -2015,6 +2155,9 @@
"add": "添加" "add": "添加"
}, },
"vulnerabilityModal": { "vulnerabilityModal": {
"project": "所属项目",
"projectNone": "(未绑定)",
"projectHint": "绑定后 Agent 在项目范围内可通过 list_vulnerabilities 看到本条记录;留空则尝试从会话自动关联。",
"conversationId": "会话ID", "conversationId": "会话ID",
"conversationIdPlaceholder": "输入会话ID", "conversationIdPlaceholder": "输入会话ID",
"conversationTag": "对话标签", "conversationTag": "对话标签",
+598
View File
@@ -0,0 +1,598 @@
/**
* 系统设置 - 平台操作审计日志
*/
let auditLogsPage = 1;
let auditLogsPageSize = 20;
let auditLogsTotal = 0;
const AUDIT_PAGE_SIZE_KEY = 'cyberstrike_audit_page_size';
/** 按类别列出的操作(用于 datalist 提示,避免超长下拉) */
const AUDIT_ACTIONS_BY_CATEGORY = {
auth: ['login', 'logout', 'change_password'],
config: ['apply', 'update'],
c2: ['listener_create', 'listener_delete', 'listener_start', 'listener_stop',
'session_delete', 'task_create', 'task_cancel', 'task_delete'],
webshell: ['connection_create', 'connection_delete'],
knowledge: ['item_delete', 'index_rebuild'],
conversation: ['create', 'delete', 'delete_turn'],
vulnerability: ['create', 'update', 'delete'],
external_mcp: ['upsert', 'delete'],
task: ['create_queue', 'start_queue', 'delete_queue', 'pause_queue', 'rerun_queue', 'delete_batch_task'],
tool: ['execution_delete', 'execution_delete_batch'],
file: ['upload', 'delete'],
hitl: ['decision'],
role: ['create', 'update', 'delete'],
skill: ['create', 'update', 'delete'],
agent: ['markdown_create', 'markdown_update', 'markdown_delete']
};
function auditT(key, opts, fallback) {
if (typeof t === 'function') {
const v = t(key, opts);
if (v && v !== key) return v;
}
return fallback != null ? fallback : key;
}
function auditCategoryI18nKey(category) {
if (!category) return '';
if (category === 'external_mcp') return 'externalMcp';
return category;
}
function auditCategoryLabel(category) {
if (!category) return '';
const key = 'settingsAudit.cat.' + auditCategoryI18nKey(category);
return auditT(key, null, category);
}
function auditActionLabel(action) {
if (!action) return '';
return auditT('settingsAudit.act.' + action, null, action);
}
function formatAuditTime(iso) {
if (!iso) return '';
try {
const d = new Date(iso);
if (Number.isNaN(d.getTime())) return iso;
return d.toLocaleString();
} catch (_) {
return iso;
}
}
function auditDatetimeLocalToRFC3339(value) {
if (!value || !value.trim()) return '';
const d = new Date(value);
if (Number.isNaN(d.getTime())) return '';
return d.toISOString();
}
function initAuditPageSizeFromStorage() {
try {
const saved = parseInt(localStorage.getItem(AUDIT_PAGE_SIZE_KEY), 10);
if ([10, 20, 50, 100].indexOf(saved) >= 0) {
auditLogsPageSize = saved;
}
} catch (_) { /* ignore */ }
const sel = document.getElementById('audit-page-size');
if (sel) sel.value = String(auditLogsPageSize);
}
function onAuditPageSizeChange() {
const sel = document.getElementById('audit-page-size');
if (!sel) return;
const n = parseInt(sel.value, 10);
if ([10, 20, 50, 100].indexOf(n) < 0) return;
auditLogsPageSize = n;
try {
localStorage.setItem(AUDIT_PAGE_SIZE_KEY, String(n));
} catch (_) { /* ignore */ }
auditLogsPage = 1;
loadAuditLogs(1);
}
function rebuildAuditActionSelect() {
const catEl = document.getElementById('audit-filter-category');
const actEl = document.getElementById('audit-filter-action');
if (!actEl) return;
const category = catEl ? catEl.value : '';
const prev = actEl.value;
const allLabel = auditT('settingsAudit.filterAllActions', null, '全部操作');
const hint = auditT('settingsAudit.filterCascadeHint', null, '选择类别后可筛选具体操作');
actEl.innerHTML = '';
const allOpt = document.createElement('option');
allOpt.value = '';
allOpt.textContent = allLabel;
actEl.appendChild(allOpt);
if (!category) {
actEl.disabled = true;
actEl.value = '';
actEl.title = hint;
return;
}
actEl.disabled = false;
actEl.title = '';
const actions = AUDIT_ACTIONS_BY_CATEGORY[category] || [];
actions.forEach(function (action) {
const opt = document.createElement('option');
opt.value = action;
opt.textContent = auditActionLabel(action);
actEl.appendChild(opt);
});
if (prev && Array.prototype.some.call(actEl.options, function (o) { return o.value === prev; })) {
actEl.value = prev;
}
}
function onAuditCategoryFilterChange() {
rebuildAuditActionSelect();
}
function buildAuditQueryParams(forExport) {
const params = new URLSearchParams();
if (!forExport) {
params.set('page', String(auditLogsPage));
params.set('page_size', String(auditLogsPageSize));
}
const cat = document.getElementById('audit-filter-category');
const act = document.getElementById('audit-filter-action');
const res = document.getElementById('audit-filter-result');
const q = document.getElementById('audit-filter-q');
const since = document.getElementById('audit-filter-since');
const until = document.getElementById('audit-filter-until');
if (cat && cat.value) params.set('category', cat.value);
if (act && !act.disabled && act.value) params.set('action', act.value);
if (res && res.value) params.set('result', res.value);
if (q && q.value.trim()) params.set('q', q.value.trim());
const sinceISO = since ? auditDatetimeLocalToRFC3339(since.value) : '';
const untilISO = until ? auditDatetimeLocalToRFC3339(until.value) : '';
if (sinceISO) params.set('since', sinceISO);
if (untilISO) params.set('until', untilISO);
return params.toString();
}
async function loadAuditMeta() {
if (typeof apiFetch !== 'function') return;
const hint = document.getElementById('audit-retention-hint');
try {
const r = await apiFetch('/api/audit/meta');
if (!r.ok) return;
const data = await r.json();
if (!hint) return;
if (!data.enabled) {
hint.hidden = false;
hint.textContent = auditT('settingsAudit.disabledHint', null, '审计功能已关闭,新操作不会写入审计表。');
return;
}
const days = data.retention_days;
if (days > 0) {
hint.hidden = false;
hint.textContent = auditT('settingsAudit.retentionHint', { days: days },
'审计记录保留 ' + days + ' 天,超期自动清理。');
} else {
hint.hidden = true;
}
} catch (_) { /* ignore */ }
}
async function loadAuditSummary() {
if (typeof apiFetch !== 'function') return;
const wrap = document.getElementById('audit-summary-stats');
try {
const r = await apiFetch('/api/audit/summary?' + buildAuditQueryParams(true));
if (!r.ok) return;
const data = await r.json();
if (wrap) wrap.hidden = false;
const elTotal = document.getElementById('audit-stat-total');
const elFail = document.getElementById('audit-stat-failures');
const elRecent = document.getElementById('audit-stat-recent');
if (elTotal) elTotal.textContent = String(data.total != null ? data.total : 0);
if (elFail) elFail.textContent = String(data.failures != null ? data.failures : 0);
if (elRecent) elRecent.textContent = String(data.recent_7d != null ? data.recent_7d : 0);
} catch (_) { /* ignore */ }
}
async function loadAuditLogs(page) {
if (typeof apiFetch !== 'function') return;
auditLogsPage = page != null ? page : auditLogsPage;
const listEl = document.getElementById('audit-log-list');
if (listEl) {
listEl.innerHTML = '<div class="loading-spinner">' + (typeof escapeHtml === 'function' ? escapeHtml(auditT('settingsAudit.loading', null, '加载中...')) : '加载中...') + '</div>';
}
try {
const qs = buildAuditQueryParams(false);
const r = await apiFetch('/api/audit/logs?' + qs);
if (!r.ok) {
const err = await r.json().catch(function () { return {}; });
throw new Error(err.error || r.statusText);
}
const data = await r.json();
renderAuditLogs(data.logs || []);
auditLogsTotal = typeof data.total === 'number' ? data.total : 0;
const maxPage = Math.max(1, Math.ceil(auditLogsTotal / auditLogsPageSize));
if (auditLogsPage > maxPage) {
loadAuditLogs(maxPage);
return;
}
renderAuditLogsPagination();
loadAuditSummary();
} catch (e) {
if (listEl) {
const msg = typeof escapeHtml === 'function' ? escapeHtml(e.message || String(e)) : (e.message || String(e));
listEl.innerHTML = '<div class="monitor-empty">' + msg + '</div>';
}
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
function renderAuditLogs(logs) {
const listEl = document.getElementById('audit-log-list');
if (!listEl) return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
if (!logs.length) {
listEl.innerHTML = '<div class="c2-empty">' + esc(auditT('settingsAudit.empty', null, '暂无审计记录')) + '</div>';
return;
}
listEl.innerHTML = logs.map(function (log) {
const lvl = log.result === 'failure' ? 'warn' : (log.level || 'info');
const catLabel = esc(auditCategoryLabel(log.category || ''));
const actionLabel = esc(auditActionLabel(log.action || ''));
const msg = esc(log.message || '');
const ip = esc(log.clientIp || '');
const when = esc(formatAuditTime(log.createdAt));
const res = esc(log.result || '');
const rid = log.resourceId || '';
const meta = rid ? (' · ' + esc(rid)) : '';
const eid = esc(log.id || '');
return (
'<div class="c2-event-item audit-log-item" role="button" tabindex="0" ' +
'onclick="showAuditLogDetail(\'' + eid + '\')" ' +
'onkeydown="if(event.key===\'Enter\'||event.key===\' \'){event.preventDefault();showAuditLogDetail(\'' + eid + '\')}">' +
'<div class="c2-event-level ' + esc(lvl) + '"></div>' +
'<div class="c2-event-content">' +
'<div class="c2-event-message">' + msg + '</div>' +
'<div class="c2-event-meta">' + when + ' · ' + catLabel + '/' + actionLabel + ' · ' + res + meta +
(ip ? ' · IP ' + ip : '') +
'</div></div></div>'
);
}).join('');
if (typeof applyTranslations === 'function') {
applyTranslations(listEl);
}
}
function renderAuditLogsPagination() {
const container = document.getElementById('audit-logs-pagination');
if (!container) return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
const total = auditLogsTotal || 0;
const currentPage = auditLogsPage || 1;
const pageSize = auditLogsPageSize || 20;
const totalPages = Math.max(1, Math.ceil(total / pageSize));
const start = total === 0 ? 0 : (currentPage - 1) * pageSize + 1;
const end = total === 0 ? 0 : Math.min(currentPage * pageSize, total);
const infoText = auditT('mcpMonitor.paginationInfo', { start: start, end: end, total: total },
'显示 ' + start + '-' + end + ' / 共 ' + total + ' 条记录');
const perPageLabel = auditT('mcpMonitor.perPageLabel', null, '每页显示');
const firstPageLabel = auditT('mcp.firstPage', null, '首页');
const prevPageLabel = auditT('mcp.prevPage', null, '上一页');
const pageInfoText = auditT('mcp.pageInfo', { page: currentPage, total: totalPages },
'第 ' + currentPage + ' / ' + totalPages + ' 页');
const nextPageLabel = auditT('mcp.nextPage', null, '下一页');
const lastPageLabel = auditT('mcp.lastPage', null, '末页');
const disabledFirst = currentPage === 1 || total === 0;
const disabledLast = currentPage >= totalPages || total === 0;
let html = '<div class="monitor-pagination">';
html += '<div class="pagination-info">';
html += '<span>' + esc(infoText) + '</span>';
html += '<label class="pagination-page-size">' + esc(perPageLabel);
html += '<select id="audit-page-size" onchange="onAuditPageSizeChange()">';
[10, 20, 50, 100].forEach(function (n) {
html += '<option value="' + n + '"' + (pageSize === n ? ' selected' : '') + '>' + n + '</option>';
});
html += '</select></label></div>';
html += '<div class="pagination-controls">';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(1)"' + (disabledFirst ? ' disabled' : '') + '>' + esc(firstPageLabel) + '</button>';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(' + (currentPage - 1) + ')"' + (disabledFirst ? ' disabled' : '') + '>' + esc(prevPageLabel) + '</button>';
html += '<span class="pagination-page">' + esc(pageInfoText) + '</span>';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(' + (currentPage + 1) + ')"' + (disabledLast ? ' disabled' : '') + '>' + esc(nextPageLabel) + '</button>';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(' + totalPages + ')"' + (disabledLast ? ' disabled' : '') + '>' + esc(lastPageLabel) + '</button>';
html += '</div></div>';
container.innerHTML = html;
}
function goAuditLogsPage(p) {
const totalPages = Math.max(1, Math.ceil((auditLogsTotal || 0) / (auditLogsPageSize || 20)));
if (p < 1 || p > totalPages) return;
loadAuditLogs(p);
}
function filterAuditLogs() {
auditLogsPage = 1;
loadAuditLogs(1);
}
function resetAuditLogFilters() {
const cat = document.getElementById('audit-filter-category');
const act = document.getElementById('audit-filter-action');
const res = document.getElementById('audit-filter-result');
const q = document.getElementById('audit-filter-q');
const since = document.getElementById('audit-filter-since');
const until = document.getElementById('audit-filter-until');
if (cat) cat.value = '';
if (res) res.value = '';
if (q) q.value = '';
if (since) since.value = '';
if (until) until.value = '';
rebuildAuditActionSelect();
filterAuditLogs();
}
/** 资源已被删除/移除的审计操作,不再提供「打开关联资源」 */
const AUDIT_ACTIONS_RESOURCE_REMOVED = {
delete: true,
item_delete: true,
connection_delete: true,
listener_delete: true,
session_delete: true,
task_delete: true,
execution_delete: true,
execution_delete_batch: true,
delete_queue: true,
delete_batch_task: true,
markdown_delete: true
};
function auditResourceWasRemoved(log) {
if (!log || !log.action) return false;
return !!AUDIT_ACTIONS_RESOURCE_REMOVED[log.action];
}
/** 删除类操作,或关联资源已不存在(由详情 API resourceAvailable 判定) */
function auditResourceUnavailable(log) {
if (!log) return false;
if (auditResourceWasRemoved(log)) return true;
return log.resourceAvailable === false;
}
function auditResourceMeta(log) {
if (!log || !log.resourceId) return '';
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
const id = esc(log.resourceId);
if (auditResourceUnavailable(log)) {
const idLabel = esc(auditT('settingsAudit.resourceIdLabel', null, '资源 ID'));
const removed = esc(auditT('settingsAudit.resourceRemoved', null, '(关联对象已删除)'));
return '<p class="audit-resource-meta"><strong>' + idLabel + ':</strong> <code>' + id +
'</code> <span class="audit-resource-removed">' + removed + '</span></p>';
}
const link = auditResourceLink(log);
return link || ('<p><strong>ID:</strong> ' + id + '</p>');
}
async function auditOpenConversationChat(conversationId) {
const id = String(conversationId || '').trim();
if (!id) return;
if (typeof apiFetch === 'function') {
try {
const r = await apiFetch('/api/conversations/' + encodeURIComponent(id));
if (!r.ok) {
if (typeof showToast === 'function') {
showToast(auditT('settingsAudit.resourceRemoved', null, '(关联对象已删除)'), 'warning');
}
return;
}
} catch (_) {
return;
}
}
closeAuditDetailModal();
if (typeof switchPage === 'function') {
switchPage('chat');
}
if (typeof loadConversation === 'function') {
void loadConversation(id);
}
}
window.auditOpenConversationChat = auditOpenConversationChat;
function auditResourceLink(log) {
if (!log || auditResourceUnavailable(log)) return '';
const type = log.resourceType || '';
const id = log.resourceId || '';
if (!id) return '';
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
const label = esc(auditT('settingsAudit.openResource', null, '打开关联资源'));
if (type === 'conversation' || (type === '' && id.length > 8 && !id.startsWith('c2_'))) {
const chatLabel = esc(auditT('settingsAudit.openResourceChat', null, '打开关联资源(chat'));
return '<p><button type="button" class="btn-secondary btn-small audit-open-chat-btn" data-conversation-id="' +
esc(id) + '">' + chatLabel + '</button></p>';
}
if (type === 'vulnerability' || type === 'batch_queue') {
const page = type === 'batch_queue' ? 'tasks' : 'vulnerabilities';
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'' + page + '\');}">' + label + '</button></p>';
}
if (type === 'c2_listener' || type === 'c2_session' || type === 'c2_task') {
const page = type === 'c2_listener' ? 'c2-listeners' : (type === 'c2_session' ? 'c2-sessions' : 'c2-tasks');
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'' + page + '\');}">' + label + '</button></p>';
}
if (type === 'webshell_connection') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'webshell\');}">' + label + '</button></p>';
}
if (type === 'knowledge_item') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'knowledge-management\');}">' + label + '</button></p>';
}
if (type === 'chat_upload') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'chat-files\');}">' + label + '</button></p>';
}
if (type === 'tool_execution') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'mcp-monitor\');}">' + label + '</button></p>';
}
if (type === 'role' || type === 'skill' || type === 'markdown_agent') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchSettingsSection===\'function\'){switchPage(\'settings\');switchSettingsSection(\'roles\');}">' + label + '</button></p>';
}
return '';
}
function refreshAuditLogs() {
loadAuditLogs(auditLogsPage);
}
async function downloadAuditExport(url, filename) {
const r = await apiFetch(url);
if (!r.ok) {
const err = await r.json().catch(function () { return {}; });
throw new Error(err.error || r.statusText);
}
const blob = await r.blob();
const objectUrl = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = objectUrl;
a.download = filename;
a.click();
URL.revokeObjectURL(objectUrl);
}
function closeAuditExportMenu() {
const menu = document.getElementById('audit-export-menu');
const trigger = document.getElementById('audit-export-trigger');
if (menu) menu.hidden = true;
if (trigger) trigger.setAttribute('aria-expanded', 'false');
}
function toggleAuditExportMenu(ev) {
if (ev && ev.stopPropagation) ev.stopPropagation();
const menu = document.getElementById('audit-export-menu');
const trigger = document.getElementById('audit-export-trigger');
if (!menu) return;
const willOpen = menu.hidden;
if (willOpen) {
menu.hidden = false;
if (trigger) trigger.setAttribute('aria-expanded', 'true');
if (!window._auditExportMenuDocBound) {
window._auditExportMenuDocBound = true;
document.addEventListener('click', function () {
closeAuditExportMenu();
});
}
} else {
closeAuditExportMenu();
}
}
async function runAuditExport(format) {
closeAuditExportMenu();
if (format === 'csv') {
await exportAuditLogsCsv();
} else {
await exportAuditLogs();
}
}
async function exportAuditLogs() {
if (typeof apiFetch !== 'function') return;
try {
await downloadAuditExport(
'/api/audit/logs/export?' + buildAuditQueryParams(true),
'audit-logs-' + new Date().toISOString().slice(0, 10) + '.json'
);
if (typeof showToast === 'function') {
showToast(auditT('settingsAudit.exportDone', null, '导出完成'), 'success');
}
} catch (e) {
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
async function exportAuditLogsCsv() {
if (typeof apiFetch !== 'function') return;
try {
const qs = buildAuditQueryParams(true);
await downloadAuditExport(
'/api/audit/logs/export?' + (qs ? qs + '&' : '') + 'format=csv',
'audit-logs-' + new Date().toISOString().slice(0, 10) + '.csv'
);
if (typeof showToast === 'function') {
showToast(auditT('settingsAudit.exportDone', null, '导出完成'), 'success');
}
} catch (e) {
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
function closeAuditDetailModal() {
const el = document.getElementById('audit-detail-modal');
if (el) el.remove();
}
async function showAuditLogDetail(id) {
if (!id || typeof apiFetch !== 'function') return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
try {
const r = await apiFetch('/api/audit/logs/' + encodeURIComponent(id));
if (!r.ok) throw new Error('not found');
const data = await r.json();
const log = data.log || {};
const detail = log.detail ? JSON.stringify(log.detail, null, 2) : '';
closeAuditDetailModal();
const overlay = document.createElement('div');
overlay.id = 'audit-detail-modal';
overlay.className = 'modal';
overlay.style.display = 'block';
const catAction = esc(auditCategoryLabel(log.category || '')) + ' / ' + esc(auditActionLabel(log.action || ''));
overlay.innerHTML =
'<div class="modal-content" style="max-width: 720px;">' +
'<div class="modal-header">' +
'<h2>' + esc(auditT('settingsAudit.detailTitle', null, '审计详情')) + '</h2>' +
'<span class="modal-close" onclick="closeAuditDetailModal()">&times;</span>' +
'</div>' +
'<div class="modal-body audit-detail-body">' +
'<p><strong>' + esc(auditT('settingsAudit.detailTime', null, '时间')) + ':</strong> ' + esc(formatAuditTime(log.createdAt)) + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailCategory', null, '类别')) + ':</strong> ' + catAction + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailResult', null, '结果')) + ':</strong> ' + esc(log.result || '') + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailMessage', null, '说明')) + ':</strong> ' + esc(log.message || '') + '</p>' +
(log.clientIp ? '<p><strong>IP:</strong> ' + esc(log.clientIp) + '</p>' : '') +
(log.sessionHint ? '<p><strong>' + esc(auditT('settingsAudit.detailSession', null, '会话')) + ':</strong> ' + esc(log.sessionHint) + '</p>' : '') +
(log.userAgent ? '<p><strong>UA:</strong> ' + esc(log.userAgent) + '</p>' : '') +
auditResourceMeta(log) +
(detail ? '<pre class="audit-detail-pre">' + esc(detail) + '</pre>' : '') +
'</div>' +
'<div class="modal-footer"><button type="button" class="btn-secondary" onclick="closeAuditDetailModal()">' +
esc(auditT('common.close', null, '关闭')) + '</button></div>' +
'</div>';
document.body.appendChild(overlay);
const chatBtn = overlay.querySelector('.audit-open-chat-btn');
if (chatBtn) {
chatBtn.addEventListener('click', function () {
auditOpenConversationChat(chatBtn.getAttribute('data-conversation-id'));
});
}
overlay.addEventListener('click', function (ev) {
if (ev.target === overlay) closeAuditDetailModal();
});
} catch (e) {
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
function initAuditLogsSection() {
if (!document.getElementById('audit-log-list')) return;
initAuditPageSizeFromStorage();
rebuildAuditActionSelect();
loadAuditMeta();
loadAuditLogs(1);
}
+7
View File
@@ -282,6 +282,13 @@ async function submitLogin(event) {
} }
async function refreshAppData(showTaskErrors = false) { async function refreshAppData(showTaskErrors = false) {
if (typeof initChatAgentModeFromConfig === 'function') {
try {
await initChatAgentModeFromConfig();
} catch (error) {
console.warn('刷新对话模式配置失败:', error);
}
}
await Promise.allSettled([ await Promise.allSettled([
loadConversations(), loadConversations(),
loadActiveTasks(showTaskErrors), loadActiveTasks(showTaskErrors),
+15 -4
View File
@@ -1052,15 +1052,26 @@ async function saveChatFilesEdit() {
function openChatFilesRename(relativePath, currentName) { function openChatFilesRename(relativePath, currentName) {
chatFilesRenameRelativePath = relativePath; chatFilesRenameRelativePath = relativePath;
const input = document.getElementById('chat-files-rename-input'); const input = document.getElementById('chat-files-rename-input');
const hint = document.getElementById('chat-files-rename-path-hint');
const modal = document.getElementById('chat-files-rename-modal'); const modal = document.getElementById('chat-files-rename-modal');
if (input) input.value = currentName || ''; const pathText = relativePath ? ('chat_uploads/' + String(relativePath).replace(/\\/g, '/')) : 'chat_uploads';
if (modal) modal.style.display = 'block'; if (hint) hint.textContent = pathText;
setTimeout(() => { if (input) input.focus(); }, 100); if (input) {
input.value = currentName || '';
input.select();
}
if (modal) modal.style.display = 'flex';
if (modal && typeof window.applyTranslations === 'function') {
window.applyTranslations(modal);
}
setTimeout(() => { if (input) { input.focus(); input.select(); } }, 100);
} }
function closeChatFilesRenameModal() { function closeChatFilesRenameModal() {
const modal = document.getElementById('chat-files-rename-modal'); const modal = document.getElementById('chat-files-rename-modal');
if (modal) modal.style.display = 'none'; if (modal) modal.style.display = 'none';
const hint = document.getElementById('chat-files-rename-path-hint');
if (hint) hint.textContent = '';
chatFilesRenameRelativePath = ''; chatFilesRenameRelativePath = '';
} }
@@ -1095,7 +1106,7 @@ function openChatFilesMkdirModal() {
const p = chatFilesBrowsePath.join('/'); const p = chatFilesBrowsePath.join('/');
if (hint) hint.textContent = p ? ('chat_uploads/' + p) : 'chat_uploads'; if (hint) hint.textContent = p ? ('chat_uploads/' + p) : 'chat_uploads';
if (input) input.value = ''; if (input) input.value = '';
if (modal) modal.style.display = 'block'; if (modal) modal.style.display = 'flex';
if (modal && typeof window.applyTranslations === 'function') { if (modal && typeof window.applyTranslations === 'function') {
window.applyTranslations(modal); window.applyTranslations(modal);
} }
+78 -9
View File
@@ -574,7 +574,7 @@ function restoreChatReasoningControlsFromStorage() {
} }
if (e) { if (e) {
const v = localStorage.getItem(REASONING_EFFORT_LS); const v = localStorage.getItem(REASONING_EFFORT_LS);
if (v !== null && ['', 'low', 'medium', 'high', 'max'].indexOf(v) !== -1) { if (v !== null && ['', 'low', 'medium', 'high', 'max', 'xhigh'].indexOf(v) !== -1) {
e.value = v; e.value = v;
} }
} }
@@ -646,6 +646,9 @@ function toggleAgentModePanel() {
if (typeof closeRoleSelectionPanel === 'function') { if (typeof closeRoleSelectionPanel === 'function') {
closeRoleSelectionPanel(); closeRoleSelectionPanel();
} }
if (typeof closeChatProjectPanel === 'function') {
closeChatProjectPanel();
}
panel.style.display = 'flex'; panel.style.display = 'flex';
btn.classList.add('active'); btn.classList.add('active');
btn.setAttribute('aria-expanded', 'true'); btn.setAttribute('aria-expanded', 'true');
@@ -662,6 +665,29 @@ function selectAgentMode(mode) {
} }
async function initChatAgentModeFromConfig() { async function initChatAgentModeFromConfig() {
const wrap = document.getElementById('agent-mode-wrapper');
const sel = document.getElementById('agent-mode-select');
if (!wrap || !sel) return;
// 先展示基础模式,避免首次登录时配置接口短暂失败导致入口被隐藏。
wrap.style.display = '';
let stored = localStorage.getItem(AGENT_MODE_STORAGE_KEY);
if (!(stored === CHAT_AGENT_MODE_REACT || chatAgentModeIsEinoSingle(stored) || chatAgentModeIsEino(stored))) {
stored = CHAT_AGENT_MODE_REACT;
}
sel.value = stored;
syncAgentModeFromValue(stored);
document.querySelectorAll('.agent-mode-option').forEach(function (el) {
const v = el.getAttribute('data-value');
if (v === 'deep' || v === 'plan_execute' || v === 'supervisor') {
el.style.display = 'none';
} else {
el.style.display = '';
}
});
restoreChatReasoningControlsFromStorage();
syncReasoningRowVisibility(stored);
try { try {
const r = await apiFetch('/api/config'); const r = await apiFetch('/api/config');
if (!r.ok) return; if (!r.ok) return;
@@ -674,10 +700,6 @@ async function initChatAgentModeFromConfig() {
window.csaiHitlGlobalToolWhitelist = tw.slice(); window.csaiHitlGlobalToolWhitelist = tw.slice();
} }
} }
const wrap = document.getElementById('agent-mode-wrapper');
const sel = document.getElementById('agent-mode-select');
if (!wrap || !sel) return;
wrap.style.display = '';
document.querySelectorAll('.agent-mode-option').forEach(function (el) { document.querySelectorAll('.agent-mode-option').forEach(function (el) {
const v = el.getAttribute('data-value'); const v = el.getAttribute('data-value');
if (v === 'deep' || v === 'plan_execute' || v === 'supervisor') { if (v === 'deep' || v === 'plan_execute' || v === 'supervisor') {
@@ -686,7 +708,6 @@ async function initChatAgentModeFromConfig() {
el.style.display = ''; el.style.display = '';
} }
}); });
let stored = localStorage.getItem(AGENT_MODE_STORAGE_KEY);
stored = chatAgentModeNormalizeStored(stored, cfg); stored = chatAgentModeNormalizeStored(stored, cfg);
try { try {
localStorage.setItem(AGENT_MODE_STORAGE_KEY, stored); localStorage.setItem(AGENT_MODE_STORAGE_KEY, stored);
@@ -879,6 +900,10 @@ async function sendMessage() {
conversationId: currentConversationId, conversationId: currentConversationId,
role: typeof getCurrentRole === 'function' ? getCurrentRole() : '' role: typeof getCurrentRole === 'function' ? getCurrentRole() : ''
}; };
if (!currentConversationId && typeof getActiveProjectId === 'function') {
const pid = getActiveProjectId();
if (pid) body.projectId = pid;
}
const hitlCfg = readHitlConfigFromForm(); const hitlCfg = readHitlConfigFromForm();
if (normalizeHitlMode(hitlCfg.mode) !== HITL_MODE_OFF) { if (normalizeHitlMode(hitlCfg.mode) !== HITL_MODE_OFF) {
const sensitiveTools = hitlToolsSplitToArray(hitlCfg.sensitiveTools || ''); const sensitiveTools = hitlToolsSplitToArray(hitlCfg.sensitiveTools || '');
@@ -2357,6 +2382,9 @@ function renderProcessDetails(messageId, processDetails) {
detailsContainer.dataset.lazyNotLoaded = '0'; detailsContainer.dataset.lazyNotLoaded = '0';
detailsContainer.dataset.loaded = '1'; detailsContainer.dataset.loaded = '1';
processDetails = dedupeConsecutiveProcessDetailRows(processDetails); processDetails = dedupeConsecutiveProcessDetailRows(processDetails);
if (typeof window.coalesceProcessDetailsToolPairs === 'function') {
processDetails = window.coalesceProcessDetailsToolPairs(processDetails);
}
// 如果没有processDetails或为空,显示空状态 // 如果没有processDetails或为空,显示空状态
if (!processDetails || processDetails.length === 0) { if (!processDetails || processDetails.length === 0) {
// 显示空状态提示 // 显示空状态提示
@@ -2421,7 +2449,10 @@ function renderProcessDetails(messageId, processDetails) {
const toolName = data.toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具'); const toolName = data.toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具');
const index = data.index || 0; const index = data.index || 0;
const total = data.total || 0; const total = data.total || 0;
itemTitle = agPx + '🔧 ' + (typeof window.t === 'function' ? window.t('chat.callTool', { name: escapeHtml(toolName), index: index, total: total }) : '调用工具: ' + escapeHtml(toolName) + ' (' + index + '/' + total + ')'); const callTitle = typeof window.formatToolCallTimelineTitle === 'function'
? window.formatToolCallTimelineTitle(toolName, index, total)
: (typeof window.t === 'function' ? window.t('chat.callTool', { name: escapeHtml(toolName), index: index, total: total }) : '调用工具: ' + escapeHtml(toolName) + ' (' + index + '/' + total + ')');
itemTitle = agPx + '🔧 ' + callTitle;
} else if (eventType === 'tool_result') { } else if (eventType === 'tool_result') {
const toolName = data.toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具'); const toolName = data.toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具');
const success = data.success !== false; const success = data.success !== false;
@@ -2451,12 +2482,16 @@ function renderProcessDetails(messageId, processDetails) {
: '⏸️ 用户中断并继续'; : '⏸️ 用户中断并继续';
} }
addTimelineItem(timeline, eventType, { const timelineOpts = {
title: itemTitle, title: itemTitle,
message: detail.message || '', message: detail.message || '',
data: data, data: data,
createdAt: detail.createdAt // 传递实际的事件创建时间 createdAt: detail.createdAt // 传递实际的事件创建时间
}); };
if (eventType === 'tool_call' && data._mergedResult) {
timelineOpts.mergedResult = data._mergedResult;
}
addTimelineItem(timeline, eventType, timelineOpts);
}); });
// 检查是否有错误或取消事件,如果有,确保详情默认折叠(但仍有待审批 HITL 时保持展开,由 restoreHitlInlineForConversation 处理) // 检查是否有错误或取消事件,如果有,确保详情默认折叠(但仍有待审批 HITL 时保持展开,由 restoreHitlInlineForConversation 处理)
@@ -2872,10 +2907,14 @@ async function startNewConversation() {
} }
currentConversationId = null; currentConversationId = null;
window._loadedConversationProjectId = '';
try { try {
window.currentConversationId = ''; window.currentConversationId = '';
} catch (e) { /* ignore */ } } catch (e) { /* ignore */ }
currentConversationGroupId = null; // 新对话不属于任何分组 currentConversationGroupId = null; // 新对话不属于任何分组
if (typeof refreshChatProjectSelector === 'function') {
refreshChatProjectSelector();
}
document.getElementById('chat-messages').innerHTML = ''; document.getElementById('chat-messages').innerHTML = '';
const readyMsgNew = typeof window.t === 'function' ? window.t('chat.systemReadyMessage') : '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。'; const readyMsgNew = typeof window.t === 'function' ? window.t('chat.systemReadyMessage') : '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。';
addMessage('assistant', readyMsgNew, null, null, null, { systemReadyMessage: true }); addMessage('assistant', readyMsgNew, null, null, null, { systemReadyMessage: true });
@@ -3054,6 +3093,27 @@ function getConversationGroup(dateObj, todayStart, sevenDaysCutoff, yesterdaySta
} }
// 加载对话 // 加载对话
/** 轻量加载会话后,拉取最后一条助手消息的 process_details(机器人等无 SSE 场景) */
async function prefetchLastAssistantProcessDetails() {
const nodes = document.querySelectorAll('#chat-messages .message.assistant');
if (!nodes.length) return;
const last = nodes[nodes.length - 1];
if (!last || !last.id) return;
const container = document.getElementById('process-details-' + last.id);
if (!container || container.dataset.lazyNotLoaded !== '1') return;
const backendId = last.dataset && last.dataset.backendMessageId;
if (!backendId || typeof apiFetch !== 'function') return;
const res = await apiFetch('/api/messages/' + encodeURIComponent(String(backendId)) + '/process-details');
const j = await res.json().catch(() => ({}));
if (!res.ok || !Array.isArray(j.processDetails) || j.processDetails.length === 0) return;
if (typeof renderProcessDetails === 'function') {
renderProcessDetails(last.id, j.processDetails);
}
if (typeof window.expandProcessDetailsTimeline === 'function') {
window.expandProcessDetailsTimeline(last.id);
}
}
async function loadConversation(conversationId) { async function loadConversation(conversationId) {
const seq = ++loadConversationRequestSeq; const seq = ++loadConversationRequestSeq;
try { try {
@@ -3109,9 +3169,13 @@ async function loadConversation(conversationId) {
// 更新当前对话ID // 更新当前对话ID
currentConversationId = conversationId; currentConversationId = conversationId;
window._loadedConversationProjectId = conversation.projectId || conversation.project_id || '';
try { try {
window.currentConversationId = conversationId; window.currentConversationId = conversationId;
} catch (e) { /* ignore */ } } catch (e) { /* ignore */ }
if (typeof refreshChatProjectSelector === 'function') {
refreshChatProjectSelector();
}
if (typeof window.syncHitlConfigFromServer === 'function') { if (typeof window.syncHitlConfigFromServer === 'function') {
await window.syncHitlConfigFromServer(conversationId); await window.syncHitlConfigFromServer(conversationId);
} else { } else {
@@ -3273,6 +3337,11 @@ async function loadConversation(conversationId) {
.catch((e) => { .catch((e) => {
console.warn('attachRunningTaskEventStream on loadConversation failed', e); console.warn('attachRunningTaskEventStream on loadConversation failed', e);
}); });
} else if (seq === loadConversationRequestSeq && currentConversationId === conversationId) {
// 机器人等非 Web 流式来源:会话已结束或未注册任务时,按需拉取最后一条助手消息的过程详情
prefetchLastAssistantProcessDetails().catch((e) => {
console.warn('prefetchLastAssistantProcessDetails failed', e);
});
} }
} catch (error) { } catch (error) {
console.error('加载对话失败:', error); console.error('加载对话失败:', error);
+6 -1
View File
@@ -1028,7 +1028,12 @@ async function batchScanSelectedFofaRows() {
const resp = await apiFetch('/api/batch-tasks', { const resp = await apiFetch('/api/batch-tasks', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ title, tasks, role }) body: JSON.stringify({
title,
tasks,
role,
projectId: typeof getActiveProjectId === 'function' ? getActiveProjectId() || '' : '',
})
}); });
const result = await resp.json().catch(() => ({})); const result = await resp.json().catch(() => ({}));
if (!resp.ok) { if (!resp.ok) {
+292 -71
View File
@@ -410,6 +410,34 @@ if (typeof window !== 'undefined') {
window.normalizeStreamingDeltaJs = normalizeStreamingDeltaJs; window.normalizeStreamingDeltaJs = normalizeStreamingDeltaJs;
} }
/**
* SSE data.accumulated服务端权威流式全文有则直接用作 buffer避免双端 normalize 叠字
* @param {object|null|undefined} data
* @returns {string|null} 有快照时返回全文否则 null回退 delta 归一化
*/
function streamBufferFromAccumulated(data) {
if (!data || data.accumulated == null) {
return null;
}
return String(data.accumulated);
}
/**
* @returns {string} 合并后的 buffer
*/
function mergeStreamBuffer(current, delta, data) {
const acc = streamBufferFromAccumulated(data);
if (acc !== null) {
return acc;
}
return normalizeStreamingDeltaJs(current, delta)[0];
}
if (typeof window !== 'undefined') {
window.streamBufferFromAccumulated = streamBufferFromAccumulated;
window.mergeStreamBuffer = mergeStreamBuffer;
}
/** 流式 delta:纯文本,避免每条全量 marked + DOMPurify */ /** 流式 delta:纯文本,避免每条全量 marked + DOMPurify */
function setTimelineItemContentStreamPlain(contentEl, text) { function setTimelineItemContentStreamPlain(contentEl, text) {
if (!contentEl) return; if (!contentEl) return;
@@ -1411,8 +1439,7 @@ function handleStreamEvent(event, progressElement, progressId,
const s = state.get(streamId); const s = state.get(streamId);
const delta = event.message || ''; const delta = event.message || '';
const merged = normalizeStreamingDeltaJs(s.buffer, delta); s.buffer = mergeStreamBuffer(s.buffer, delta, d);
s.buffer = merged[0];
const item = document.getElementById(s.itemId); const item = document.getElementById(s.itemId);
if (item) { if (item) {
@@ -1565,7 +1592,7 @@ function handleStreamEvent(event, progressElement, progressId,
const index = toolInfo.index || 0; const index = toolInfo.index || 0;
const total = toolInfo.total || 0; const total = toolInfo.total || 0;
const toolCallId = toolInfo.toolCallId || null; const toolCallId = toolInfo.toolCallId || null;
const toolCallTitle = typeof window.t === 'function' ? window.t('chat.callTool', { name: escapeHtml(toolName), index: index, total: total }) : '调用工具: ' + escapeHtml(toolName) + ' (' + index + '/' + total + ')'; const toolCallTitle = formatToolCallTimelineTitle(toolName, index, total);
const toolCallItemId = addTimelineItem(timeline, 'tool_call', { const toolCallItemId = addTimelineItem(timeline, 'tool_call', {
title: timelineAgentBracketPrefix(toolInfo) + '🔧 ' + toolCallTitle, title: timelineAgentBracketPrefix(toolInfo) + '🔧 ' + toolCallTitle,
message: event.message, message: event.message,
@@ -1593,44 +1620,33 @@ function handleStreamEvent(event, progressElement, progressId,
const key = toolResultStreamKey(progressId, toolCallId); const key = toolResultStreamKey(progressId, toolCallId);
let state = toolResultStreamStateByKey.get(key); let state = toolResultStreamStateByKey.get(key);
const toolNameDelta = deltaInfo.toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具');
const deltaText = event.message || ''; const deltaText = event.message || '';
if (!deltaText) break; if (!deltaText) break;
if (!state) { if (!state) {
// 首次增量:创建一个 tool_result 占位条目,后续不断更新 pre 内容 const mapping = toolCallStatusMap.get(toolCallId);
const runningLabel = typeof window.t === 'function' ? window.t('timeline.running') : '执行中...'; let callItemId = mapping && mapping.itemId ? mapping.itemId : null;
const title = timelineAgentBracketPrefix(deltaInfo) + '⏳ ' + (typeof window.t === 'function' if (callItemId) {
? window.t('timeline.running') const callItem = document.getElementById(callItemId);
: runningLabel) + ' ' + (typeof window.t === 'function' ? window.t('chat.callTool', { name: escapeHtmlLocal(toolNameDelta), index: deltaInfo.index || 0, total: deltaInfo.total || 0 }) : toolNameDelta); if (callItem) {
ensureToolCallResultSlot(callItem);
const itemId = addTimelineItem(timeline, 'tool_result', { const section = callItem.querySelector('.tool-result-section');
title: title, if (section) {
message: '', section.classList.remove('pending');
data: { section.className = 'tool-result-section success';
toolName: toolNameDelta, }
success: true, }
isError: false, }
result: deltaText, state = { itemId: callItemId, buffer: '', onCallItem: !!callItemId };
toolCallId: toolCallId,
index: deltaInfo.index,
total: deltaInfo.total,
iteration: deltaInfo.iteration,
einoAgent: deltaInfo.einoAgent,
source: deltaInfo.source
},
expanded: false
});
state = { itemId, buffer: '' };
toolResultStreamStateByKey.set(key, state); toolResultStreamStateByKey.set(key, state);
} }
state.buffer += deltaText; state.buffer += deltaText;
const item = document.getElementById(state.itemId); const item = state.itemId ? document.getElementById(state.itemId) : null;
if (item) { if (item) {
const pre = item.querySelector('pre.tool-result'); const pre = item.querySelector('pre.tool-result');
if (pre) { if (pre) {
pre.classList.remove('tool-result-pending');
pre.textContent = state.buffer; pre.textContent = state.buffer;
} }
} }
@@ -1645,34 +1661,23 @@ function handleStreamEvent(event, progressElement, progressId,
const resultToolCallId = resultInfo.toolCallId || null; const resultToolCallId = resultInfo.toolCallId || null;
const resultExecText = success ? (typeof window.t === 'function' ? window.t('chat.toolExecComplete', { name: escapeHtml(resultToolName) }) : '工具 ' + escapeHtml(resultToolName) + ' 执行完成') : (typeof window.t === 'function' ? window.t('chat.toolExecFailed', { name: escapeHtml(resultToolName) }) : '工具 ' + escapeHtml(resultToolName) + ' 执行失败'); const resultExecText = success ? (typeof window.t === 'function' ? window.t('chat.toolExecComplete', { name: escapeHtml(resultToolName) }) : '工具 ' + escapeHtml(resultToolName) + ' 执行完成') : (typeof window.t === 'function' ? window.t('chat.toolExecFailed', { name: escapeHtml(resultToolName) }) : '工具 ' + escapeHtml(resultToolName) + ' 执行失败');
// 若此 tool 已经流式推送过增量,则复用占位条目并更新最终结果,避免重复添加一条
if (resultToolCallId) { if (resultToolCallId) {
const key = toolResultStreamKey(progressId, resultToolCallId); const key = toolResultStreamKey(progressId, resultToolCallId);
const state = toolResultStreamStateByKey.get(key); const streamState = toolResultStreamStateByKey.get(key);
if (state && state.itemId) { if (streamState && streamState.itemId) {
const item = document.getElementById(state.itemId); const streamCallItem = document.getElementById(streamState.itemId);
if (item) { if (streamCallItem) {
const pre = item.querySelector('pre.tool-result'); mergeToolResultIntoCallItem(streamCallItem, resultInfo);
const resultVal = resultInfo.result || resultInfo.error || '';
if (pre) pre.textContent = typeof resultVal === 'string' ? resultVal : JSON.stringify(resultVal);
const section = item.querySelector('.tool-result-section');
if (section) {
section.className = 'tool-result-section ' + (success ? 'success' : 'error');
}
const titleEl = item.querySelector('.timeline-item-title');
if (titleEl) {
if (resultInfo.einoAgent != null && String(resultInfo.einoAgent).trim() !== '') {
item.dataset.einoAgent = String(resultInfo.einoAgent).trim();
}
titleEl.textContent = timelineAgentBracketPrefix(resultInfo) + statusIcon + ' ' + resultExecText;
}
} }
toolResultStreamStateByKey.delete(key); toolResultStreamStateByKey.delete(key);
if (toolCallStatusMap.has(resultToolCallId)) {
// 同时更新 tool_call 的状态 updateToolCallStatus(resultToolCallId, success ? 'completed' : 'failed');
if (resultToolCallId && toolCallStatusMap.has(resultToolCallId)) { toolCallStatusMap.delete(resultToolCallId);
}
break;
}
if (attachToolResultToCall(resultToolCallId, resultInfo)) {
if (toolCallStatusMap.has(resultToolCallId)) {
updateToolCallStatus(resultToolCallId, success ? 'completed' : 'failed'); updateToolCallStatus(resultToolCallId, success ? 'completed' : 'failed');
toolCallStatusMap.delete(resultToolCallId); toolCallStatusMap.delete(resultToolCallId);
} }
@@ -1730,12 +1735,11 @@ function handleStreamEvent(event, progressElement, progressId,
const streamId = d.streamId || null; const streamId = d.streamId || null;
if (!streamId) break; if (!streamId) break;
const delta = event.message || ''; const delta = event.message || '';
if (!delta) break; if (!delta && streamBufferFromAccumulated(d) === null) break;
const stateMap = einoAgentReplyStreamStateByProgressId.get(progressId); const stateMap = einoAgentReplyStreamStateByProgressId.get(progressId);
if (!stateMap || !stateMap.has(streamId)) break; if (!stateMap || !stateMap.has(streamId)) break;
const s = stateMap.get(streamId); const s = stateMap.get(streamId);
const merged = normalizeStreamingDeltaJs(s.buffer, delta); s.buffer = mergeStreamBuffer(s.buffer, delta, d);
s.buffer = merged[0];
const item = document.getElementById(s.itemId); const item = document.getElementById(s.itemId);
if (item) { if (item) {
let contentEl = item.querySelector('.timeline-item-content'); let contentEl = item.querySelector('.timeline-item-content');
@@ -1921,8 +1925,8 @@ function handleStreamEvent(event, progressElement, progressId,
} }
const deltaContent = event.message || ''; const deltaContent = event.message || '';
const mergedResp = normalizeStreamingDeltaJs(state.buffer, deltaContent); if (!deltaContent && streamBufferFromAccumulated(responseData) === null) break;
state.buffer = mergedResp[0]; state.buffer = mergeStreamBuffer(state.buffer, deltaContent, responseData);
// 更新时间线条目内容 // 更新时间线条目内容
if (state.itemId) { if (state.itemId) {
@@ -2507,6 +2511,212 @@ async function attachRunningTaskEventStream(conversationId) {
window.attachRunningTaskEventStream = attachRunningTaskEventStream; window.attachRunningTaskEventStream = attachRunningTaskEventStream;
window.taskReplayProgressId = taskReplayProgressId; window.taskReplayProgressId = taskReplayProgressId;
window.expandProcessDetailsTimeline = expandProcessDetailsTimeline;
/** 从工具参数提取短摘要(URL/命令等),便于同名工具批量调用时区分 */
function parseToolCallArgsFromData(data) {
if (!data) return {};
let args = data.argumentsObj;
if (args == null && data.arguments != null && String(data.arguments).trim() !== '') {
try {
args = JSON.parse(String(data.arguments));
} catch (e) {
args = { _raw: String(data.arguments) };
}
}
if (args == null || typeof args !== 'object') {
return {};
}
return args;
}
function formatToolCallTimelineTitle(toolName, index, total) {
const name = toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具');
const idx = index || 0;
const tot = total || 0;
if (typeof window.t === 'function') {
return window.t('chat.callTool', { name: name, index: idx, total: tot });
}
return '调用工具: ' + name + (tot ? ' (' + idx + '/' + tot + ')' : '');
}
function buildToolResultSectionHtml(data, opts) {
opts = opts || {};
const _t = function (k, o) {
return typeof window.t === 'function' ? window.t(k, o) : k;
};
const execResultLabel = _t('timeline.executionResult');
const execIdLabel = _t('timeline.executionId');
const waitingLabel = _t('timeline.running');
if (opts.pending) {
return (
'<div class="tool-result-section pending">' +
'<strong data-i18n="timeline.executionResult">' + escapeHtml(execResultLabel) + '</strong>' +
'<pre class="tool-result tool-result-pending">' + escapeHtml(waitingLabel) + '</pre>' +
'</div>'
);
}
const isError = data.isError || data.success === false;
const noResultText = _t('timeline.noResult');
const result = data.result != null ? data.result : (data.error != null ? data.error : noResultText);
const resultStr = typeof result === 'string' ? result : JSON.stringify(result);
const rawText = opts.rawText != null ? String(opts.rawText) : resultStr;
return (
'<div class="tool-result-section ' + (isError ? 'error' : 'success') + '">' +
'<strong data-i18n="timeline.executionResult">' + escapeHtml(execResultLabel) + '</strong>' +
'<pre class="tool-result">' + escapeHtml(rawText) + '</pre>' +
(data.executionId ? '<div class="tool-execution-id"><span data-i18n="timeline.executionId">' +
escapeHtml(execIdLabel) + '</span> <code>' + escapeHtml(String(data.executionId)) + '</code></div>' : '') +
'</div>'
);
}
function ensureToolCallResultSlot(item) {
if (!item) return null;
let section = item.querySelector('.tool-result-section');
if (section) return section;
const content = item.querySelector('.timeline-item-content');
if (!content) return null;
const wrap = document.createElement('div');
wrap.className = 'tool-details tool-result-slot';
wrap.innerHTML = buildToolResultSectionHtml({}, { pending: true });
content.appendChild(wrap);
return wrap.querySelector('.tool-result-section');
}
function mergeToolResultIntoCallItem(item, data, options) {
if (!item || !data) return false;
options = options || {};
const isError = data.isError || data.success === false;
const noResultText = typeof window.t === 'function' ? window.t('timeline.noResult') : '无结果';
const result = data.result != null ? data.result : (data.error != null ? data.error : noResultText);
const resultStr = typeof result === 'string' ? result : JSON.stringify(result);
const text = options.rawText != null ? String(options.rawText) : resultStr;
let section = item.querySelector('.tool-result-section');
if (!section) {
ensureToolCallResultSlot(item);
section = item.querySelector('.tool-result-section');
}
if (!section) return false;
section.classList.remove('pending');
section.className = 'tool-result-section ' + (isError ? 'error' : 'success');
const pre = section.querySelector('pre.tool-result');
if (pre) {
pre.classList.remove('tool-result-pending');
pre.textContent = text;
}
if (data.executionId) {
let execIdEl = section.querySelector('.tool-execution-id');
if (!execIdEl) {
const execIdLabel = typeof window.t === 'function' ? window.t('timeline.executionId') : '执行ID:';
execIdEl = document.createElement('div');
execIdEl.className = 'tool-execution-id';
execIdEl.innerHTML = '<span data-i18n="timeline.executionId">' + escapeHtml(execIdLabel) +
'</span> <code></code>';
section.appendChild(execIdEl);
}
const code = execIdEl.querySelector('code');
if (code) code.textContent = String(data.executionId);
}
item.dataset.toolResultMerged = '1';
item.dataset.toolSuccess = data.success !== false ? '1' : '0';
item.classList.remove('tool-call-running');
item.classList.add(data.success !== false ? 'tool-call-completed' : 'tool-call-failed');
return true;
}
function findToolCallItemById(root, toolCallId) {
if (!root || !toolCallId) return null;
const id = String(toolCallId).trim();
if (!id) return null;
try {
return root.querySelector('[data-tool-call-id="' + CSS.escape(id) + '"]');
} catch (e) {
return root.querySelector('[data-tool-call-id="' + id.replace(/"/g, '\\"') + '"]');
}
}
function attachToolResultToCall(toolCallId, data, options) {
if (!toolCallId || !data) return false;
const mapping = toolCallStatusMap.get(toolCallId);
let item = null;
if (mapping && mapping.itemId) {
item = document.getElementById(mapping.itemId);
}
if (!item && mapping && mapping.timeline) {
item = findToolCallItemById(mapping.timeline, toolCallId);
}
if (!item) return false;
mergeToolResultIntoCallItem(item, data, options);
return true;
}
function coalesceProcessDetailsToolPairs(details) {
if (!Array.isArray(details) || details.length === 0) return details;
const callsById = new Map();
const fifoCalls = [];
const out = [];
function absorbResult(targetDetail, resultDetail) {
const rd = resultDetail.data || {};
targetDetail.data = targetDetail.data || {};
targetDetail.data._mergedResult = Object.assign({}, rd);
if (resultDetail.createdAt) {
targetDetail.data._mergedResultAt = resultDetail.createdAt;
}
}
for (let i = 0; i < details.length; i++) {
const detail = details[i];
const et = detail.eventType || '';
const data = detail.data || {};
const id = data.toolCallId != null ? String(data.toolCallId).trim() : '';
if (et === 'tool_call') {
const copy = {
eventType: detail.eventType,
message: detail.message,
createdAt: detail.createdAt,
data: Object.assign({}, data)
};
if (id) callsById.set(id, copy);
fifoCalls.push(copy);
out.push(copy);
} else if (et === 'tool_result') {
let target = null;
if (id && callsById.has(id)) {
target = callsById.get(id);
} else {
for (let j = 0; j < fifoCalls.length; j++) {
const c = fifoCalls[j];
if (c && c.data && !c.data._mergedResult) {
target = c;
break;
}
}
}
if (target) {
absorbResult(target, detail);
continue;
}
out.push(detail);
} else {
out.push(detail);
}
}
return out;
}
window.coalesceProcessDetailsToolPairs = coalesceProcessDetailsToolPairs;
window.attachToolResultToCall = attachToolResultToCall;
window.mergeToolResultIntoCallItem = mergeToolResultIntoCallItem;
window.formatToolCallTimelineTitle = formatToolCallTimelineTitle;
window.parseToolCallArgsFromData = parseToolCallArgsFromData;
window.buildToolResultSectionHtml = buildToolResultSectionHtml;
// 更新工具调用状态 // 更新工具调用状态
function updateToolCallStatus(toolCallId, status) { function updateToolCallStatus(toolCallId, status) {
@@ -2574,6 +2784,11 @@ function addTimelineItem(timeline, type, options) {
if (d.toolCallId != null && String(d.toolCallId).trim() !== '') { if (d.toolCallId != null && String(d.toolCallId).trim() !== '') {
item.dataset.toolCallId = String(d.toolCallId).trim(); item.dataset.toolCallId = String(d.toolCallId).trim();
} }
const merged = options.mergedResult || d._mergedResult;
if (merged) {
item.dataset.toolResultMerged = '1';
item.dataset.toolSuccess = merged.success !== false ? '1' : '0';
}
} }
if (type === 'hitl_interrupt' && options.data && options.data.interruptId != null && String(options.data.interruptId).trim() !== '') { if (type === 'hitl_interrupt' && options.data && options.data.interruptId != null && String(options.data.interruptId).trim() !== '') {
item.dataset.hitlInterruptId = String(options.data.interruptId).trim(); item.dataset.hitlInterruptId = String(options.data.interruptId).trim();
@@ -2635,18 +2850,20 @@ function addTimelineItem(timeline, type, options) {
content += `<div class="timeline-item-content">${formatMarkdown(streamBody)}</div>`; content += `<div class="timeline-item-content">${formatMarkdown(streamBody)}</div>`;
} else if (type === 'tool_call' && options.data) { } else if (type === 'tool_call' && options.data) {
const data = options.data; const data = options.data;
let args = data.argumentsObj; const args = parseToolCallArgsFromData(data);
if (args == null && data.arguments != null && String(data.arguments).trim() !== '') { const merged = options.mergedResult || data._mergedResult;
try {
args = JSON.parse(String(data.arguments));
} catch (e) {
args = { _raw: String(data.arguments) };
}
}
if (args == null || typeof args !== 'object') {
args = {};
}
const paramsLabel = typeof window.t === 'function' ? window.t('timeline.params') : '参数:'; const paramsLabel = typeof window.t === 'function' ? window.t('timeline.params') : '参数:';
let resultBlock = '';
if (merged) {
resultBlock = '<div class="tool-details tool-result-slot">' + buildToolResultSectionHtml(merged) + '</div>';
if (merged.success !== false) {
item.classList.add('tool-call-completed');
} else {
item.classList.add('tool-call-failed');
}
} else if (!options.skipPendingResult) {
resultBlock = '<div class="tool-details tool-result-slot">' + buildToolResultSectionHtml({}, { pending: true }) + '</div>';
}
content += ` content += `
<div class="timeline-item-content"> <div class="timeline-item-content">
<div class="tool-details"> <div class="tool-details">
@@ -2654,6 +2871,7 @@ function addTimelineItem(timeline, type, options) {
<strong data-i18n="timeline.params">${escapeHtml(paramsLabel)}</strong> <strong data-i18n="timeline.params">${escapeHtml(paramsLabel)}</strong>
<pre class="tool-args">${escapeHtml(JSON.stringify(args, null, 2))}</pre> <pre class="tool-args">${escapeHtml(JSON.stringify(args, null, 2))}</pre>
</div> </div>
${resultBlock}
</div> </div>
</div> </div>
`; `;
@@ -4071,7 +4289,10 @@ function refreshProgressAndTimelineI18n() {
const name = (item.dataset.toolName != null && item.dataset.toolName !== '') ? item.dataset.toolName : _t('chat.unknownTool'); const name = (item.dataset.toolName != null && item.dataset.toolName !== '') ? item.dataset.toolName : _t('chat.unknownTool');
const index = parseInt(item.dataset.toolIndex, 10) || 0; const index = parseInt(item.dataset.toolIndex, 10) || 0;
const total = parseInt(item.dataset.toolTotal, 10) || 0; const total = parseInt(item.dataset.toolTotal, 10) || 0;
titleSpan.textContent = ap + '\uD83D\uDD27 ' + _t('chat.callTool', { name: name, index: index, total: total }); const callTitle = typeof formatToolCallTimelineTitle === 'function'
? formatToolCallTimelineTitle(name, index, total)
: _t('chat.callTool', { name: name, index: index, total: total });
titleSpan.textContent = ap + '\uD83D\uDD27 ' + callTitle;
} else if (type === 'tool_result' && (item.dataset.toolName !== undefined || item.dataset.toolSuccess !== undefined)) { } else if (type === 'tool_result' && (item.dataset.toolName !== undefined || item.dataset.toolSuccess !== undefined)) {
const name = (item.dataset.toolName != null && item.dataset.toolName !== '') ? item.dataset.toolName : _t('chat.unknownTool'); const name = (item.dataset.toolName != null && item.dataset.toolName !== '') ? item.dataset.toolName : _t('chat.unknownTool');
const success = item.dataset.toolSuccess === '1'; const success = item.dataset.toolSuccess === '1';
+907
View File
@@ -0,0 +1,907 @@
/**
* 项目管理与事实黑板
*/
let projectsCache = [];
let projectsCacheAll = [];
let currentProjectId = null;
let currentProjectTab = 'facts';
const projectNameById = {};
let _projectsListReady = false;
let _projectsFetchPromise = null;
const PROJECT_ACTIVE_KEY = 'cyberstrike.activeProjectId';
function getActiveProjectId() {
try {
return localStorage.getItem(PROJECT_ACTIVE_KEY) || '';
} catch (e) {
return '';
}
}
function setActiveProjectId(id) {
try {
if (id) localStorage.setItem(PROJECT_ACTIVE_KEY, id);
else localStorage.removeItem(PROJECT_ACTIVE_KEY);
} catch (e) { /* ignore */ }
}
function rebuildProjectNameMap(list) {
Object.keys(projectNameById).forEach((k) => delete projectNameById[k]);
(list || []).forEach((p) => {
if (p && p.id) projectNameById[p.id] = p.name || p.id;
});
}
async function fetchProjectsList(includeArchived) {
const showArchived = includeArchived || document.getElementById('projects-show-archived')?.checked;
const url = showArchived ? '/api/projects?limit=200' : '/api/projects?status=active&limit=200';
const res = await apiFetch(url);
if (!res.ok) throw new Error('加载项目失败');
const data = await res.json();
projectsCache = Array.isArray(data) ? data : [];
rebuildProjectNameMap(projectsCache);
_projectsListReady = true;
return projectsCache;
}
/** 对话页等项目选择器:确保列表已拉取(去重并发请求) */
async function ensureProjectsLoaded(force) {
if (!force && _projectsListReady) return projectsCache;
if (!force && _projectsFetchPromise) return _projectsFetchPromise;
_projectsFetchPromise = fetchProjectsList(false)
.catch((e) => {
_projectsListReady = false;
throw e;
})
.finally(() => {
_projectsFetchPromise = null;
});
return _projectsFetchPromise;
}
function prefetchProjectsForChat() {
ensureProjectsLoaded().catch(() => {});
}
function getProjectName(id) {
return projectNameById[id] || id || '';
}
function initProjectsModalEscape() {
if (window._projectsModalEscapeBound) return;
window._projectsModalEscapeBound = true;
document.addEventListener('keydown', (e) => {
if (e.key !== 'Escape') return;
if (document.getElementById('project-modal')?.style.display === 'flex') closeProjectModal();
else if (document.getElementById('fact-modal')?.style.display === 'flex') closeFactModal();
else if (document.getElementById('fact-detail-modal')?.style.display === 'flex') closeFactDetailModal();
});
}
async function initProjectsPage() {
const page = document.getElementById('page-projects');
if (!page || page.style.display === 'none') return;
initProjectsModalEscape();
updateProjectsDetailVisibility();
await loadProjectsList();
if (!currentProjectId && projectsCache.length) {
const fromHash = new URLSearchParams(window.location.hash.split('?')[1] || '').get('id');
currentProjectId = fromHash || projectsCache[0].id;
}
renderProjectsSidebar();
if (currentProjectId) {
await selectProject(currentProjectId);
}
}
async function loadProjectsList() {
await fetchProjectsList();
renderProjectsSidebar();
if (typeof refreshChatProjectSelector === 'function') {
refreshChatProjectSelector();
}
if (typeof refreshVulnerabilityProjectFilter === 'function') {
refreshVulnerabilityProjectFilter();
}
}
function projectInitial(name) {
const s = (name || 'P').trim();
return s ? s.charAt(0).toUpperCase() : 'P';
}
function updateProjectsDetailVisibility() {
const main = document.getElementById('projects-detail-main');
const placeholder = document.getElementById('projects-detail-placeholder');
const inner = document.getElementById('projects-detail-inner');
const show = !!currentProjectId;
if (main) main.classList.toggle('has-project', show);
if (placeholder) placeholder.hidden = show;
if (inner) inner.hidden = !show;
}
function updateProjectsListCount() {
const el = document.getElementById('projects-list-count');
if (el) el.textContent = String(projectsCache.length);
}
function formatConfidenceBadge(confidence) {
const c = (confidence || '').toLowerCase();
let cls = 'projects-confidence--tentative';
let label = c || '—';
if (c === 'confirmed') {
cls = 'projects-confidence--confirmed';
label = '已确认';
} else if (c === 'deprecated') {
cls = 'projects-confidence--deprecated';
label = '已废弃';
} else if (c === 'tentative') {
label = '待确认';
}
return `<span class="projects-confidence ${cls}">${escapeHtml(label)}</span>`;
}
function renderProjectFactActions(keyEsc, idEsc) {
return `<div class="projects-table-actions">
<button type="button" class="projects-action-btn projects-action-btn--edit" data-fact-key="${keyEsc}" onclick="showEditFactModal(this.dataset.factKey)" title="编辑各字段">编辑</button>
<button type="button" class="projects-action-btn projects-action-btn--view" data-fact-key="${keyEsc}" onclick="viewProjectFactBody(this.dataset.factKey)" title="查看完整 body">详情</button>
<button type="button" class="projects-action-btn projects-action-btn--mute" data-fact-key="${keyEsc}" onclick="deprecateProjectFactByKey(this.dataset.factKey)" title="标记为已废弃">废弃</button>
<button type="button" class="projects-action-btn projects-action-btn--danger" data-fact-id="${idEsc}" onclick="deleteProjectFact(this.dataset.factId)" title="永久删除">删除</button>
</div>`;
}
function formatSeverityBadge(severity) {
const s = (severity || 'info').toLowerCase();
const cls = 'projects-severity--' + (['critical', 'high', 'medium', 'low', 'info'].includes(s) ? s : 'info');
return `<span class="projects-severity ${cls}">${escapeHtml(severity || '—')}</span>`;
}
function getProjectsListFilter() {
return (document.getElementById('projects-list-search')?.value || '').trim().toLowerCase();
}
function filterProjectsList() {
renderProjectsSidebar();
}
function renderProjectsSidebar() {
const el = document.getElementById('projects-list');
if (!el) return;
updateProjectsListCount();
const q = getProjectsListFilter();
const list = q
? projectsCache.filter((p) => (p.name || '').toLowerCase().includes(q) || (p.description || '').toLowerCase().includes(q))
: projectsCache;
if (!projectsCache.length) {
el.innerHTML =
'<div class="projects-empty">暂无项目<br><button type="button" class="btn-primary btn-small projects-empty-btn" onclick="showNewProjectModal()">新建项目</button></div>';
updateProjectsDetailVisibility();
return;
}
if (!list.length) {
el.innerHTML = '<div class="projects-empty">无匹配项目</div>';
updateProjectsDetailVisibility();
return;
}
el.innerHTML = list.map((p) => {
const active = p.id === currentProjectId ? ' is-active' : '';
const archived = p.status === 'archived' ? ' is-archived' : '';
const badges = [
p.pinned ? '<span class="projects-list-item-badge">置顶</span>' : '',
p.status === 'archived' ? '<span class="projects-list-item-badge">归档</span>' : '',
].join('');
return `<div class="projects-list-item${active}${archived}" data-id="${escapeHtml(p.id)}" onclick="selectProject('${escapeHtml(p.id)}')">
<div class="projects-list-item-body">
<div class="projects-list-item-name">${escapeHtml(p.name)}${badges}</div>
<div class="projects-list-item-meta">${formatProjectTime(p.updated_at)}</div>
</div>
</div>`;
}).join('');
updateProjectsDetailVisibility();
}
function updateProjectStatusPill(status) {
const el = document.getElementById('projects-detail-status');
if (!el) return;
const archived = status === 'archived';
el.textContent = archived ? '已归档' : '进行中';
el.className = 'projects-status-pill ' + (archived ? 'projects-status-pill--archived' : 'projects-status-pill--active');
}
function updateProjectStats(factCount, vulnCount) {
const f = document.getElementById('project-stat-facts');
const v = document.getElementById('project-stat-vulns');
if (f) f.textContent = `${factCount ?? 0} 条事实`;
if (v) v.textContent = `${vulnCount ?? 0} 个漏洞`;
}
async function selectProject(id) {
currentProjectId = id;
renderProjectsSidebar();
updateProjectsDetailVisibility();
try {
const res = await apiFetch(`/api/projects/${id}`);
if (!res.ok) throw new Error('项目不存在');
const p = await res.json();
const titleEl = document.getElementById('projects-detail-title');
if (titleEl) titleEl.textContent = p.name || '项目';
document.getElementById('project-edit-name').value = p.name || '';
document.getElementById('project-edit-description').value = p.description || '';
document.getElementById('project-edit-scope').value = p.scope_json || '';
const statusEl = document.getElementById('project-edit-status');
if (statusEl) statusEl.value = p.status || 'active';
updateProjectStatusPill(p.status || 'active');
const metaEl = document.getElementById('projects-detail-meta');
if (metaEl) metaEl.textContent = `更新于 ${formatProjectTime(p.updated_at)}`;
const descEl = document.getElementById('projects-detail-desc');
if (descEl) {
const desc = (p.description || '').trim();
if (desc) {
descEl.textContent = desc;
descEl.hidden = false;
} else {
descEl.textContent = '';
descEl.hidden = true;
}
}
projectNameById[p.id] = p.name || p.id;
} catch (e) {
console.warn(e);
}
refreshProjectHeaderStats();
switchProjectTab(currentProjectTab);
}
function switchProjectTab(tab) {
currentProjectTab = tab;
['facts', 'vulns', 'settings'].forEach((t) => {
const btn = document.getElementById(`project-tab-${t}`);
const panel = document.getElementById(`project-panel-${t}`);
if (btn) btn.classList.toggle('is-active', t === tab);
if (panel) panel.hidden = t !== tab;
});
if (tab === 'facts') loadProjectFacts();
if (tab === 'vulns') loadProjectVulnerabilities();
}
async function loadProjectFacts() {
const tbody = document.getElementById('project-facts-tbody');
if (!tbody || !currentProjectId) return;
tbody.innerHTML = '<tr class="is-empty-row"><td colspan="6">加载中…</td></tr>';
const res = await apiFetch(`/api/projects/${currentProjectId}/facts?limit=200`);
if (!res.ok) {
tbody.innerHTML = '<tr class="is-empty-row"><td colspan="6">加载失败</td></tr>';
return;
}
const facts = await res.json();
if (!facts.length) {
tbody.innerHTML = '<tr class="is-empty-row"><td colspan="6">暂无事实,点击「添加事实」或由 Agent 自动写入</td></tr>';
refreshProjectHeaderStats();
return;
}
tbody.innerHTML = facts.map((f) => {
const keyEsc = escapeHtml(f.fact_key);
const idEsc = escapeHtml(f.id);
return `<tr>
<td><code>${keyEsc}</code></td>
<td>${escapeHtml(f.category)}</td>
<td class="cell-summary" title="${escapeHtml(f.summary)}">${escapeHtml(f.summary)}</td>
<td>${formatConfidenceBadge(f.confidence)}</td>
<td>${formatProjectTime(f.updated_at, f.created_at)}</td>
<td class="col-actions">${renderProjectFactActions(keyEsc, idEsc)}</td>
</tr>`;
}).join('');
refreshProjectHeaderStats();
}
async function refreshProjectHeaderStats() {
if (!currentProjectId) return;
try {
const [factsRes, vulnRes] = await Promise.all([
apiFetch(`/api/projects/${currentProjectId}/facts?limit=500`),
apiFetch(`/api/vulnerabilities?project_id=${encodeURIComponent(currentProjectId)}&limit=100`),
]);
let fc = 0;
let vc = 0;
if (factsRes.ok) {
const f = await factsRes.json();
fc = Array.isArray(f) ? f.length : 0;
}
if (vulnRes.ok) {
const d = await vulnRes.json();
const items = d.Vulnerabilities || d.vulnerabilities || d.items || [];
vc = items.length;
}
updateProjectStats(fc, vc);
} catch (e) {
console.warn(e);
}
}
let _factDetailKey = null;
async function viewProjectFactBody(factKey) {
const res = await apiFetch(`/api/projects/${currentProjectId}/facts?fact_key=${encodeURIComponent(factKey)}`);
if (!res.ok) return alert('加载失败');
const f = await res.json();
_factDetailKey = f.fact_key;
document.getElementById('fact-detail-title').textContent = `[${f.fact_key}]`;
document.getElementById('fact-detail-meta').textContent =
`分类: ${f.category} · 置信度: ${f.confidence} · 更新: ${formatProjectTime(f.updated_at, f.created_at)}` +
(f.related_vulnerability_id ? ` · 关联漏洞: ${f.related_vulnerability_id}` : '');
document.getElementById('fact-detail-body').textContent = f.body || '(无 body)';
openProjectsOverlay('fact-detail-modal');
}
function editFactFromDetail() {
const key = _factDetailKey;
closeFactDetailModal();
if (key) showEditFactModal(key);
}
function closeFactDetailModal() {
closeProjectsOverlay('fact-detail-modal');
_factDetailKey = null;
}
async function deprecateProjectFactByKey(factKey) {
if (!confirm(`将事实 ${factKey} 标记为 deprecated`)) return;
const res = await apiFetch(`/api/projects/${currentProjectId}/facts/deprecate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ fact_key: factKey }),
});
if (!res.ok) return alert('操作失败');
loadProjectFacts();
}
function openVulnerabilitiesForProject(projectId) {
const pid = projectId || currentProjectId;
if (!pid) return;
if (typeof switchPage === 'function') {
switchPage('vulnerabilities');
}
if (typeof window.setVulnerabilityProjectFilter === 'function') {
window.setVulnerabilityProjectFilter(pid);
} else {
window.location.hash = `vulnerabilities?project_id=${encodeURIComponent(pid)}`;
}
}
async function loadProjectVulnerabilities() {
const tbody = document.getElementById('project-vulns-tbody');
if (!tbody || !currentProjectId) return;
tbody.innerHTML = '<tr class="is-empty-row"><td colspan="4">加载中…</td></tr>';
const res = await apiFetch(`/api/vulnerabilities?project_id=${encodeURIComponent(currentProjectId)}&limit=100`);
if (!res.ok) {
tbody.innerHTML = '<tr class="is-empty-row"><td colspan="4">加载失败</td></tr>';
return;
}
const data = await res.json();
const items = data.Vulnerabilities || data.vulnerabilities || data.items || [];
if (!items.length) {
tbody.innerHTML = '<tr class="is-empty-row"><td colspan="4">本项目暂无漏洞记录</td></tr>';
refreshProjectHeaderStats();
return;
}
tbody.innerHTML = items.map((v) => {
const idEsc = escapeHtml(v.id);
return `<tr>
<td class="cell-summary" title="${escapeHtml(v.title)}">${escapeHtml(v.title)}</td>
<td>${formatSeverityBadge(v.severity)}</td>
<td>${escapeHtml(v.status)}</td>
<td class="col-actions">
<div class="projects-table-actions">
<button type="button" class="projects-action-btn projects-action-btn--view" data-vuln-id="${idEsc}" onclick="openVulnerabilityDetail(this.dataset.vulnId)">查看</button>
</div>
</td>
</tr>`;
}).join('');
refreshProjectHeaderStats();
}
function openVulnerabilityDetail(vulnId) {
openVulnerabilitiesForProject(currentProjectId);
if (typeof window.setVulnerabilityIdFilter === 'function') {
setTimeout(() => window.setVulnerabilityIdFilter(vulnId), 300);
}
}
function openProjectsOverlay(id) {
const el = document.getElementById(id);
if (!el) return;
el.style.display = 'flex';
document.body.classList.add('projects-modal-open');
const focusTarget = el.querySelector('input.form-input, textarea.form-input, select.form-input');
if (focusTarget) {
setTimeout(() => focusTarget.focus(), 80);
}
}
function closeProjectsOverlay(id) {
const el = document.getElementById(id);
if (el) el.style.display = 'none';
const anyOpen = document.querySelector('.projects-modal-overlay[style*="flex"]');
if (!anyOpen) document.body.classList.remove('projects-modal-open');
}
function showNewProjectModal() {
document.getElementById('project-modal-title').textContent = '新建项目';
const sub = document.getElementById('project-modal-subtitle');
if (sub) sub.textContent = '创建后可绑定对话,跨会话共享事实黑板';
const submitBtn = document.getElementById('project-modal-submit-btn');
if (submitBtn) submitBtn.textContent = '创建项目';
document.getElementById('project-modal-name').value = '';
document.getElementById('project-modal-description').value = '';
window._projectModalEditId = null;
openProjectsOverlay('project-modal');
}
async function saveProjectModal() {
const name = document.getElementById('project-modal-name').value.trim();
if (!name) return alert('请输入项目名称');
const body = {
name,
description: document.getElementById('project-modal-description').value.trim(),
};
const editId = window._projectModalEditId;
const res = editId
? await apiFetch(`/api/projects/${editId}`, { method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(body) })
: await apiFetch('/api/projects', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(body) });
if (!res.ok) {
const err = await res.json().catch(() => ({}));
alert(err.error || '保存失败');
return;
}
closeProjectModal();
const saved = await res.json();
await loadProjectsList();
if (saved.id) await selectProject(saved.id);
}
function closeProjectModal() {
closeProjectsOverlay('project-modal');
}
function formatProjectScopeJson() {
const el = document.getElementById('project-edit-scope');
if (!el) return;
const raw = el.value.trim();
if (!raw) return;
try {
el.value = JSON.stringify(JSON.parse(raw), null, 2);
} catch (e) {
alert('JSON 格式无效:' + (e.message || String(e)));
}
}
function insertProjectScopeExample() {
const el = document.getElementById('project-edit-scope');
if (!el) return;
const example = {
targets: ['https://example.com'],
exclude: ['*.cdn.example.com'],
notes: '仅授权 Web 应用层测试',
};
el.value = JSON.stringify(example, null, 2);
el.focus();
}
async function saveProjectSettings() {
if (!currentProjectId) return;
const scopeRaw = document.getElementById('project-edit-scope').value.trim();
if (scopeRaw) {
try {
JSON.parse(scopeRaw);
} catch (e) {
alert('测试范围 JSON 无效,请先修正或点击「格式化」:' + (e.message || String(e)));
return;
}
}
const body = {
name: document.getElementById('project-edit-name').value.trim(),
description: document.getElementById('project-edit-description').value.trim(),
scope_json: scopeRaw,
status: document.getElementById('project-edit-status')?.value || 'active',
};
const res = await apiFetch(`/api/projects/${currentProjectId}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body),
});
if (!res.ok) return alert('保存失败');
await loadProjectsList();
await selectProject(currentProjectId);
alert('已保存');
}
async function archiveCurrentProject() {
if (!currentProjectId) return;
const statusEl = document.getElementById('project-edit-status');
const cur = statusEl?.value || 'active';
const next = cur === 'archived' ? 'active' : 'archived';
if (!confirm(next === 'archived' ? '归档后默认不再出现在活跃列表,是否继续?' : '恢复为 active')) return;
const res = await apiFetch(`/api/projects/${currentProjectId}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ status: next }),
});
if (!res.ok) return alert('操作失败');
await loadProjectsList();
await selectProject(currentProjectId);
}
async function deleteCurrentProject() {
if (!currentProjectId || !confirm('确定删除该项目?事实将一并删除,对话将解除绑定。')) return;
const deletedId = currentProjectId;
const deletedIndex = projectsCache.findIndex((p) => p.id === deletedId);
const res = await apiFetch(`/api/projects/${deletedId}`, { method: 'DELETE' });
if (!res.ok) return alert('删除失败');
if (getActiveProjectId() === deletedId) setActiveProjectId('');
currentProjectId = null;
await loadProjectsList();
if (projectsCache.length) {
const nextIndex = Math.min(deletedIndex >= 0 ? deletedIndex : 0, projectsCache.length - 1);
await selectProject(projectsCache[nextIndex].id);
} else {
updateProjectsDetailVisibility();
}
}
function resetFactModalForm() {
window._factModalEditId = null;
const keyEl = document.getElementById('fact-modal-key');
if (keyEl) keyEl.disabled = false;
document.getElementById('fact-modal-title').textContent = '添加事实';
document.getElementById('fact-modal-submit-btn').textContent = '保存事实';
document.getElementById('fact-modal-key').value = '';
document.getElementById('fact-modal-category').value = 'note';
document.getElementById('fact-modal-summary').value = '';
document.getElementById('fact-modal-body').value = '';
document.getElementById('fact-modal-confidence').value = 'tentative';
const rel = document.getElementById('fact-modal-related-vuln');
if (rel) rel.value = '';
}
function fillFactModalForm(f) {
window._factModalEditId = f.id;
document.getElementById('fact-modal-title').textContent = '编辑事实';
document.getElementById('fact-modal-submit-btn').textContent = '保存修改';
document.getElementById('fact-modal-key').value = f.fact_key || '';
document.getElementById('fact-modal-category').value = f.category || 'note';
document.getElementById('fact-modal-summary').value = f.summary || '';
document.getElementById('fact-modal-body').value = f.body || '';
const conf = (f.confidence || 'tentative').toLowerCase();
const confEl = document.getElementById('fact-modal-confidence');
if (confEl) {
const allowed = ['tentative', 'confirmed', 'deprecated'];
confEl.value = allowed.includes(conf) ? conf : 'tentative';
}
const rel = document.getElementById('fact-modal-related-vuln');
if (rel) rel.value = f.related_vulnerability_id || '';
}
function showAddFactModal() {
if (!currentProjectId) return alert('请先选择项目');
resetFactModalForm();
openProjectsOverlay('fact-modal');
}
async function showEditFactModal(factKey) {
if (!currentProjectId) return alert('请先选择项目');
const res = await apiFetch(
`/api/projects/${currentProjectId}/facts?fact_key=${encodeURIComponent(factKey)}`,
);
if (!res.ok) return alert('加载事实失败');
const f = await res.json();
resetFactModalForm();
fillFactModalForm(f);
openProjectsOverlay('fact-modal');
}
function closeFactModal() {
closeProjectsOverlay('fact-modal');
resetFactModalForm();
}
async function saveFactModal() {
const fact_key = document.getElementById('fact-modal-key').value.trim();
const summary = document.getElementById('fact-modal-summary').value.trim();
if (!fact_key || !summary) return alert('fact_key 与 summary 必填');
const payload = {
fact_key,
category: document.getElementById('fact-modal-category').value.trim() || 'note',
summary,
body: document.getElementById('fact-modal-body').value,
confidence: document.getElementById('fact-modal-confidence').value,
related_vulnerability_id: document.getElementById('fact-modal-related-vuln')?.value?.trim() || '',
};
const editId = window._factModalEditId;
const res = editId
? await apiFetch(`/api/projects/${currentProjectId}/facts/${editId}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
})
: await apiFetch(`/api/projects/${currentProjectId}/facts`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
return alert(err.error || '保存失败');
}
closeFactModal();
loadProjectFacts();
}
async function deleteProjectFact(id) {
if (!confirm('删除该事实?')) return;
await apiFetch(`/api/projects/${currentProjectId}/facts/${id}`, { method: 'DELETE' });
loadProjectFacts();
}
function parseProjectDate(t) {
if (t == null || t === '') return null;
if (typeof t === 'number' && Number.isFinite(t)) {
const d = new Date(t);
return isNaN(d.getTime()) || d.getFullYear() < 2000 ? null : d;
}
let s = String(t).trim();
if (!s || s.startsWith('0001-01-01')) return null;
let d = new Date(s);
if (!isNaN(d.getTime()) && d.getFullYear() >= 2000) return d;
const m = s.match(
/^(\d{4})-(\d{2})-(\d{2})[T\s](\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:([Zz]|([+-])(\d{2}):?(\d{2}))?)?$/,
);
if (m) {
const ms = m[7] ? parseInt(String(m[7]).slice(0, 3).padEnd(3, '0'), 10) : 0;
let offMin = 0;
if (m[8] && m[9] && m[10]) {
offMin = parseInt(m[10], 10) * 60 + parseInt(m[11] || '0', 10);
if (m[9] === '-') offMin = -offMin;
}
d = new Date(
Date.UTC(
parseInt(m[1], 10),
parseInt(m[2], 10) - 1,
parseInt(m[3], 10),
parseInt(m[4], 10),
parseInt(m[5], 10),
parseInt(m[6], 10),
ms,
) - offMin * 60 * 1000,
);
if (!isNaN(d.getTime()) && d.getFullYear() >= 2000) return d;
}
return null;
}
function formatProjectTime(t, fallback) {
const d = parseProjectDate(t) || (fallback != null ? parseProjectDate(fallback) : null);
if (!d) return '尚未更新';
const now = Date.now();
const diff = now - d.getTime();
if (diff < 60000) return '刚刚';
if (diff < 3600000) return `${Math.floor(diff / 60000)} 分钟前`;
if (diff < 86400000) return `${Math.floor(diff / 3600000)} 小时前`;
if (diff < 604800000) return `${Math.floor(diff / 86400000)} 天前`;
return d.toLocaleString(undefined, { year: 'numeric', month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' });
}
function escapeHtml(s) {
if (s == null) return '';
return String(s)
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;');
}
function getChatProjectSelection() {
const convId = window.currentConversationId;
if (convId) {
return window._loadedConversationProjectId || '';
}
return getActiveProjectId();
}
function updateChatProjectButtonLabel() {
const textEl = document.getElementById('chat-project-text');
if (!textEl) return;
const id = getChatProjectSelection();
textEl.textContent = id ? getProjectName(id) || id : '无项目';
}
function renderChatProjectPanelList() {
const list = document.getElementById('chat-project-list');
if (!list) return;
const selected = getChatProjectSelection();
const activeProjects = projectsCache.filter((p) => p.status !== 'archived');
const items = [{ id: '', name: '无项目', description: '不绑定项目黑板' }, ...activeProjects];
if (!items.length) {
list.innerHTML = '<div class="chat-project-panel-empty">暂无项目,可在「项目管理」中创建</div>';
return;
}
list.innerHTML = '';
items.forEach((p) => {
const isNone = !p.id;
const isSelected = isNone ? !selected : selected === p.id;
const desc = isNone
? (p.description || '')
: (p.description || '').trim().slice(0, 80) || '共享事实黑板';
const projectId = p.id || '';
const btn = document.createElement('button');
btn.type = 'button';
btn.className = 'role-selection-item-main' + (isSelected ? ' selected' : '');
btn.setAttribute('role', 'option');
btn.onclick = () => {
selectChatProject(projectId);
};
btn.innerHTML = `
<div class="role-selection-item-icon-main">${isNone ? '—' : '📁'}</div>
<div class="role-selection-item-content-main">
<div class="role-selection-item-name-main">${escapeHtml(p.name || '未命名')}</div>
<div class="role-selection-item-description-main">${escapeHtml(desc)}</div>
</div>
${isSelected ? '<div class="role-selection-checkmark-main">✓</div>' : ''}
`;
list.appendChild(btn);
});
}
async function renderChatProjectPanel() {
const list = document.getElementById('chat-project-list');
if (!list) return;
list.innerHTML = '<div class="chat-project-panel-loading">加载中…</div>';
try {
await ensureProjectsLoaded();
} catch (e) {
console.warn(e);
list.innerHTML = '<div class="chat-project-panel-empty">加载失败,请稍后重试</div>';
return;
}
renderChatProjectPanelList();
}
function closeChatProjectPanel() {
const panel = document.getElementById('chat-project-panel');
const btn = document.getElementById('chat-project-btn');
if (panel) panel.style.display = 'none';
if (btn) {
btn.classList.remove('active');
btn.setAttribute('aria-expanded', 'false');
}
}
async function toggleChatProjectPanel() {
const panel = document.getElementById('chat-project-panel');
const btn = document.getElementById('chat-project-btn');
if (!panel) return;
const isHidden = panel.style.display === 'none' || !panel.style.display;
if (!isHidden) {
closeChatProjectPanel();
return;
}
if (typeof closeRoleSelectionPanel === 'function') closeRoleSelectionPanel();
if (typeof closeAgentModePanel === 'function') closeAgentModePanel();
if (typeof closeChatReasoningPanel === 'function') closeChatReasoningPanel();
panel.style.display = 'flex';
if (btn) {
btn.classList.add('active');
btn.setAttribute('aria-expanded', 'true');
}
await renderChatProjectPanel();
}
async function selectChatProject(projectId) {
closeChatProjectPanel();
await applyChatProjectSelection(projectId || '');
}
async function applyChatProjectSelection(projectId) {
const prev = getChatProjectSelection();
if (projectId === prev) {
updateChatProjectButtonLabel();
return;
}
if (window.currentConversationId) {
try {
const res = await apiFetch(`/api/conversations/${encodeURIComponent(window.currentConversationId)}/project`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ projectId }),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
throw new Error(err.error || res.statusText);
}
window._loadedConversationProjectId = projectId;
if (typeof showNotification === 'function') {
showNotification(projectId ? '已绑定项目' : '已解除项目绑定', 'success');
}
} catch (e) {
console.error(e);
alert('更新项目绑定失败: ' + (e.message || e));
updateChatProjectButtonLabel();
return;
}
} else {
setActiveProjectId(projectId);
}
updateChatProjectButtonLabel();
}
/** 对话页项目选择器:同步按钮文案;若浮层已打开则刷新列表 */
async function refreshChatProjectSelector() {
if (!document.getElementById('chat-project-btn')) return;
try {
await ensureProjectsLoaded();
} catch (e) {
console.warn(e);
}
updateChatProjectButtonLabel();
const panel = document.getElementById('chat-project-panel');
if (panel && panel.style.display === 'flex') {
renderChatProjectPanelList();
}
}
async function onChatProjectChange() {
/* 兼容旧调用;新 UI 使用 selectChatProject */
await applyChatProjectSelection(getChatProjectSelection());
}
function initChatProjectSelector() {
if (window._chatProjectSelectorInited) return;
window._chatProjectSelectorInited = true;
prefetchProjectsForChat();
updateChatProjectButtonLabel();
document.addEventListener('click', (e) => {
const panel = document.getElementById('chat-project-panel');
const wrapper = document.querySelector('.project-selector-wrapper');
if (!panel || panel.style.display === 'none' || !panel.style.display) return;
if (!wrapper?.contains(e.target)) {
closeChatProjectPanel();
}
});
}
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', initChatProjectSelector);
} else {
initChatProjectSelector();
}
window.initProjectsPage = initProjectsPage;
window.showNewProjectModal = showNewProjectModal;
window.saveProjectModal = saveProjectModal;
window.closeProjectModal = closeProjectModal;
window.selectProject = selectProject;
window.switchProjectTab = switchProjectTab;
window.showAddFactModal = showAddFactModal;
window.showEditFactModal = showEditFactModal;
window.editFactFromDetail = editFactFromDetail;
window.saveFactModal = saveFactModal;
window.closeFactModal = closeFactModal;
window.closeFactDetailModal = closeFactDetailModal;
window.saveProjectSettings = saveProjectSettings;
window.archiveCurrentProject = archiveCurrentProject;
window.deleteCurrentProject = deleteCurrentProject;
window.refreshChatProjectSelector = refreshChatProjectSelector;
window.onChatProjectChange = onChatProjectChange;
window.toggleChatProjectPanel = toggleChatProjectPanel;
window.closeChatProjectPanel = closeChatProjectPanel;
window.selectChatProject = selectChatProject;
window.prefetchProjectsForChat = prefetchProjectsForChat;
window.getActiveProjectId = getActiveProjectId;
window.getProjectName = getProjectName;
window.viewProjectFactBody = viewProjectFactBody;
window.deprecateProjectFactByKey = deprecateProjectFactByKey;
window.openVulnerabilitiesForProject = openVulnerabilitiesForProject;
window.openVulnerabilityDetail = openVulnerabilityDetail;
window.filterProjectsList = filterProjectsList;
window.rebuildProjectNameMap = rebuildProjectNameMap;
window.projectNameById = projectNameById;
+26 -15
View File
@@ -244,30 +244,46 @@ function selectRole(roleName) {
renderRoleSelectionSidebar(); // 重新渲染以更新选中状态 renderRoleSelectionSidebar(); // 重新渲染以更新选中状态
} }
function getChatRoleSelectorWrapper() {
return document.getElementById('role-selector-wrapper')
|| document.getElementById('role-selector-btn')?.closest('.role-selector-wrapper:not(.project-selector-wrapper)');
}
function isRoleSelectionPanelOpen() {
const panel = document.getElementById('role-selection-panel');
if (!panel) return false;
return panel.style.display !== 'none' && panel.style.display !== '';
}
// 切换角色选择面板显示/隐藏 // 切换角色选择面板显示/隐藏
function toggleRoleSelectionPanel() { function toggleRoleSelectionPanel() {
const panel = document.getElementById('role-selection-panel'); const panel = document.getElementById('role-selection-panel');
const roleSelectorBtn = document.getElementById('role-selector-btn'); const roleSelectorBtn = document.getElementById('role-selector-btn');
if (!panel) return; if (!panel) return;
const isHidden = panel.style.display === 'none' || !panel.style.display; const isHidden = !isRoleSelectionPanelOpen();
if (isHidden) { if (isHidden) {
if (typeof closeAgentModePanel === 'function') { if (typeof closeAgentModePanel === 'function') {
closeAgentModePanel(); closeAgentModePanel();
} }
if (typeof closeChatProjectPanel === 'function') {
closeChatProjectPanel();
}
if (typeof closeChatReasoningPanel === 'function') { if (typeof closeChatReasoningPanel === 'function') {
closeChatReasoningPanel(); closeChatReasoningPanel();
} }
renderRoleSelectionSidebar();
panel.style.display = 'flex'; // 使用flex布局 panel.style.display = 'flex'; // 使用flex布局
// 添加打开状态的视觉反馈 // 添加打开状态的视觉反馈
if (roleSelectorBtn) { if (roleSelectorBtn) {
roleSelectorBtn.classList.add('active'); roleSelectorBtn.classList.add('active');
roleSelectorBtn.setAttribute('aria-expanded', 'true');
} }
// 确保面板渲染后再检查位置 // 确保面板渲染后再检查位置
setTimeout(() => { setTimeout(() => {
const wrapper = document.querySelector('.role-selector-wrapper'); const wrapper = getChatRoleSelectorWrapper();
if (wrapper) { if (wrapper) {
const rect = wrapper.getBoundingClientRect(); const rect = wrapper.getBoundingClientRect();
const panelHeight = panel.offsetHeight || 400; const panelHeight = panel.offsetHeight || 400;
@@ -281,11 +297,7 @@ function toggleRoleSelectionPanel() {
} }
}, 10); }, 10);
} else { } else {
panel.style.display = 'none'; closeRoleSelectionPanel();
// 移除打开状态的视觉反馈
if (roleSelectorBtn) {
roleSelectorBtn.classList.remove('active');
}
} }
} }
@@ -298,6 +310,7 @@ function closeRoleSelectionPanel() {
} }
if (roleSelectorBtn) { if (roleSelectorBtn) {
roleSelectorBtn.classList.remove('active'); roleSelectorBtn.classList.remove('active');
roleSelectorBtn.setAttribute('aria-expanded', 'false');
} }
} }
@@ -1568,9 +1581,9 @@ async function deleteRole(roleName) {
} }
// 在页面切换时初始化角色列表 // 在页面切换时初始化角色列表
if (typeof switchPage === 'function') { if (typeof window.switchPage === 'function') {
const originalSwitchPage = switchPage; const originalSwitchPage = window.switchPage;
switchPage = function(page) { window.switchPage = function(page) {
originalSwitchPage(page); originalSwitchPage(page);
if (page === 'roles-management') { if (page === 'roles-management') {
loadRoles().then(() => renderRolesList()); loadRoles().then(() => renderRolesList());
@@ -1590,11 +1603,9 @@ document.addEventListener('click', (e) => {
closeRoleModal(); closeRoleModal();
} }
// 点击角色选择面板外部关闭面板(但不包括角色选择按钮和面板本身 // 点击角色选择面板外部关闭(须用 #role-selector-wrapper,勿用 .role-selector-wrapper:项目选择器也带该类
const roleSelectionPanel = document.getElementById('role-selection-panel'); if (isRoleSelectionPanelOpen()) {
const roleSelectorWrapper = document.querySelector('.role-selector-wrapper'); const roleSelectorWrapper = getChatRoleSelectorWrapper();
if (roleSelectionPanel && roleSelectionPanel.style.display !== 'none' && roleSelectionPanel.style.display) {
// 检查点击是否在面板或包装器上
if (!roleSelectorWrapper?.contains(e.target)) { if (!roleSelectorWrapper?.contains(e.target)) {
closeRoleSelectionPanel(); closeRoleSelectionPanel();
} }
+70 -7
View File
@@ -25,6 +25,13 @@ function scheduleChatConversationFromHash(delayMs) {
} }
const params = new URLSearchParams(hashParts.slice(1).join('?')); const params = new URLSearchParams(hashParts.slice(1).join('?'));
const conversationId = params.get('conversation'); const conversationId = params.get('conversation');
const projectId = params.get('project');
if (projectId && typeof setActiveProjectId === 'function') {
setActiveProjectId(projectId);
if (typeof refreshChatProjectSelector === 'function') {
refreshChatProjectSelector();
}
}
if (!conversationId) { if (!conversationId) {
return; return;
} }
@@ -50,7 +57,7 @@ function initRouter() {
if (hash) { if (hash) {
const hashParts = hash.split('?'); const hashParts = hash.split('?');
const pageId = hashParts[0]; const pageId = hashParts[0];
if (pageId && ['dashboard', 'chat', 'hitl', 'info-collect', 'vulnerabilities', 'webshell', 'chat-files', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'roles-management', 'skills-monitor', 'skills-management', 'agents-management', 'settings', 'tasks', 'c2', 'c2-listeners', 'c2-sessions', 'c2-tasks', 'c2-payloads', 'c2-events', 'c2-profiles'].includes(pageId)) { if (pageId && ['dashboard', 'chat', 'hitl', 'info-collect', 'projects', 'vulnerabilities', 'webshell', 'chat-files', 'mcp-monitor', 'mcp-management', 'knowledge-management', 'knowledge-retrieval-logs', 'roles-management', 'skills-monitor', 'skills-management', 'agents-management', 'settings', 'tasks', 'c2', 'c2-listeners', 'c2-sessions', 'c2-tasks', 'c2-payloads', 'c2-events', 'c2-profiles'].includes(pageId)) {
switchPage(pageId); switchPage(pageId);
if (pageId === 'chat') { if (pageId === 'chat') {
scheduleChatConversationFromHash(500); scheduleChatConversationFromHash(500);
@@ -187,6 +194,24 @@ function updateNavState(pageId) {
} }
} }
/** 读取侧栏子菜单项(仅 .nav-submenu 内,避免误匹配) */
function getNavSubmenuItems(navItem) {
if (!navItem) return [];
const submenu = navItem.querySelector('.nav-submenu');
if (!submenu) return [];
return Array.from(submenu.querySelectorAll('.nav-submenu-item'));
}
/** 仅一个子页时直接进入,避免展开后菜单在侧栏底部不可见 */
function navigateSingleSubmenuPage(navItem) {
const items = getNavSubmenuItems(navItem);
if (items.length !== 1) return false;
const pageId = items[0].getAttribute('data-page');
if (!pageId) return false;
switchPage(pageId);
return true;
}
// 切换子菜单 // 切换子菜单
function toggleSubmenu(menuId) { function toggleSubmenu(menuId) {
const sidebar = document.getElementById('main-sidebar'); const sidebar = document.getElementById('main-sidebar');
@@ -194,24 +219,50 @@ function toggleSubmenu(menuId) {
if (!navItem) return; if (!navItem) return;
const collapsed = sidebar && sidebar.classList.contains('collapsed');
// 检查侧边栏是否折叠 // 检查侧边栏是否折叠
if (sidebar && sidebar.classList.contains('collapsed')) { if (collapsed) {
// 折叠状态下显示弹出菜单 // 折叠状态下显示弹出菜单
showSubmenuPopup(navItem, menuId); showSubmenuPopup(navItem, menuId);
} else { return;
// 展开状态下正常切换子菜单 }
navItem.classList.toggle('expanded');
// 展开侧栏且仅一个子项(角色、Agents 等):单击直接进入,无需再点二级菜单
if (navigateSingleSubmenuPage(navItem)) {
return;
}
// 展开状态下切换子菜单,并滚入视口以便看到子项
const willExpand = !navItem.classList.contains('expanded');
navItem.classList.toggle('expanded');
if (willExpand) {
requestAnimationFrame(() => {
navItem.scrollIntoView({ block: 'nearest', behavior: 'smooth' });
const items = getNavSubmenuItems(navItem);
const last = items[items.length - 1];
if (last) {
last.scrollIntoView({ block: 'nearest', behavior: 'smooth' });
}
});
} }
} }
window.toggleSubmenu = toggleSubmenu; window.toggleSubmenu = toggleSubmenu;
// 显示子菜单弹出框 // 显示子菜单弹出框
function showSubmenuPopup(navItem, menuId) { function showSubmenuPopup(navItem, menuId) {
// 移除其他已打开的弹出菜单
const existingPopup = document.querySelector('.submenu-popup'); const existingPopup = document.querySelector('.submenu-popup');
if (existingPopup) { if (existingPopup) {
const sameMenu = existingPopup.dataset.menuId === menuId;
existingPopup.remove(); existingPopup.remove();
return; // 如果已经打开,点击时关闭 // 再次点击同一项:仅关闭;点击另一项:继续打开新菜单
if (sameMenu) {
return;
}
}
if (navigateSingleSubmenuPage(navItem)) {
return;
} }
const navItemContent = navItem.querySelector('.nav-item-content'); const navItemContent = navItem.querySelector('.nav-item-content');
@@ -225,6 +276,7 @@ function showSubmenuPopup(navItem, menuId) {
// 创建弹出菜单 // 创建弹出菜单
const popup = document.createElement('div'); const popup = document.createElement('div');
popup.className = 'submenu-popup'; popup.className = 'submenu-popup';
popup.dataset.menuId = menuId;
popup.style.position = 'fixed'; popup.style.position = 'fixed';
popup.style.left = (rect.right + 8) + 'px'; popup.style.left = (rect.right + 8) + 'px';
popup.style.top = rect.top + 'px'; popup.style.top = rect.top + 'px';
@@ -289,6 +341,12 @@ async function initPage(pageId) {
case 'chat': case 'chat':
// 恢复对话列表折叠状态(从其他页返回时保持用户选择) // 恢复对话列表折叠状态(从其他页返回时保持用户选择)
initConversationSidebarState(); initConversationSidebarState();
if (typeof prefetchProjectsForChat === 'function') {
prefetchProjectsForChat();
}
if (typeof refreshChatProjectSelector === 'function') {
refreshChatProjectSelector();
}
break; break;
case 'hitl': case 'hitl':
if (typeof refreshHitlPending === 'function') { if (typeof refreshHitlPending === 'function') {
@@ -348,6 +406,11 @@ async function initPage(pageId) {
}); });
} }
break; break;
case 'projects':
if (typeof initProjectsPage === 'function') {
initProjectsPage();
}
break;
case 'vulnerabilities': case 'vulnerabilities':
// 初始化漏洞管理页面 // 初始化漏洞管理页面
if (typeof initVulnerabilityPage === 'function') { if (typeof initVulnerabilityPage === 'function') {
+67 -7
View File
@@ -31,6 +31,19 @@ let toolsPagination = {
let c2NavSyncedOnce = false; let c2NavSyncedOnce = false;
/** 根据是否启用多代理,禁用/启用机器人模式中的 Eino 编排选项 */
function syncRobotAgentModeSelectOptions(multiEnabled) {
const sel = document.getElementById('multi-agent-robot-mode');
if (!sel) return;
['deep', 'plan_execute', 'supervisor'].forEach(function (v) {
const opt = sel.querySelector('option[value="' + v + '"]');
if (opt) opt.disabled = !multiEnabled;
});
if (!multiEnabled && ['deep', 'plan_execute', 'supervisor'].indexOf(sel.value) >= 0) {
sel.value = 'react';
}
}
/** 首次进入仪表盘等页面前拉一次配置,隐藏侧栏 C2(避免禁用后仍显示) */ /** 首次进入仪表盘等页面前拉一次配置,隐藏侧栏 C2(避免禁用后仍显示) */
window.syncC2NavOnceFromServer = async function syncC2NavOnceFromServer() { window.syncC2NavOnceFromServer = async function syncC2NavOnceFromServer() {
if (c2NavSyncedOnce || typeof apiFetch === 'undefined') { if (c2NavSyncedOnce || typeof apiFetch === 'undefined') {
@@ -87,6 +100,9 @@ function switchSettingsSection(section) {
if (section === 'terminal' && typeof initTerminal === 'function') { if (section === 'terminal' && typeof initTerminal === 'function') {
setTimeout(initTerminal, 0); setTimeout(initTerminal, 0);
} }
if (section === 'audit' && typeof initAuditLogsSection === 'function') {
setTimeout(initAuditLogsSection, 0);
}
} }
// 打开设置 // 打开设置
@@ -168,7 +184,7 @@ async function loadConfig(loadTools = true) {
const orEffEl = document.getElementById('openai-reasoning-effort'); const orEffEl = document.getElementById('openai-reasoning-effort');
if (orEffEl) { if (orEffEl) {
const ev = (orm.effort || '').toString().trim().toLowerCase(); const ev = (orm.effort || '').toString().trim().toLowerCase();
orEffEl.value = ['', 'low', 'medium', 'high', 'max'].includes(ev) ? ev : ''; orEffEl.value = ['', 'low', 'medium', 'high', 'max', 'xhigh'].includes(ev) ? ev : '';
} }
const orProfEl = document.getElementById('openai-reasoning-profile'); const orProfEl = document.getElementById('openai-reasoning-profile');
if (orProfEl) { if (orProfEl) {
@@ -195,14 +211,27 @@ async function loadConfig(loadTools = true) {
const ma = currentConfig.multi_agent || {}; const ma = currentConfig.multi_agent || {};
const maEn = document.getElementById('multi-agent-enabled'); const maEn = document.getElementById('multi-agent-enabled');
if (maEn) maEn.checked = ma.enabled === true; if (maEn) {
maEn.checked = ma.enabled === true;
if (!maEn.dataset.robotModeBound) {
maEn.dataset.robotModeBound = '1';
maEn.addEventListener('change', function () {
syncRobotAgentModeSelectOptions(maEn.checked);
});
}
}
const maPeLoop = document.getElementById('multi-agent-pe-loop'); const maPeLoop = document.getElementById('multi-agent-pe-loop');
if (maPeLoop) { if (maPeLoop) {
const v = ma.plan_execute_loop_max_iterations; const v = ma.plan_execute_loop_max_iterations;
maPeLoop.value = (v !== undefined && v !== null && !Number.isNaN(Number(v))) ? String(Number(v)) : '0'; maPeLoop.value = (v !== undefined && v !== null && !Number.isNaN(Number(v))) ? String(Number(v)) : '0';
} }
const maRobot = document.getElementById('multi-agent-robot-use'); const maRobotMode = document.getElementById('multi-agent-robot-mode');
if (maRobot) maRobot.checked = ma.robot_use_multi_agent === true; if (maRobotMode) {
let mode = (ma.robot_default_agent_mode || 'react').trim().toLowerCase();
if (mode === 'single') mode = 'react';
maRobotMode.value = mode;
syncRobotAgentModeSelectOptions(ma.enabled === true);
}
// 填充知识库配置 // 填充知识库配置
const knowledgeEnabledCheckbox = document.getElementById('knowledge-enabled'); const knowledgeEnabledCheckbox = document.getElementById('knowledge-enabled');
@@ -339,9 +368,23 @@ async function loadConfig(loadTools = true) {
// 填充机器人配置 // 填充机器人配置
const robots = currentConfig.robots || {}; const robots = currentConfig.robots || {};
const wechat = robots.wechat || {};
const wecom = robots.wecom || {}; const wecom = robots.wecom || {};
const dingtalk = robots.dingtalk || {}; const dingtalk = robots.dingtalk || {};
const lark = robots.lark || {}; const lark = robots.lark || {};
const wechatEnabled = document.getElementById('robot-wechat-enabled');
if (wechatEnabled) wechatEnabled.checked = wechat.enabled === true;
const wechatBase = document.getElementById('robot-wechat-base-url');
if (wechatBase) wechatBase.value = wechat.base_url || 'https://ilinkai.weixin.qq.com';
const wechatBotType = document.getElementById('robot-wechat-bot-type');
if (wechatBotType) wechatBotType.value = wechat.bot_type || '3';
const wechatBotAgent = document.getElementById('robot-wechat-bot-agent');
if (wechatBotAgent) wechatBotAgent.value = wechat.bot_agent || 'CyberStrikeAI/1.0';
const wechatBotId = document.getElementById('robot-wechat-ilink-bot-id');
if (wechatBotId) wechatBotId.value = wechat.ilink_bot_id || '';
if (typeof refreshWechatRobotBoundUI === 'function') {
refreshWechatRobotBoundUI({ ...wechat, bound: !!(wechat.bot_token && wechat.ilink_bot_id) });
}
const wecomEnabled = document.getElementById('robot-wecom-enabled'); const wecomEnabled = document.getElementById('robot-wecom-enabled');
if (wecomEnabled) wecomEnabled.checked = wecom.enabled === true; if (wecomEnabled) wecomEnabled.checked = wecom.enabled === true;
const wecomToken = document.getElementById('robot-wecom-token'); const wecomToken = document.getElementById('robot-wecom-token');
@@ -1116,9 +1159,14 @@ async function applySettings() {
const peRaw = document.getElementById('multi-agent-pe-loop')?.value; const peRaw = document.getElementById('multi-agent-pe-loop')?.value;
const peParsed = parseInt(peRaw, 10); const peParsed = parseInt(peRaw, 10);
const peLoop = Number.isNaN(peParsed) ? 0 : Math.max(0, peParsed); const peLoop = Number.isNaN(peParsed) ? 0 : Math.max(0, peParsed);
const maEnabled = document.getElementById('multi-agent-enabled')?.checked === true;
let robotMode = document.getElementById('multi-agent-robot-mode')?.value || 'react';
if (!maEnabled && ['deep', 'plan_execute', 'supervisor'].indexOf(robotMode) >= 0) {
robotMode = 'react';
}
return { return {
enabled: document.getElementById('multi-agent-enabled')?.checked === true, enabled: maEnabled,
robot_use_multi_agent: document.getElementById('multi-agent-robot-use')?.checked === true, robot_default_agent_mode: robotMode,
batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true, batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true,
plan_execute_loop_max_iterations: peLoop plan_execute_loop_max_iterations: peLoop
}; };
@@ -1129,6 +1177,18 @@ async function applySettings() {
}, },
robots: { robots: {
...(prevRobots.session && typeof prevRobots.session === 'object' ? { session: prevRobots.session } : {}), ...(prevRobots.session && typeof prevRobots.session === 'object' ? { session: prevRobots.session } : {}),
wechat: {
enabled: document.getElementById('robot-wechat-enabled')?.checked === true,
base_url: document.getElementById('robot-wechat-base-url')?.value.trim() || 'https://ilinkai.weixin.qq.com',
bot_type: document.getElementById('robot-wechat-bot-type')?.value.trim() || '3',
bot_agent: document.getElementById('robot-wechat-bot-agent')?.value.trim() || 'CyberStrikeAI/1.0',
ilink_bot_id: document.getElementById('robot-wechat-ilink-bot-id')?.value.trim() || (prevRobots.wechat && prevRobots.wechat.ilink_bot_id) || '',
...(prevRobots.wechat && typeof prevRobots.wechat === 'object' ? {
bot_token: prevRobots.wechat.bot_token || '',
ilink_user_id: prevRobots.wechat.ilink_user_id || '',
get_updates_buf: prevRobots.wechat.get_updates_buf || ''
} : {})
},
wecom: { wecom: {
enabled: document.getElementById('robot-wecom-enabled')?.checked === true, enabled: document.getElementById('robot-wecom-enabled')?.checked === true,
token: document.getElementById('robot-wecom-token')?.value.trim() || '', token: document.getElementById('robot-wecom-token')?.value.trim() || '',
@@ -1355,7 +1415,7 @@ async function saveToolsConfig() {
agent: currentConfig.agent || {}, agent: currentConfig.agent || {},
multi_agent: { multi_agent: {
enabled: currentConfig?.multi_agent?.enabled === true, enabled: currentConfig?.multi_agent?.enabled === true,
robot_use_multi_agent: currentConfig?.multi_agent?.robot_use_multi_agent === true, robot_default_agent_mode: currentConfig?.multi_agent?.robot_default_agent_mode || 'react',
batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true, batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true,
plan_execute_loop_max_iterations: Number(currentConfig?.multi_agent?.plan_execute_loop_max_iterations || 0), plan_execute_loop_max_iterations: Number(currentConfig?.multi_agent?.plan_execute_loop_max_iterations || 0),
tool_search_always_visible_tools: Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleBuiltinToolNames.has(name)) tool_search_always_visible_tools: Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleBuiltinToolNames.has(name))
+10 -1
View File
@@ -979,7 +979,16 @@ async function createBatchQueue() {
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
body: JSON.stringify({ title, tasks, role, agentMode, scheduleMode, cronExpr, executeNow }), body: JSON.stringify({
title,
tasks,
role,
agentMode,
scheduleMode,
cronExpr,
executeNow,
projectId: typeof getActiveProjectId === 'function' ? getActiveProjectId() || '' : '',
}),
}); });
if (!response.ok) { if (!response.ok) {
+339 -82
View File
@@ -46,7 +46,9 @@ const getVulnerabilityPageSize = () => {
let currentVulnerabilityId = null; let currentVulnerabilityId = null;
let vulnerabilityFilters = { let vulnerabilityFilters = {
q: '',
id: '', id: '',
project_id: '',
conversation_id: '', conversation_id: '',
task_id: '', task_id: '',
conversation_tag: '', conversation_tag: '',
@@ -65,14 +67,18 @@ const VULN_STAT_SEVERITIES = ['critical', 'high', 'medium', 'low', 'info'];
let vulnerabilityStatCardsBound = false; let vulnerabilityStatCardsBound = false;
let vulnerabilityFilterPanelBound = false; let vulnerabilityFilterPanelBound = false;
let vulnerabilityFilterOptionsCache = null; let vulnerabilityFilterOptionsCache = null;
const VULNERABILITY_ADVANCED_OPEN_KEY = 'vulnerabilityAdvancedFiltersOpen'; let vulnerabilityMoreFiltersPopoverOpen = false;
let vulnerabilityFilterDebounceTimer = null;
const VULNERABILITY_FILTER_DEBOUNCE_MS = 400;
const VULNERABILITY_DATALIST_MAX = 8; const VULNERABILITY_DATALIST_MAX = 8;
const VULNERABILITY_DATALIST_MIN_QUERY = 2; const VULNERABILITY_DATALIST_MIN_QUERY = 2;
const VULN_FILTER_CHIP_FIELDS = [ const VULN_FILTER_CHIP_FIELDS = [
{ key: 'q', labelKey: 'vulnerabilityPage.searchKeywordShort' },
{ key: 'id', labelKey: 'vulnerabilityPage.vulnId' }, { key: 'id', labelKey: 'vulnerabilityPage.vulnId' },
{ key: 'status', labelKey: null, format: 'status' }, { key: 'status', labelKey: null, format: 'status' },
{ key: 'severity', labelKey: null, format: 'severity' }, { key: 'severity', labelKey: null, format: 'severity' },
{ key: 'project_id', labelKey: null },
{ key: 'conversation_id', labelKey: 'vulnerabilityPage.conversationId' }, { key: 'conversation_id', labelKey: 'vulnerabilityPage.conversationId' },
{ key: 'task_id', labelKey: 'vulnerabilityPage.taskOrQueueId' }, { key: 'task_id', labelKey: 'vulnerabilityPage.taskOrQueueId' },
{ key: 'conversation_tag', labelKey: 'vulnerabilityPage.conversationTag' }, { key: 'conversation_tag', labelKey: 'vulnerabilityPage.conversationTag' },
@@ -94,35 +100,51 @@ function syncVulnerabilityFiltersFromLocationHash() {
const st = (params.get('status') || '').trim(); const st = (params.get('status') || '').trim();
const convTag = (params.get('conversation_tag') || '').trim(); const convTag = (params.get('conversation_tag') || '').trim();
const taskTag = (params.get('task_tag') || '').trim(); const taskTag = (params.get('task_tag') || '').trim();
if (!vid && !cid && !tid && !sev && !st && !convTag && !taskTag) { const pid = (params.get('project_id') || '').trim();
const q = (params.get('q') || params.get('search') || '').trim();
if (!vid && !cid && !tid && !sev && !st && !convTag && !taskTag && !q && !pid) {
return; return;
} }
vulnerabilityFilters.q = '';
vulnerabilityFilters.id = ''; vulnerabilityFilters.id = '';
vulnerabilityFilters.project_id = '';
vulnerabilityFilters.conversation_id = ''; vulnerabilityFilters.conversation_id = '';
vulnerabilityFilters.task_id = ''; vulnerabilityFilters.task_id = '';
vulnerabilityFilters.conversation_tag = ''; vulnerabilityFilters.conversation_tag = '';
vulnerabilityFilters.task_tag = ''; vulnerabilityFilters.task_tag = '';
vulnerabilityFilters.severity = ''; vulnerabilityFilters.severity = '';
vulnerabilityFilters.status = ''; vulnerabilityFilters.status = '';
const idEl = document.getElementById('vulnerability-id-filter'); const searchEl = document.getElementById('vulnerability-search-filter');
const exactIdEl = document.getElementById('vulnerability-exact-id-filter');
const convEl = document.getElementById('vulnerability-conversation-filter'); const convEl = document.getElementById('vulnerability-conversation-filter');
const taskEl = document.getElementById('vulnerability-task-filter'); const taskEl = document.getElementById('vulnerability-task-filter');
const convTagEl = document.getElementById('vulnerability-conversation-tag-filter'); const convTagEl = document.getElementById('vulnerability-conversation-tag-filter');
const taskTagEl = document.getElementById('vulnerability-task-tag-filter'); const taskTagEl = document.getElementById('vulnerability-task-tag-filter');
const projEl = document.getElementById('vulnerability-project-filter');
const sevEl = document.getElementById('vulnerability-severity-filter'); const sevEl = document.getElementById('vulnerability-severity-filter');
const stEl = document.getElementById('vulnerability-status-filter'); const stEl = document.getElementById('vulnerability-status-filter');
if (idEl) idEl.value = ''; if (searchEl) searchEl.value = '';
if (exactIdEl) exactIdEl.value = '';
if (convEl) convEl.value = ''; if (convEl) convEl.value = '';
if (taskEl) taskEl.value = ''; if (taskEl) taskEl.value = '';
if (convTagEl) convTagEl.value = ''; if (convTagEl) convTagEl.value = '';
if (taskTagEl) taskTagEl.value = ''; if (taskTagEl) taskTagEl.value = '';
if (projEl) projEl.value = '';
if (sevEl) sevEl.value = ''; if (sevEl) sevEl.value = '';
if (stEl) stEl.value = ''; if (stEl) stEl.value = '';
if (q) {
vulnerabilityFilters.q = q;
if (searchEl) searchEl.value = q;
}
if (pid) {
vulnerabilityFilters.project_id = pid;
if (projEl) projEl.value = pid;
}
if (vid) { if (vid) {
vulnerabilityFilters.id = vid; vulnerabilityFilters.id = vid;
if (idEl) idEl.value = vid; if (exactIdEl) exactIdEl.value = vid;
} }
if (cid) { if (cid) {
vulnerabilityFilters.conversation_id = cid; vulnerabilityFilters.conversation_id = cid;
@@ -149,21 +171,19 @@ function syncVulnerabilityFiltersFromLocationHash() {
if (stEl) stEl.value = st; if (stEl) stEl.value = st;
} }
vulnerabilityPagination.currentPage = 1; vulnerabilityPagination.currentPage = 1;
if (hasVulnerabilityAdvancedFiltersActive()) {
setVulnerabilityAdvancedFiltersOpen(true, false);
}
syncVulnerabilityStatCardActiveState(); syncVulnerabilityStatCardActiveState();
updateVulnerabilityFilterPanelState(); updateVulnerabilityFilterPanelState();
renderVulnerabilityFilterChips(); renderVulnerabilityFilterChips();
} }
// 初始化漏洞管理页面 // 初始化漏洞管理页面
function initVulnerabilityPage() { async function initVulnerabilityPage() {
// 从localStorage加载每页条数设置 // 从localStorage加载每页条数设置
vulnerabilityPagination.pageSize = getVulnerabilityPageSize(); vulnerabilityPagination.pageSize = getVulnerabilityPageSize();
initVulnerabilityStatCards(); initVulnerabilityStatCards();
initVulnerabilityFilterPanel(); initVulnerabilityFilterPanel();
syncVulnerabilityFiltersFromLocationHash(); syncVulnerabilityFiltersFromLocationHash();
await refreshVulnerabilityProjectFilter();
updateVulnerabilityFilterPanelState(); updateVulnerabilityFilterPanelState();
renderVulnerabilityFilterChips(); renderVulnerabilityFilterChips();
loadVulnerabilityFilterOptions(); loadVulnerabilityFilterOptions();
@@ -213,7 +233,9 @@ function applyVulnerabilitySeverityFilter(severity) {
} }
function readVulnerabilityFiltersFromForm() { function readVulnerabilityFiltersFromForm() {
vulnerabilityFilters.id = (document.getElementById('vulnerability-id-filter')?.value || '').trim(); vulnerabilityFilters.q = (document.getElementById('vulnerability-search-filter')?.value || '').trim();
vulnerabilityFilters.id = (document.getElementById('vulnerability-exact-id-filter')?.value || '').trim();
vulnerabilityFilters.project_id = (document.getElementById('vulnerability-project-filter')?.value || '').trim();
vulnerabilityFilters.conversation_id = (document.getElementById('vulnerability-conversation-filter')?.value || '').trim(); vulnerabilityFilters.conversation_id = (document.getElementById('vulnerability-conversation-filter')?.value || '').trim();
vulnerabilityFilters.task_id = (document.getElementById('vulnerability-task-filter')?.value || '').trim(); vulnerabilityFilters.task_id = (document.getElementById('vulnerability-task-filter')?.value || '').trim();
vulnerabilityFilters.conversation_tag = (document.getElementById('vulnerability-conversation-tag-filter')?.value || '').trim(); vulnerabilityFilters.conversation_tag = (document.getElementById('vulnerability-conversation-tag-filter')?.value || '').trim();
@@ -225,13 +247,13 @@ function readVulnerabilityFiltersFromForm() {
function hasVulnerabilityAdvancedFiltersActive() { function hasVulnerabilityAdvancedFiltersActive() {
const f = vulnerabilityFilters; const f = vulnerabilityFilters;
return Boolean(f.conversation_id || f.task_id || f.conversation_tag || f.task_tag); return Boolean(f.id || f.conversation_id || f.task_id || f.conversation_tag || f.task_tag);
} }
function hasAnyVulnerabilityFilterActive() { function hasAnyVulnerabilityFilterActive() {
const f = vulnerabilityFilters; const f = vulnerabilityFilters;
return Boolean( return Boolean(
f.id || f.conversation_id || f.task_id || f.conversation_tag || f.task_tag || f.severity || f.status f.q || f.id || f.project_id || f.conversation_id || f.task_id || f.conversation_tag || f.task_tag || f.severity || f.status
); );
} }
@@ -253,7 +275,9 @@ function updateVulnerabilityLocationHashFromFilters() {
const params = new URLSearchParams(hashParts.length >= 2 ? hashParts.slice(1).join('?') : ''); const params = new URLSearchParams(hashParts.length >= 2 ? hashParts.slice(1).join('?') : '');
const f = vulnerabilityFilters; const f = vulnerabilityFilters;
const pairs = [ const pairs = [
['q', f.q],
['id', f.id], ['id', f.id],
['project_id', f.project_id],
['conversation_id', f.conversation_id], ['conversation_id', f.conversation_id],
['task_id', f.task_id], ['task_id', f.task_id],
['conversation_tag', f.conversation_tag], ['conversation_tag', f.conversation_tag],
@@ -279,17 +303,82 @@ function updateVulnerabilityLocationHashFromFilters() {
} }
} }
function toggleVulnerabilityAdvancedFilters(ev) { function scheduleVulnerabilityFilterApply(immediate) {
if (vulnerabilityFilterDebounceTimer) {
clearTimeout(vulnerabilityFilterDebounceTimer);
vulnerabilityFilterDebounceTimer = null;
}
if (immediate) {
applyVulnerabilityFilters();
return;
}
vulnerabilityFilterDebounceTimer = setTimeout(function () {
vulnerabilityFilterDebounceTimer = null;
applyVulnerabilityFilters();
}, VULNERABILITY_FILTER_DEBOUNCE_MS);
}
function syncVulnerabilityAdvancedFieldsFromFilters() {
const map = {
id: 'vulnerability-exact-id-filter',
conversation_id: 'vulnerability-conversation-filter',
task_id: 'vulnerability-task-filter',
conversation_tag: 'vulnerability-conversation-tag-filter',
task_tag: 'vulnerability-task-tag-filter'
};
Object.keys(map).forEach(function (key) {
const el = document.getElementById(map[key]);
if (el) el.value = vulnerabilityFilters[key] || '';
});
}
function setVulnerabilityMoreFiltersPopoverOpen(open) {
const btn = document.getElementById('vulnerability-more-filters-btn');
const popover = document.getElementById('vulnerability-more-filters-popover');
if (!btn || !popover) return;
vulnerabilityMoreFiltersPopoverOpen = open;
btn.setAttribute('aria-expanded', open ? 'true' : 'false');
btn.classList.toggle('is-active', open);
popover.hidden = !open;
}
function toggleVulnerabilityMoreFiltersPopover(ev) {
if (ev) { if (ev) {
ev.preventDefault(); ev.preventDefault();
ev.stopPropagation(); ev.stopPropagation();
} }
const toggleBtn = document.getElementById('vulnerability-advanced-toggle'); const opening = !vulnerabilityMoreFiltersPopoverOpen;
if (!toggleBtn) return; if (opening) {
const expanded = toggleBtn.getAttribute('aria-expanded') === 'true'; readVulnerabilityFiltersFromForm();
setVulnerabilityAdvancedFiltersOpen(!expanded, true); syncVulnerabilityAdvancedFieldsFromFilters();
}
setVulnerabilityMoreFiltersPopoverOpen(opening);
}
function closeVulnerabilityMoreFiltersPopover(revertDraft) {
if (revertDraft) syncVulnerabilityAdvancedFieldsFromFilters();
setVulnerabilityMoreFiltersPopoverOpen(false);
}
function clearVulnerabilityAdvancedFilterFields() {
['vulnerability-exact-id-filter', 'vulnerability-conversation-filter', 'vulnerability-task-filter',
'vulnerability-conversation-tag-filter', 'vulnerability-task-tag-filter'].forEach(function (id) {
const el = document.getElementById(id);
if (el) el.value = '';
});
}
function applyVulnerabilityMoreFiltersFromPopover() {
closeVulnerabilityMoreFiltersPopover(false);
scheduleVulnerabilityFilterApply(true);
}
function onVulnerabilityFilterDocumentPointerDown(ev) {
if (!vulnerabilityMoreFiltersPopoverOpen) return;
const anchor = document.querySelector('.vulnerability-more-filters-anchor');
if (anchor && anchor.contains(ev.target)) return;
closeVulnerabilityMoreFiltersPopover(true);
} }
window.toggleVulnerabilityAdvancedFilters = toggleVulnerabilityAdvancedFilters;
function initVulnerabilityFilterPanel() { function initVulnerabilityFilterPanel() {
const panel = document.getElementById('vulnerability-filter-panel'); const panel = document.getElementById('vulnerability-filter-panel');
@@ -301,55 +390,69 @@ function initVulnerabilityFilterPanel() {
} }
vulnerabilityFilterPanelBound = true; vulnerabilityFilterPanelBound = true;
let savedOpen = false;
try {
savedOpen = localStorage.getItem(VULNERABILITY_ADVANCED_OPEN_KEY) === 'true';
} catch (e) { /* ignore */ }
setVulnerabilityAdvancedFiltersOpen(savedOpen, false);
const stEl = document.getElementById('vulnerability-status-filter'); const stEl = document.getElementById('vulnerability-status-filter');
if (stEl) stEl.addEventListener('change', applyVulnerabilityFilters); if (stEl) stEl.addEventListener('change', function () { scheduleVulnerabilityFilterApply(true); });
const textIds = [ const searchEl = document.getElementById('vulnerability-search-filter');
'vulnerability-id-filter', if (searchEl) {
searchEl.addEventListener('input', function () { scheduleVulnerabilityFilterApply(false); });
searchEl.addEventListener('keydown', function (ev) {
if (ev.key === 'Enter') {
ev.preventDefault();
scheduleVulnerabilityFilterApply(true);
}
});
searchEl.addEventListener('search', function () {
if (searchEl.value === '') scheduleVulnerabilityFilterApply(true);
});
}
const moreBtn = document.getElementById('vulnerability-more-filters-btn');
if (moreBtn) moreBtn.addEventListener('click', toggleVulnerabilityMoreFiltersPopover);
const applyBtn = document.getElementById('vulnerability-more-filters-apply');
if (applyBtn) applyBtn.addEventListener('click', applyVulnerabilityMoreFiltersFromPopover);
const resetBtn = document.getElementById('vulnerability-more-filters-reset');
if (resetBtn) {
resetBtn.addEventListener('click', function () {
clearVulnerabilityAdvancedFilterFields();
applyVulnerabilityMoreFiltersFromPopover();
});
}
const advancedTextIds = [
'vulnerability-exact-id-filter',
'vulnerability-conversation-filter', 'vulnerability-conversation-filter',
'vulnerability-task-filter', 'vulnerability-task-filter',
'vulnerability-conversation-tag-filter', 'vulnerability-conversation-tag-filter',
'vulnerability-task-tag-filter' 'vulnerability-task-tag-filter'
]; ];
textIds.forEach(function (id) { advancedTextIds.forEach(function (id) {
const el = document.getElementById(id); const el = document.getElementById(id);
if (!el) return; if (!el) return;
el.addEventListener('keydown', function (ev) { el.addEventListener('keydown', function (ev) {
if (ev.key === 'Enter') { if (ev.key === 'Enter') {
ev.preventDefault(); ev.preventDefault();
applyVulnerabilityFilters(); applyVulnerabilityMoreFiltersFromPopover();
} }
}); });
}); });
document.addEventListener('mousedown', onVulnerabilityFilterDocumentPointerDown);
document.addEventListener('keydown', function (ev) {
if (ev.key === 'Escape' && vulnerabilityMoreFiltersPopoverOpen) {
closeVulnerabilityMoreFiltersPopover(true);
}
});
bindVulnerabilityFilterTypeaheads(); bindVulnerabilityFilterTypeaheads();
} }
function setVulnerabilityAdvancedFiltersOpen(open, persist) {
const toggleBtn = document.getElementById('vulnerability-advanced-toggle');
const advanced = document.getElementById('vulnerability-advanced-filters');
const wrap = document.querySelector('#vulnerability-filter-panel .vulnerability-filter-advanced-wrap');
if (!toggleBtn || !advanced) return;
toggleBtn.setAttribute('aria-expanded', open ? 'true' : 'false');
advanced.hidden = !open;
advanced.classList.toggle('is-open', open);
if (wrap) wrap.classList.toggle('is-expanded', open);
if (persist) {
try {
localStorage.setItem(VULNERABILITY_ADVANCED_OPEN_KEY, open ? 'true' : 'false');
} catch (e) { /* ignore */ }
}
}
function countVulnerabilityAdvancedFiltersActive() { function countVulnerabilityAdvancedFiltersActive() {
const f = vulnerabilityFilters; const f = vulnerabilityFilters;
let n = 0; let n = 0;
if (f.id) n++;
if (f.conversation_id) n++; if (f.conversation_id) n++;
if (f.task_id) n++; if (f.task_id) n++;
if (f.conversation_tag) n++; if (f.conversation_tag) n++;
@@ -379,11 +482,17 @@ function updateVulnerabilityFilterPanelState() {
readVulnerabilityFiltersFromForm(); readVulnerabilityFiltersFromForm();
panel.classList.toggle('is-filtered', hasAnyVulnerabilityFilterActive()); panel.classList.toggle('is-filtered', hasAnyVulnerabilityFilterActive());
updateVulnerabilityAdvancedBadge(); updateVulnerabilityAdvancedBadge();
const clearBtn = document.getElementById('vulnerability-filter-clear-btn');
if (clearBtn) clearBtn.hidden = !hasAnyVulnerabilityFilterActive();
} }
function formatVulnerabilityFilterChipValue(key, value) { function formatVulnerabilityFilterChipValue(key, value) {
if (key === 'severity') return vulnSeverityLabel(value); if (key === 'severity') return vulnSeverityLabel(value);
if (key === 'status') return vulnStatusLabel(value); if (key === 'status') return vulnStatusLabel(value);
if (key === 'project_id') {
const name = typeof getProjectName === 'function' ? getProjectName(value) : '';
return name && name !== value ? name : value;
}
return value; return value;
} }
@@ -397,7 +506,7 @@ function renderVulnerabilityFilterChips() {
VULN_FILTER_CHIP_FIELDS.forEach(function (field) { VULN_FILTER_CHIP_FIELDS.forEach(function (field) {
const val = vulnerabilityFilters[field.key]; const val = vulnerabilityFilters[field.key];
if (!val) return; if (!val) return;
const label = field.labelKey ? vulnT(field.labelKey) : ''; const label = field.labelKey ? vulnT(field.labelKey) : (field.key === 'project_id' ? '项目' : '');
const displayVal = formatVulnerabilityFilterChipValue(field.key, val); const displayVal = formatVulnerabilityFilterChipValue(field.key, val);
const text = label ? label + ': ' + displayVal : displayVal; const text = label ? label + ': ' + displayVal : displayVal;
chips.push({ key: field.key, text: text }); chips.push({ key: field.key, text: text });
@@ -431,11 +540,13 @@ function renderVulnerabilityFilterChips() {
function removeVulnerabilityFilterByKey(key) { function removeVulnerabilityFilterByKey(key) {
const map = { const map = {
id: 'vulnerability-id-filter', q: 'vulnerability-search-filter',
id: 'vulnerability-exact-id-filter',
conversation_id: 'vulnerability-conversation-filter', conversation_id: 'vulnerability-conversation-filter',
task_id: 'vulnerability-task-filter', task_id: 'vulnerability-task-filter',
conversation_tag: 'vulnerability-conversation-tag-filter', conversation_tag: 'vulnerability-conversation-tag-filter',
task_tag: 'vulnerability-task-tag-filter', task_tag: 'vulnerability-task-tag-filter',
project_id: 'vulnerability-project-filter',
severity: 'vulnerability-severity-filter', severity: 'vulnerability-severity-filter',
status: 'vulnerability-status-filter' status: 'vulnerability-status-filter'
}; };
@@ -609,13 +720,7 @@ async function loadVulnerabilityStats() {
throw new Error('apiFetch未定义'); throw new Error('apiFetch未定义');
} }
const params = new URLSearchParams(); const params = buildVulnerabilityFilterParams();
if (vulnerabilityFilters.conversation_id) {
params.append('conversation_id', vulnerabilityFilters.conversation_id);
}
if (vulnerabilityFilters.task_id) {
params.append('task_id', vulnerabilityFilters.task_id);
}
const response = await apiFetch(`/api/vulnerabilities/stats?${params.toString()}`); const response = await apiFetch(`/api/vulnerabilities/stats?${params.toString()}`);
if (!response.ok) { if (!response.ok) {
@@ -689,31 +794,9 @@ async function loadVulnerabilities(page = null) {
vulnerabilityPagination.currentPage = page; vulnerabilityPagination.currentPage = page;
} }
const params = new URLSearchParams(); const params = buildVulnerabilityFilterParams();
params.append('page', vulnerabilityPagination.currentPage.toString()); params.append('page', vulnerabilityPagination.currentPage.toString());
params.append('limit', vulnerabilityPagination.pageSize.toString()); params.append('limit', vulnerabilityPagination.pageSize.toString());
if (vulnerabilityFilters.id) {
params.append('id', vulnerabilityFilters.id);
}
if (vulnerabilityFilters.conversation_id) {
params.append('conversation_id', vulnerabilityFilters.conversation_id);
}
if (vulnerabilityFilters.task_id) {
params.append('task_id', vulnerabilityFilters.task_id);
}
if (vulnerabilityFilters.conversation_tag) {
params.append('conversation_tag', vulnerabilityFilters.conversation_tag);
}
if (vulnerabilityFilters.task_tag) {
params.append('task_tag', vulnerabilityFilters.task_tag);
}
if (vulnerabilityFilters.severity) {
params.append('severity', vulnerabilityFilters.severity);
}
if (vulnerabilityFilters.status) {
params.append('status', vulnerabilityFilters.status);
}
const response = await apiFetch(`/api/vulnerabilities?${params.toString()}`); const response = await apiFetch(`/api/vulnerabilities?${params.toString()}`);
if (!response.ok) { if (!response.ok) {
@@ -785,6 +868,12 @@ function renderVulnerabilities(vulnerabilities) {
const severityText = vulnSeverityLabel(vuln.severity); const severityText = vulnSeverityLabel(vuln.severity);
const statusText = vulnStatusLabel(vuln.status); const statusText = vulnStatusLabel(vuln.status);
const createdDate = new Date(vuln.created_at).toLocaleString(vulnDateLocale()); const createdDate = new Date(vuln.created_at).toLocaleString(vulnDateLocale());
const projectLabel = vuln.project_id
? escapeHtml(typeof getProjectName === 'function' ? getProjectName(vuln.project_id) : vuln.project_id)
: escapeHtml(vulnT('vulnerabilityPage.projectUnbound'));
const projectBadge = vuln.project_id
? `<span class="vulnerability-project-badge" title="${escapeHtml(vuln.project_id)}">${escapeHtml(vulnT('vulnerabilityPage.detailProject'))}: ${projectLabel}</span>`
: `<span class="vulnerability-project-badge vulnerability-project-badge--unbound">${escapeHtml(vulnT('vulnerabilityPage.projectUnbound'))}</span>`;
const dlTitle = escapeHtml(vulnT('vulnerabilityPage.downloadMarkdownTitle')); const dlTitle = escapeHtml(vulnT('vulnerabilityPage.downloadMarkdownTitle'));
const editTitle = escapeHtml(vulnT('common.edit')); const editTitle = escapeHtml(vulnT('common.edit'));
const deleteTitle = escapeHtml(vulnT('common.delete')); const deleteTitle = escapeHtml(vulnT('common.delete'));
@@ -802,6 +891,7 @@ function renderVulnerabilities(vulnerabilities) {
<div class="vulnerability-meta"> <div class="vulnerability-meta">
<span class="severity-badge ${severityClass}">${severityText}</span> <span class="severity-badge ${severityClass}">${severityText}</span>
<span class="status-badge status-${vuln.status}">${statusText}</span> <span class="status-badge status-${vuln.status}">${statusText}</span>
${projectBadge}
<span class="vulnerability-date">${createdDate}</span> <span class="vulnerability-date">${createdDate}</span>
</div> </div>
</div> </div>
@@ -830,6 +920,7 @@ function renderVulnerabilities(vulnerabilities) {
${vuln.description ? `<div class="vulnerability-description">${escapeHtml(vuln.description)}</div>` : ''} ${vuln.description ? `<div class="vulnerability-description">${escapeHtml(vuln.description)}</div>` : ''}
<div class="vulnerability-details"> <div class="vulnerability-details">
${vulnDetailField(vulnT('vulnerabilityPage.detailVulnId'), vuln.id, true)} ${vulnDetailField(vulnT('vulnerabilityPage.detailVulnId'), vuln.id, true)}
${vulnDetailProjectField(vuln)}
${vuln.type ? vulnDetailField(vulnT('vulnerabilityPage.detailType'), vuln.type, false) : ''} ${vuln.type ? vulnDetailField(vulnT('vulnerabilityPage.detailType'), vuln.type, false) : ''}
${vuln.target ? vulnDetailField(vulnT('vulnerabilityPage.detailTarget'), vuln.target, false) : ''} ${vuln.target ? vulnDetailField(vulnT('vulnerabilityPage.detailTarget'), vuln.target, false) : ''}
${vulnDetailField(vulnT('vulnerabilityPage.detailConversationId'), vuln.conversation_id, true)} ${vulnDetailField(vulnT('vulnerabilityPage.detailConversationId'), vuln.conversation_id, true)}
@@ -940,11 +1031,50 @@ async function changeVulnerabilityPageSize() {
await loadVulnerabilities(); await loadVulnerabilities();
} }
function buildVulnerabilityProjectOptionsHtml(selectedId) {
const sel = (selectedId || '').trim();
let html = `<option value="">${escapeHtml(vulnT('vulnerabilityModal.projectNone'))}</option>`;
const entries = typeof projectNameById !== 'undefined' ? Object.entries(projectNameById) : [];
entries.sort((a, b) => (a[1] || '').localeCompare(b[1] || '', undefined, { sensitivity: 'base' }));
entries.forEach(([id, name]) => {
if (!id) return;
const selected = id === sel ? ' selected' : '';
html += `<option value="${escapeHtml(id)}"${selected}>${escapeHtml(name || id)}</option>`;
});
if (sel && !entries.some(([id]) => id === sel)) {
html += `<option value="${escapeHtml(sel)}" selected>${escapeHtml(sel)}</option>`;
}
return html;
}
async function populateVulnerabilityModalProjectSelect(selectedId) {
const sel = document.getElementById('vulnerability-project-id');
if (!sel) return;
try {
const res = await apiFetch('/api/projects?limit=200');
if (res.ok) {
const list = await res.json();
if (typeof rebuildProjectNameMap === 'function') {
rebuildProjectNameMap(list);
} else if (typeof projectNameById !== 'undefined') {
(list || []).forEach((p) => { if (p.id) projectNameById[p.id] = p.name || p.id; });
}
}
} catch (e) {
console.warn('加载项目列表失败', e);
}
sel.innerHTML = buildVulnerabilityProjectOptionsHtml(selectedId || '');
sel.value = selectedId || '';
}
// 显示添加漏洞模态框 // 显示添加漏洞模态框
function showAddVulnerabilityModal() { async function showAddVulnerabilityModal() {
currentVulnerabilityId = null; currentVulnerabilityId = null;
document.getElementById('vulnerability-modal-title').textContent = vulnT('vulnerability.addVuln'); document.getElementById('vulnerability-modal-title').textContent = vulnT('vulnerability.addVuln');
const defaultProject = vulnerabilityFilters.project_id || '';
await populateVulnerabilityModalProjectSelect(defaultProject);
// 清空表单 // 清空表单
document.getElementById('vulnerability-conversation-id').value = ''; document.getElementById('vulnerability-conversation-id').value = '';
document.getElementById('vulnerability-conversation-tag').value = ''; document.getElementById('vulnerability-conversation-tag').value = '';
@@ -986,6 +1116,8 @@ async function editVulnerability(id) {
document.getElementById('vulnerability-impact').value = vuln.impact || ''; document.getElementById('vulnerability-impact').value = vuln.impact || '';
document.getElementById('vulnerability-recommendation').value = vuln.recommendation || ''; document.getElementById('vulnerability-recommendation').value = vuln.recommendation || '';
await populateVulnerabilityModalProjectSelect(vuln.project_id || '');
document.getElementById('vulnerability-modal').style.display = 'block'; document.getElementById('vulnerability-modal').style.display = 'block';
} catch (error) { } catch (error) {
console.error('加载漏洞失败:', error); console.error('加载漏洞失败:', error);
@@ -1004,8 +1136,11 @@ async function saveVulnerability() {
return; return;
} }
const projectId = (document.getElementById('vulnerability-project-id')?.value || '').trim();
const data = { const data = {
conversation_id: conversationId, conversation_id: conversationId,
project_id: projectId,
conversation_tag: document.getElementById('vulnerability-conversation-tag').value.trim(), conversation_tag: document.getElementById('vulnerability-conversation-tag').value.trim(),
task_tag: document.getElementById('vulnerability-task-tag').value.trim(), task_tag: document.getElementById('vulnerability-task-tag').value.trim(),
title: title, title: title,
@@ -1025,12 +1160,30 @@ async function saveVulnerability() {
: '/api/vulnerabilities'; : '/api/vulnerabilities';
const method = currentVulnerabilityId ? 'PUT' : 'POST'; const method = currentVulnerabilityId ? 'PUT' : 'POST';
let body = data;
if (currentVulnerabilityId) {
body = {
project_id: projectId,
conversation_tag: data.conversation_tag,
task_tag: data.task_tag,
title: data.title,
description: data.description,
severity: data.severity,
status: data.status,
type: data.type,
target: data.target,
proof: data.proof,
impact: data.impact,
recommendation: data.recommendation,
};
}
const response = await apiFetch(url, { const response = await apiFetch(url, {
method: method, method: method,
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify(data) body: JSON.stringify(body)
}); });
if (!response.ok) { if (!response.ok) {
@@ -1090,12 +1243,19 @@ function filterVulnerabilities() {
// 清除筛选 // 清除筛选
function clearVulnerabilityFilters() { function clearVulnerabilityFilters() {
closeVulnerabilityMoreFiltersPopover(false);
if (vulnerabilityFilterDebounceTimer) {
clearTimeout(vulnerabilityFilterDebounceTimer);
vulnerabilityFilterDebounceTimer = null;
}
const fields = [ const fields = [
'vulnerability-id-filter', 'vulnerability-search-filter',
'vulnerability-exact-id-filter',
'vulnerability-conversation-filter', 'vulnerability-conversation-filter',
'vulnerability-task-filter', 'vulnerability-task-filter',
'vulnerability-conversation-tag-filter', 'vulnerability-conversation-tag-filter',
'vulnerability-task-tag-filter', 'vulnerability-task-tag-filter',
'vulnerability-project-filter',
'vulnerability-severity-filter', 'vulnerability-severity-filter',
'vulnerability-status-filter' 'vulnerability-status-filter'
]; ];
@@ -1105,7 +1265,9 @@ function clearVulnerabilityFilters() {
}); });
vulnerabilityFilters = { vulnerabilityFilters = {
q: '',
id: '', id: '',
project_id: '',
conversation_id: '', conversation_id: '',
task_id: '', task_id: '',
conversation_tag: '', conversation_tag: '',
@@ -1200,6 +1362,21 @@ function vulnerabilityCopyEncoded(evt, encoded) {
} }
} }
function vulnDetailProjectField(vuln) {
const label = vulnT('vulnerabilityPage.detailProject');
const hint = escapeHtml(vulnT('vulnerabilityPage.projectBindHint'));
return `<div class="vuln-detail-field">
<div class="vuln-detail-field__label">${escapeHtml(label)}</div>
<div class="vuln-detail-field__row">
<select class="vuln-detail-field-select vulnerability-project-bind-select" data-vuln-id="${escapeHtml(vuln.id)}"
onchange="bindVulnerabilityProject(this.dataset.vulnId, this.value, true)" onclick="event.stopPropagation();"
title="${hint}" aria-label="${escapeHtml(label)}">
${buildVulnerabilityProjectOptionsHtml(vuln.project_id || '')}
</select>
</div>
</div>`;
}
function vulnDetailField(label, value, asCode) { function vulnDetailField(label, value, asCode) {
if (value === undefined || value === null || String(value) === '') { if (value === undefined || value === null || String(value) === '') {
return ''; return '';
@@ -1277,8 +1454,11 @@ function formatVulnerabilityAsMarkdown(vuln) {
function buildVulnerabilityFilterParams() { function buildVulnerabilityFilterParams() {
const params = new URLSearchParams(); const params = new URLSearchParams();
const keys = ['id', 'conversation_id', 'task_id', 'conversation_tag', 'task_tag', 'severity', 'status']; if (vulnerabilityFilters.q) {
keys.forEach((k) => { params.append('q', vulnerabilityFilters.q);
}
const keys = ['id', 'project_id', 'conversation_id', 'task_id', 'conversation_tag', 'task_tag', 'severity', 'status'];
keys.forEach(function (k) {
if (vulnerabilityFilters[k]) { if (vulnerabilityFilters[k]) {
params.append(k, vulnerabilityFilters[k]); params.append(k, vulnerabilityFilters[k]);
} }
@@ -1395,3 +1575,80 @@ document.addEventListener('languagechange', function () {
} }
}); });
async function bindVulnerabilityProject(vulnId, projectId, silent) {
if (!vulnId) return;
try {
const response = await apiFetch(`/api/vulnerabilities/${encodeURIComponent(vulnId)}`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ project_id: projectId || '' }),
});
if (!response.ok) {
const err = await response.json().catch(() => ({}));
throw new Error(err.error || vulnT('vulnerabilityPage.projectBindFailed'));
}
if (!silent) {
alert(vulnT('vulnerabilityPage.projectBindOk'));
}
loadVulnerabilityStats();
loadVulnerabilities();
} catch (error) {
console.error('绑定项目失败:', error);
alert(vulnT('vulnerabilityPage.projectBindFailed') + ': ' + error.message);
loadVulnerabilities();
}
}
async function refreshVulnerabilityProjectFilter() {
const sel = document.getElementById('vulnerability-project-filter');
if (!sel) return;
try {
const res = await apiFetch('/api/projects?limit=200');
if (!res.ok) return;
const list = await res.json();
if (typeof rebuildProjectNameMap === 'function') {
rebuildProjectNameMap(list);
} else if (typeof projectNameById !== 'undefined') {
list.forEach((p) => { if (p.id) projectNameById[p.id] = p.name || p.id; });
}
const cur = vulnerabilityFilters.project_id || sel.value || '';
let html = '<option value="">全部项目</option>';
(list || []).forEach((p) => {
if (!p.id) return;
const selected = p.id === cur ? ' selected' : '';
const arch = p.status === 'archived' ? ' [归档]' : '';
html += `<option value="${escapeHtml(p.id)}"${selected}>${escapeHtml(p.name || p.id)}${arch}</option>`;
});
sel.innerHTML = html;
if (cur) sel.value = cur;
const modalSel = document.getElementById('vulnerability-project-id');
if (modalSel && document.getElementById('vulnerability-modal')?.style.display === 'block') {
const modalCur = modalSel.value || '';
modalSel.innerHTML = buildVulnerabilityProjectOptionsHtml(modalCur);
modalSel.value = modalCur;
}
} catch (e) {
console.warn('加载项目筛选列表失败', e);
}
}
function setVulnerabilityProjectFilter(projectId) {
vulnerabilityFilters.project_id = projectId || '';
const sel = document.getElementById('vulnerability-project-filter');
if (sel) sel.value = projectId || '';
applyVulnerabilityFilters();
}
function setVulnerabilityIdFilter(vulnId) {
vulnerabilityFilters.id = vulnId || '';
const el = document.getElementById('vulnerability-exact-id-filter');
if (el) el.value = vulnId || '';
applyVulnerabilityFilters();
}
window.refreshVulnerabilityProjectFilter = refreshVulnerabilityProjectFilter;
window.setVulnerabilityProjectFilter = setVulnerabilityProjectFilter;
window.setVulnerabilityIdFilter = setVulnerabilityIdFilter;
window.bindVulnerabilityProject = bindVulnerabilityProject;
window.buildVulnerabilityProjectOptionsHtml = buildVulnerabilityProjectOptionsHtml;

Some files were not shown because too many files have changed in this diff Show More