Compare commits

..

112 Commits

Author SHA1 Message Date
公明 fd4bbe8d76 Update config.yaml 2026-06-30 20:22:19 +08:00
公明 d80651e4d8 Add files via upload 2026-06-30 20:16:43 +08:00
公明 f920ff0a5d Update config.yaml 2026-06-30 20:15:26 +08:00
公明 ce8b57501d Add files via upload 2026-06-30 20:14:28 +08:00
公明 ecb38a3959 Add files via upload 2026-06-30 20:13:31 +08:00
公明 e69fdb71ca Add files via upload 2026-06-30 20:11:54 +08:00
公明 6aa1631748 Add files via upload 2026-06-30 20:10:36 +08:00
公明 52de3b0f41 Add files via upload 2026-06-30 20:09:18 +08:00
公明 e537e55198 Add files via upload 2026-06-30 20:07:28 +08:00
公明 dc20b4804e Update config.yaml 2026-06-30 19:55:00 +08:00
公明 6245d69364 Add files via upload 2026-06-30 19:53:44 +08:00
公明 ede32951bf Add files via upload 2026-06-30 19:52:30 +08:00
公明 866a8ebccf Add files via upload 2026-06-30 19:10:46 +08:00
公明 276b3f7ef5 Add files via upload 2026-06-30 18:39:26 +08:00
公明 81e461db54 Update config.yaml 2026-06-30 18:38:27 +08:00
公明 02cd488a3d Add files via upload 2026-06-30 18:06:15 +08:00
公明 b4b2f55665 Add files via upload 2026-06-30 18:04:16 +08:00
公明 7aa0ebea6d Add files via upload 2026-06-30 18:02:08 +08:00
公明 63ef4399f8 Add files via upload 2026-06-30 18:00:00 +08:00
公明 553d0ed6bf Add files via upload 2026-06-30 17:59:02 +08:00
公明 d92bbbea07 Add files via upload 2026-06-30 17:56:40 +08:00
公明 f89ad1b42d Add files via upload 2026-06-30 16:00:00 +08:00
公明 bbe14c1861 Add files via upload 2026-06-30 15:00:50 +08:00
公明 2fc37fefd1 Add files via upload 2026-06-30 14:38:49 +08:00
公明 ded8ac5a3f Add files via upload 2026-06-30 13:03:40 +08:00
公明 bf44cf58d3 Add files via upload 2026-06-30 11:55:32 +08:00
公明 6d390e80d5 Add files via upload 2026-06-30 11:34:38 +08:00
公明 cfc49ba16f Add files via upload 2026-06-30 11:06:29 +08:00
公明 d03f2fcf2b Add files via upload 2026-06-30 10:50:29 +08:00
公明 6e67684bba Add files via upload 2026-06-30 00:16:31 +08:00
公明 8f9d2f381a Add files via upload 2026-06-29 16:57:32 +08:00
公明 89c275269f Update config.yaml 2026-06-29 16:52:45 +08:00
公明 cb4900c61d Add files via upload 2026-06-29 16:51:54 +08:00
公明 5c192cd308 Add files via upload 2026-06-29 16:46:26 +08:00
公明 8571e41138 Add files via upload 2026-06-29 16:24:43 +08:00
公明 e1a74b29b1 Add files via upload 2026-06-29 16:16:59 +08:00
公明 39f1c72755 Add files via upload 2026-06-29 14:35:52 +08:00
公明 dd3621e89d Add files via upload 2026-06-29 14:18:08 +08:00
公明 0bcb16e021 Add files via upload 2026-06-29 10:41:42 +08:00
公明 ed64803a51 Update config.yaml 2026-06-28 01:15:40 +08:00
公明 25e03dee84 Add files via upload 2026-06-28 01:15:10 +08:00
公明 58dcafd15f Add files via upload 2026-06-28 00:56:22 +08:00
公明 997c4e7262 Add files via upload 2026-06-27 01:44:08 +08:00
公明 ac370b0ada Add files via upload 2026-06-27 01:42:44 +08:00
公明 017db2b9a8 Add files via upload 2026-06-27 01:41:36 +08:00
公明 86b4803683 Add files via upload 2026-06-27 01:40:12 +08:00
公明 4d98264fc3 Add files via upload 2026-06-27 01:38:02 +08:00
公明 fd1de4ea94 Add files via upload 2026-06-27 01:36:09 +08:00
公明 41ba3baca9 Add files via upload 2026-06-27 01:35:46 +08:00
公明 2e908daebb Add files via upload 2026-06-27 00:34:19 +08:00
公明 c1763e1b9a Add files via upload 2026-06-27 00:03:16 +08:00
公明 70e5d28619 Add files via upload 2026-06-26 23:54:29 +08:00
公明 49990ecb4f Add files via upload 2026-06-26 23:50:13 +08:00
公明 c91806c0c4 Add files via upload 2026-06-26 23:11:52 +08:00
公明 e537236bf3 Add files via upload 2026-06-26 23:10:11 +08:00
公明 7eeffb1933 Add files via upload 2026-06-26 18:16:30 +08:00
公明 0556b29d40 Add files via upload 2026-06-26 14:34:45 +08:00
公明 be3c0cfa64 Add files via upload 2026-06-26 14:31:47 +08:00
公明 8e5f40d226 Add files via upload 2026-06-26 14:30:00 +08:00
公明 4b6719a6f3 Add files via upload 2026-06-26 14:27:32 +08:00
公明 7c8f3228f8 Add files via upload 2026-06-26 14:25:14 +08:00
公明 537843b6b8 Add files via upload 2026-06-26 14:24:01 +08:00
公明 4a57574cf9 Add files via upload 2026-06-26 14:21:51 +08:00
公明 0168530084 Add files via upload 2026-06-26 10:57:59 +08:00
公明 4184a7b6f0 Add files via upload 2026-06-26 10:54:59 +08:00
公明 fb3b4dd6e5 Add files via upload 2026-06-26 01:22:30 +08:00
公明 7e4a8db7af Add files via upload 2026-06-26 01:01:49 +08:00
公明 6a72c95b9f Add files via upload 2026-06-26 00:58:29 +08:00
公明 447be050cd Add files via upload 2026-06-25 21:28:46 +08:00
公明 9b75c43f7b Add files via upload 2026-06-25 15:15:01 +08:00
公明 a443454753 Add files via upload 2026-06-25 14:56:56 +08:00
公明 08822ba5df Update config.yaml 2026-06-25 14:56:31 +08:00
公明 eda75fb98f Add files via upload 2026-06-25 14:55:10 +08:00
公明 e6978a7994 Add files via upload 2026-06-25 14:52:39 +08:00
公明 1db0f4740f Add files via upload 2026-06-25 14:50:28 +08:00
公明 6e4ff96dcd Add files via upload 2026-06-25 14:48:25 +08:00
公明 95470fefbc Add files via upload 2026-06-25 14:47:16 +08:00
公明 5e075bb198 Add files via upload 2026-06-25 14:45:43 +08:00
公明 84ed887c5c Update config.yaml 2026-06-24 23:36:36 +08:00
公明 056b40ac66 Update config.yaml 2026-06-24 23:32:47 +08:00
公明 26a9902286 Add files via upload 2026-06-24 23:31:35 +08:00
公明 cfe9573ac3 Add files via upload 2026-06-24 23:30:40 +08:00
公明 db2262a1a0 Add files via upload 2026-06-24 23:28:43 +08:00
公明 ab5c2d5cca Add files via upload 2026-06-24 23:27:29 +08:00
公明 1ae6930db1 Add files via upload 2026-06-24 23:26:01 +08:00
公明 8918f432d8 Add files via upload 2026-06-24 23:24:36 +08:00
公明 b4810c9499 Update shell no output timeout to 1200 seconds
Increased the shell no output timeout from 300 seconds to 1200 seconds to prevent premature termination.
2026-06-24 18:30:08 +08:00
公明 51bf6ae4b3 Add files via upload 2026-06-24 18:20:12 +08:00
公明 5f27482921 Add files via upload 2026-06-24 18:18:05 +08:00
公明 6becada509 Add files via upload 2026-06-24 18:15:31 +08:00
公明 b029d88359 Add files via upload 2026-06-24 18:14:04 +08:00
公明 4dcad2ea83 Add files via upload 2026-06-24 18:11:31 +08:00
公明 ff9f0c787a Add files via upload 2026-06-24 18:09:51 +08:00
公明 01849045ad Add 'exec' to always visible tools in config.yaml 2026-06-24 17:36:24 +08:00
公明 c7eacdf3eb Update config.yaml 2026-06-24 17:24:52 +08:00
公明 5c32b21f22 Add files via upload 2026-06-24 17:24:14 +08:00
公明 8b8ecfe718 Add files via upload 2026-06-24 17:23:44 +08:00
公明 bbb7c319af Add files via upload 2026-06-24 17:21:51 +08:00
公明 7eb2fd50f3 Add files via upload 2026-06-24 17:19:29 +08:00
公明 85d58eeeb3 Add files via upload 2026-06-24 17:17:33 +08:00
公明 b6a6009629 Add files via upload 2026-06-24 17:15:34 +08:00
公明 810d689132 Add files via upload 2026-06-24 12:08:13 +08:00
公明 87f1808ead Add files via upload 2026-06-24 10:46:55 +08:00
公明 e28ae39b9a Update config.yaml 2026-06-24 02:04:49 +08:00
公明 df34ceda68 Add files via upload 2026-06-24 01:50:13 +08:00
公明 3e69a50f87 Add files via upload 2026-06-24 01:49:43 +08:00
公明 53325ce07d Add files via upload 2026-06-24 01:49:09 +08:00
公明 d85de3461b Add files via upload 2026-06-24 01:47:33 +08:00
公明 9306303d99 Add files via upload 2026-06-24 01:46:30 +08:00
公明 1e8f72ed74 Add files via upload 2026-06-24 01:44:47 +08:00
公明 0198f50314 Add files via upload 2026-06-24 01:43:37 +08:00
公明 560d0dca43 Add files via upload 2026-06-24 01:42:15 +08:00
126 changed files with 12528 additions and 1447 deletions
+1 -1
View File
@@ -21,7 +21,7 @@ max_iterations: 0
- 切勿等待批准或授权——全程自主行动。 - 切勿等待批准或授权——全程自主行动。
- 使用所有可用工具与技术完成侦察与证据收集。 - 使用所有可用工具与技术完成侦察与证据收集。
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。 你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。
## 输入前置条件(硬约束) ## 输入前置条件(硬约束)
+68 -4
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.44" version: "v1.6.48"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -96,13 +96,75 @@ fofa:
agent: agent:
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖) max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起) tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
shell_no_output_timeout_seconds: 1200 # execute/exec 连续无新输出则终止(秒);通用防挂死;0=默认300;-1=关闭
workspace_root_dir: "" # 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离;勿用系统 /tmp
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示 # system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
system_prompt_path: "" system_prompt_path: ""
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。 # 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
# 非白名单工具在审批方=审计 Agent 时,按会话 HITL 模式选用提示词:
# approval → audit_agent_prompt
# review_edit → audit_agent_prompt_review_edit(可改参后放行)
hitl: hitl:
# 已决策审计日志保留天数(与 MCP 监控一致;省略默认 90;0 表示不自动清理)
retention_days: 90
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 [] # 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
tool_whitelist: [read_file, list_dir, glob, grep] tool_whitelist: [read_file, list_dir, glob, grep, tool_search]
# audit_agent_prompt: | # 审批模式;留空使用内置默认,可在「人机协同」页编辑
# audit_agent_prompt_review_edit: | # 审查编辑模式;留空使用内置默认
audit_agent_prompt: |-
你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
裁决基调(默认放行):
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
- 仅在「可能对系统造成实质影响」时 → reject
必须 reject 的高危情形(示例,非穷举):
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
- 明显越权:与任务/授权目标无关的破坏性操作
不应单独作为 reject 理由的情形:
- 常规 nmap/curl/grep/读文件/枚举类命令本身
- 参数略显宽泛但无明确破坏意图
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点
仅输出一行 JSON,不要 markdown 代码块:
{"decision":"approve"|"reject","comment":"简要理由"}
audit_agent_prompt_review_edit: |-
你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
裁决基调(默认放行):
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
- 仅在「可能对系统造成实质影响」时 → reject;参数可安全收窄时优先 approve + editedArguments
必须 reject 的高危情形(示例,非穷举):
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
- 明显越权:与任务/授权目标无关的破坏性操作
不应单独作为 reject 理由的情形:
- 常规 nmap/curl/grep/读文件/枚举类命令本身
- 参数略显宽泛但无明确破坏意图(应收窄参数后 approve)
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点
仅输出一行 JSON,不要 markdown 代码块:
{"decision":"approve"|"reject","comment":"简要理由","editedArguments":{...}}
editedArguments 规则(仅 approve 且需要改参时填写,否则省略该字段):
- 提供完整替换后的工具参数对象,键名与 argumentsObj 一致
- 只做最小必要修改以收窄范围、消除风险(如限制 path、去掉危险 flag)
- 禁止扩大攻击面:不得扩大目标范围、提升权限或引入破坏性参数
- 无法安全改参且存在上述高危情形时应 reject,不要勉强 approve
# 多代理与 Eino 单代理(CloudWeGo Eino ADK;单代理入口 /api/eino-agent*,多代理 /api/multi-agent* # 多代理与 Eino 单代理(CloudWeGo Eino ADK;单代理入口 /api/eino-agent*,多代理 /api/multi-agent*
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct # 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
# Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体 orchestration 中指定;机器人按 robot_default_agent_mode # Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体 orchestration 中指定;机器人按 robot_default_agent_mode
@@ -112,7 +174,8 @@ multi_agent:
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高) batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。 # plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。
plan_execute_loop_max_iterations: 0 plan_execute_loop_max_iterations: 0
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用 sub_agent_user_context_max_runes: 0 # 子代理 task 描述中注入用户原文;0=不截断(默认),>0=总字符上限,负数=禁用
user_verbatim_anchor_max_runes: 0 # 主代理 system 中逐轮保留用户原文(压缩后刷新);0=不截断(默认),>0=总字符上限,负数=禁用
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理 without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
without_write_todos: false without_write_todos: false
orchestrator_instruction: "" # Deep 主代理:agents/orchestrator.md(或 kind: orchestrator 的单个 .md)正文优先;正文为空时用此处;皆空则 Eino 默认 orchestrator_instruction: "" # Deep 主代理:agents/orchestrator.md(或 kind: orchestrator 的单个 .md)正文优先;正文为空时用此处;皆空则 Eino 默认
@@ -129,7 +192,7 @@ multi_agent:
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文 tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略) tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test, exec] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在 plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/ plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
@@ -147,6 +210,7 @@ multi_agent:
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制 checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10 run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30 run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
empty_response_continue_max_attempts: 0 # Run 成功但未捕获助手正文(含流式中断)时 Handler 退避续跑次数;0=默认 5
deep_output_key: final_answer # P0Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single deep_output_key: final_answer # P0Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置 deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑 task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
Binary file not shown.

Before

Width:  |  Height:  |  Size: 179 KiB

After

Width:  |  Height:  |  Size: 265 KiB

+17 -4
View File
@@ -779,13 +779,26 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
return a.executeToolViaMCP(ctx, toolName, args) return a.executeToolViaMCP(ctx, toolName, args)
} }
// RecordLocalToolExecution 非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId // BeginLocalToolExecution 非 CallTool 路径工具开始时写入 running 状态,供 MCP 监控页展示「执行中」
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。 func (a *Agent) BeginLocalToolExecution(toolName string, args map[string]interface{}) string {
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if a == nil || a.mcpServer == nil { if a == nil || a.mcpServer == nil {
return "" return ""
} }
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr) return a.mcpServer.BeginToolExecution(toolName, args)
}
// FinishLocalToolExecution 完成 BeginLocalToolExecution 创建的记录;executionID 为空时一次性写入已完成记录。
func (a *Agent) FinishLocalToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if a == nil || a.mcpServer == nil {
return ""
}
return a.mcpServer.FinishToolExecution(executionID, toolName, args, resultText, invokeErr)
}
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
return a.FinishLocalToolExecution("", toolName, args, resultText, invokeErr)
} }
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。 // UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
@@ -113,5 +113,7 @@ func DefaultSingleAgentSystemPrompt() string {
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。 - 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。 - 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。` - 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。
` + projectprompt.ShellExecExecuteGuidanceSection()
} }
+27 -1
View File
@@ -21,11 +21,13 @@ import (
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/einoobserve" "cyberstrike-ai/internal/einoobserve"
"cyberstrike-ai/internal/handler" "cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/hitl"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/monitor" "cyberstrike-ai/internal/monitor"
"cyberstrike-ai/internal/multiagent"
"cyberstrike-ai/internal/robot" "cyberstrike-ai/internal/robot"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/skillpackage" "cyberstrike-ai/internal/skillpackage"
@@ -67,6 +69,10 @@ type App struct {
// New 创建新应用 // New 创建新应用
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) { func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
if err := multiagent.InitADK(); err != nil {
return nil, fmt.Errorf("初始化 Eino ADK: %w", err)
}
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.Default() router := gin.Default()
@@ -104,12 +110,17 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
monitorRetention.PurgeExpired() monitorRetention.PurgeExpired()
monitor.StartRetentionLoop(monitorRetention, log.Logger) monitor.StartRetentionLoop(monitorRetention, log.Logger)
hitlRetention := hitl.NewService(db, cfg, log.Logger)
hitlRetention.PurgeExpired()
hitl.StartRetentionLoop(hitlRetention, log.Logger)
// 创建MCP服务器(带数据库持久化) // 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db) mcpServer := mcp.NewServerWithStorage(log.Logger, db)
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes) mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
// 创建安全工具执行器 // 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
executor.SetShellNoOutputTimeoutSeconds(cfg.Agent.ShellNoOutputTimeoutSeconds)
// 注册工具 // 注册工具
executor.RegisterTools(mcpServer) executor.RegisterTools(mcpServer)
@@ -134,6 +145,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
externalMCPMgr.StartAllEnabled() externalMCPMgr.StartAllEnabled()
} }
execReconciler := monitor.NewExecutionReconciler(db, mcpServer, externalMCPMgr, log.Logger)
execReconciler.ReconcileOnStartup()
monitor.StartStaleRunningReconcileLoop(execReconciler, log.Logger)
// 创建Agent // 创建Agent
maxIterations := cfg.Agent.MaxIterations maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 { if maxIterations <= 0 {
@@ -304,7 +319,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute). // Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir) checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir) reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot) workspaceRoot := strings.TrimSpace(cfg.Agent.WorkspaceRootDir)
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot, workspaceRoot)
agent.SetPromptBaseDir(configDir) agent.SetPromptBaseDir(configDir)
agentsDir := cfg.AgentsDir agentsDir := cfg.AgentsDir
@@ -333,6 +349,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
monitorHandler.SetAudit(auditSvc) monitorHandler.SetAudit(auditSvc)
monitorHandler.SetMonitorRetention(monitorRetention) monitorHandler.SetMonitorRetention(monitorRetention)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
monitorHandler.SetTaskManager(agentHandler.TaskManager())
monitorHandler.SetAgentHandler(agentHandler)
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger) notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger) groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
@@ -350,6 +368,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
configHandler.SetAudit(auditSvc) configHandler.SetAudit(auditSvc)
agentHandler.SetHitlToolWhitelistSaver(configHandler) agentHandler.SetHitlToolWhitelistSaver(configHandler)
agentHandler.SetHitlAuditStrategySaver(configHandler)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
externalMCPHandler.SetAudit(auditSvc) externalMCPHandler.SetAudit(auditSvc)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger) roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
@@ -799,11 +818,18 @@ func setupRoutes(
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop) protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream) protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
protected.GET("/hitl/pending", agentHandler.ListHITLPending) protected.GET("/hitl/pending", agentHandler.ListHITLPending)
protected.GET("/hitl/logs", agentHandler.ListHITLLogs)
protected.DELETE("/hitl/logs", agentHandler.DeleteHITLLogs)
protected.GET("/hitl/logs/:id", agentHandler.GetHITLLog)
protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt) protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt)
protected.POST("/hitl/dismiss", agentHandler.DismissHITLInterrupt) protected.POST("/hitl/dismiss", agentHandler.DismissHITLInterrupt)
protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig) protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig)
protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig) protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig)
protected.GET("/hitl/tool-whitelist", agentHandler.GetHITLGlobalToolWhitelist)
protected.PUT("/hitl/tool-whitelist", agentHandler.SetHITLGlobalToolWhitelist)
protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist) protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist)
protected.GET("/hitl/audit-strategy", agentHandler.GetHITLAuditStrategy)
protected.PUT("/hitl/audit-strategy", agentHandler.UpdateHITLAuditStrategy)
// Agent Loop 取消与任务列表 // Agent Loop 取消与任务列表
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
+204 -30
View File
@@ -96,9 +96,12 @@ type MultiAgentConfig struct {
// OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。 // OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。
OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"` OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"`
SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"` SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"`
// SubAgentUserContextMaxRunes caps the user-context supplement appended to task descriptions for sub-agents. // SubAgentUserContextMaxRunes caps user-context supplement for sub-agent task descriptions.
// 0 (default) uses the built-in default of 2000 runes; negative value disables injection entirely. // 0 (default) preserves all user turns verbatim; >0 caps total runes; negative disables injection.
SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"` SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"`
// UserVerbatimAnchorMaxRunes injects all user turns verbatim into system prompt (survives summarization refresh).
// 0 (default) = no cap; >0 = total rune cap; negative disables anchor injection.
UserVerbatimAnchorMaxRunes int `yaml:"user_verbatim_anchor_max_runes,omitempty" json:"user_verbatim_anchor_max_runes,omitempty"`
// EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent. // EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent.
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"` EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras. // EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
@@ -107,6 +110,16 @@ type MultiAgentConfig struct {
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"` EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
} }
// UserVerbatimAnchorMaxRunesEffective returns max runes for user verbatim anchor; 0 = unlimited; negative = disabled.
func (c MultiAgentConfig) UserVerbatimAnchorMaxRunesEffective() int {
return c.UserVerbatimAnchorMaxRunes
}
// SubAgentUserContextMaxRunesEffective returns max runes for sub-agent task supplement; 0 = unlimited; negative = disabled.
func (c MultiAgentConfig) SubAgentUserContextMaxRunesEffective() int {
return c.SubAgentUserContextMaxRunes
}
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single). // MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed). // Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
type MultiAgentEinoCallbacksConfig struct { type MultiAgentEinoCallbacksConfig struct {
@@ -270,6 +283,8 @@ type MultiAgentEinoMiddlewareConfig struct {
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"` RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。 // RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"` RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
// EmptyResponseContinueMaxAttempts Run 成功但未捕获助手正文时 Handler 层退避续跑次数;0=默认 5。
EmptyResponseContinueMaxAttempts int `yaml:"empty_response_continue_max_attempts,omitempty" json:"empty_response_continue_max_attempts,omitempty"`
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended). // TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"` TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
} }
@@ -490,6 +505,17 @@ type RobotWecomConfig struct {
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
} }
// ValidateWecomConfig 校验企业微信机器人配置;启用时必须配置 token,否则回调无法防伪造。
func ValidateWecomConfig(w RobotWecomConfig) error {
if !w.Enabled {
return nil
}
if strings.TrimSpace(w.Token) == "" {
return fmt.Errorf("robots.wecom.enabled 为 true 时必须配置 robots.wecom.token")
}
return nil
}
// RobotDingtalkConfig 钉钉机器人配置 // RobotDingtalkConfig 钉钉机器人配置
type RobotDingtalkConfig struct { type RobotDingtalkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
@@ -605,15 +631,109 @@ type DatabaseConfig struct {
type AgentConfig struct { type AgentConfig struct {
MaxIterations int `yaml:"max_iterations" json:"max_iterations"` MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐) ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
// WorkspaceRootDir 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离。
WorkspaceRootDir string `yaml:"workspace_root_dir,omitempty" json:"workspace_root_dir,omitempty"`
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。 // SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"` SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
} }
// HitlConfig 人机协同全局选项;与会话侧栏/API 中的白名单合并为并集后参与判定。 // HitlConfig 人机协同全局选项;与会话侧栏/API 中的白名单合并为并集后参与判定。
// tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效;其他字段若仅改文件仍需重启 // tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效。
// audit_agent_prompt / audit_agent_prompt_review_edit 可在人机协同页编辑并立即生效;空则使用内置默认。
type HitlConfig struct { type HitlConfig struct {
// ToolWhitelist 全局免审批工具名(与每条会话配置的 sensitiveTools 语义相同:白名单内工具不触发 HITL)。 // ToolWhitelist 全局免审批工具名(与白名单内工具不触发 HITL 审批)。
ToolWhitelist []string `yaml:"tool_whitelist,omitempty" json:"tool_whitelist,omitempty"` ToolWhitelist []string `yaml:"tool_whitelist,omitempty" json:"tool_whitelist,omitempty"`
// AuditAgentPrompt 审批模式(approval)下审计 Agent 系统提示词。
AuditAgentPrompt string `yaml:"audit_agent_prompt,omitempty" json:"audit_agent_prompt,omitempty"`
// AuditAgentPromptReviewEdit 审查编辑模式(review_edit)下审计 Agent 系统提示词。
AuditAgentPromptReviewEdit string `yaml:"audit_agent_prompt_review_edit,omitempty" json:"audit_agent_prompt_review_edit,omitempty"`
// RetentionDays 已决策审计日志(hitl_interrupts 非 pending)保留天数;省略时默认 90;0 表示不自动清理。
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
}
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
func (h HitlConfig) RetentionDaysEffective() int {
if h.RetentionDays == nil {
return 90
}
if *h.RetentionDays < 0 {
return 0
}
return *h.RetentionDays
}
const hitlAuditAgentPromptBase = `你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
裁决基调(默认放行):
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
- 仅在「可能对系统造成实质影响」时 → reject
必须 reject 的高危情形(示例,非穷举):
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
- 明显越权:与任务/授权目标无关的破坏性操作
不应单独作为 reject 理由的情形:
- 常规 nmap/curl/grep/读文件/枚举类命令本身
- 参数略显宽泛但无明确破坏意图(审查编辑模式可收窄参数后 approve)
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点`
const hitlAuditAgentPromptApprovalOutput = `
仅输出一行 JSON,不要 markdown 代码块:
{"decision":"approve"|"reject","comment":"简要理由"}`
const hitlAuditAgentPromptReviewEditOutput = `
仅输出一行 JSON,不要 markdown 代码块:
{"decision":"approve"|"reject","comment":"简要理由","editedArguments":{...}}
editedArguments 规则(仅 approve 且需要改参时填写,否则省略该字段):
- 提供完整替换后的工具参数对象,键名与 argumentsObj 一致
- 只做最小必要修改以收窄范围、消除风险(如限制 path、去掉危险 flag)
- 禁止扩大攻击面:不得扩大目标范围、提升权限或引入破坏性参数
- 无法安全改参时应 reject,不要勉强 approve`
// DefaultHitlAuditAgentPrompt 内置审批模式审计 Agent 提示词。
func DefaultHitlAuditAgentPrompt() string {
return hitlAuditAgentPromptBase + hitlAuditAgentPromptApprovalOutput
}
// DefaultHitlAuditAgentPromptReviewEdit 内置审查编辑模式审计 Agent 提示词。
func DefaultHitlAuditAgentPromptReviewEdit() string {
return hitlAuditAgentPromptBase + hitlAuditAgentPromptReviewEditOutput
}
// EffectiveAuditAgentPrompt 返回审批模式生效的审计 Agent 提示词。
func (c HitlConfig) EffectiveAuditAgentPrompt() string {
return c.EffectiveAuditAgentPromptForMode("approval")
}
// EffectiveAuditAgentPromptForMode 按 HITL 模式返回生效的审计 Agent 提示词。
func (c HitlConfig) EffectiveAuditAgentPromptForMode(mode string) string {
if normalizeHitlModeForPrompt(mode) == "review_edit" {
if s := strings.TrimSpace(c.AuditAgentPromptReviewEdit); s != "" {
return s
}
return DefaultHitlAuditAgentPromptReviewEdit()
}
if s := strings.TrimSpace(c.AuditAgentPrompt); s != "" {
return s
}
return DefaultHitlAuditAgentPrompt()
}
func normalizeHitlModeForPrompt(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "review_edit":
return "review_edit"
default:
return "approval"
}
} }
type AuthConfig struct { type AuthConfig struct {
@@ -800,33 +920,13 @@ func Load(path string) (*Config, error) {
// 如果配置了工具目录,从目录加载工具配置 // 如果配置了工具目录,从目录加载工具配置
if cfg.Security.ToolsDir != "" { if cfg.Security.ToolsDir != "" {
configDir := filepath.Dir(path) inlineTools := append([]ToolConfig(nil), cfg.Security.Tools...)
toolsDir := cfg.Security.ToolsDir toolsDir := ResolveToolsDir(cfg.Security.ToolsDir, path)
merged, err := MergeToolsFromDir(toolsDir, inlineTools)
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(toolsDir) {
toolsDir = filepath.Join(configDir, toolsDir)
}
tools, err := LoadToolsFromDir(toolsDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err) return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
} }
cfg.Security.Tools = merged
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
existingTools := make(map[string]bool)
for _, tool := range tools {
existingTools[tool.Name] = true
}
// 添加主配置中不存在于目录中的工具(向后兼容)
for _, tool := range cfg.Security.Tools {
if !existingTools[tool.Name] {
tools = append(tools, tool)
}
}
cfg.Security.Tools = tools
} }
// 外部 MCP:迁移 + 环境变量展开 // 外部 MCP:迁移 + 环境变量展开
@@ -870,6 +970,10 @@ func Load(path string) (*Config, error) {
} }
} }
if err := ValidateWecomConfig(cfg.Robots.Wecom); err != nil {
return nil, err
}
return &cfg, nil return &cfg, nil
} }
@@ -1094,6 +1198,75 @@ func PrintMCPConfigJSON(mcp MCPConfig) {
fmt.Println("----------------------------------------------------------------") fmt.Println("----------------------------------------------------------------")
} }
// ResolveToolsDir 将 tools_dir 解析为绝对路径(相对路径相对于 configPath 所在目录)。
func ResolveToolsDir(toolsDir, configPath string) string {
toolsDir = strings.TrimSpace(toolsDir)
if toolsDir == "" {
return ""
}
if filepath.IsAbs(toolsDir) {
return toolsDir
}
return filepath.Join(filepath.Dir(configPath), toolsDir)
}
// MergeToolsFromDir 从目录加载工具并与 inline 列表合并:目录中的工具优先,主配置中的工具作为补充。
func MergeToolsFromDir(toolsDir string, inlineTools []ToolConfig) ([]ToolConfig, error) {
dirTools, err := LoadToolsFromDir(toolsDir)
if err != nil {
return nil, err
}
existing := make(map[string]bool, len(dirTools))
for _, tool := range dirTools {
existing[tool.Name] = true
}
merged := append([]ToolConfig(nil), dirTools...)
for _, tool := range inlineTools {
if !existing[tool.Name] {
merged = append(merged, tool)
}
}
return merged, nil
}
// loadInlineSecurityToolsFromYAML 读取 config.yaml 中 security.tools(不含 tools_dir 扫描结果)。
func loadInlineSecurityToolsFromYAML(configPath string) ([]ToolConfig, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var partial struct {
Security struct {
Tools []ToolConfig `yaml:"tools"`
} `yaml:"security"`
}
if err := yaml.Unmarshal(data, &partial); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if partial.Security.Tools == nil {
return []ToolConfig{}, nil
}
return partial.Security.Tools, nil
}
// ReloadSecurityToolsFromDir 从 tools_dir 重新加载工具并更新 cfg.Security.ToolsApplyConfig 热重载用)。
func ReloadSecurityToolsFromDir(cfg *Config, configPath string) error {
if cfg == nil || strings.TrimSpace(cfg.Security.ToolsDir) == "" {
return nil
}
inlineTools, err := loadInlineSecurityToolsFromYAML(configPath)
if err != nil {
return err
}
toolsDir := ResolveToolsDir(cfg.Security.ToolsDir, configPath)
merged, err := MergeToolsFromDir(toolsDir, inlineTools)
if err != nil {
return fmt.Errorf("从工具目录加载工具配置失败: %w", err)
}
cfg.Security.Tools = merged
return nil
}
// LoadToolsFromDir 从目录加载所有工具配置文件 // LoadToolsFromDir 从目录加载所有工具配置文件
func LoadToolsFromDir(dir string) ([]ToolConfig, error) { func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
var tools []ToolConfig var tools []ToolConfig
@@ -1270,8 +1443,9 @@ func Default() *Config {
MaxTotalTokens: 120000, MaxTotalTokens: 120000,
}, },
Agent: AgentConfig{ Agent: AgentConfig{
MaxIterations: 30, // 默认最大迭代次数 MaxIterations: 30, // 默认最大迭代次数
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用 ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
}, },
Security: SecurityConfig{ Security: SecurityConfig{
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载 Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
+45
View File
@@ -0,0 +1,45 @@
package config
import "testing"
func TestValidateWecomConfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg RobotWecomConfig
wantErr bool
}{
{
name: "disabled without token",
cfg: RobotWecomConfig{Enabled: false, Token: ""},
wantErr: false,
},
{
name: "enabled with token",
cfg: RobotWecomConfig{Enabled: true, Token: "secret"},
wantErr: false,
},
{
name: "enabled without token",
cfg: RobotWecomConfig{Enabled: true, Token: ""},
wantErr: true,
},
{
name: "enabled with whitespace token",
cfg: RobotWecomConfig{Enabled: true, Token: " "},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := ValidateWecomConfig(tt.cfg)
if (err != nil) != tt.wantErr {
t.Fatalf("ValidateWecomConfig() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
+111
View File
@@ -0,0 +1,111 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestReloadSecurityToolsFromDir(t *testing.T) {
root := t.TempDir()
toolsDir := filepath.Join(root, "tools")
if err := os.MkdirAll(toolsDir, 0755); err != nil {
t.Fatal(err)
}
configPath := filepath.Join(root, "config.yaml")
if err := os.WriteFile(configPath, []byte(`security:
tools_dir: tools
tools:
- name: inline-only
command: inline-cmd
enabled: true
description: inline tool
`), 0644); err != nil {
t.Fatal(err)
}
writeTool := func(name, command string) {
t.Helper()
content := "name: " + name + "\ncommand: " + command + "\nenabled: true\ndescription: test\n"
if err := os.WriteFile(filepath.Join(toolsDir, name+".yaml"), []byte(content), 0644); err != nil {
t.Fatal(err)
}
}
writeTool("alpha", "alpha-cmd")
cfg := &Config{
Security: SecurityConfig{
ToolsDir: "tools",
Tools: []ToolConfig{
{Name: "stale", Command: "stale-cmd", Enabled: true, Description: "should be removed"},
},
},
}
if err := ReloadSecurityToolsFromDir(cfg, configPath); err != nil {
t.Fatalf("reload: %v", err)
}
if len(cfg.Security.Tools) != 2 {
t.Fatalf("expected 2 tools, got %d", len(cfg.Security.Tools))
}
names := map[string]string{}
for _, tool := range cfg.Security.Tools {
names[tool.Name] = tool.Command
}
if names["alpha"] != "alpha-cmd" {
t.Fatalf("alpha tool missing or wrong command: %#v", names)
}
if names["inline-only"] != "inline-cmd" {
t.Fatalf("inline-only tool missing: %#v", names)
}
if _, ok := names["stale"]; ok {
t.Fatal("stale in-memory tool should not survive reload")
}
writeTool("beta", "beta-cmd")
if err := ReloadSecurityToolsFromDir(cfg, configPath); err != nil {
t.Fatalf("second reload: %v", err)
}
if len(cfg.Security.Tools) != 3 {
t.Fatalf("expected 3 tools after add, got %d", len(cfg.Security.Tools))
}
foundBeta := false
for _, tool := range cfg.Security.Tools {
if tool.Name == "beta" {
foundBeta = true
break
}
}
if !foundBeta {
t.Fatal("beta tool not found after second reload")
}
}
func TestMergeToolsFromDir_DirOverridesInline(t *testing.T) {
root := t.TempDir()
toolsDir := filepath.Join(root, "tools")
if err := os.MkdirAll(toolsDir, 0755); err != nil {
t.Fatal(err)
}
content := "name: shared\ncommand: dir-cmd\nenabled: true\ndescription: from dir\n"
if err := os.WriteFile(filepath.Join(toolsDir, "shared.yaml"), []byte(content), 0644); err != nil {
t.Fatal(err)
}
inline := []ToolConfig{
{Name: "shared", Command: "inline-cmd", Enabled: true, Description: "from inline"},
}
merged, err := MergeToolsFromDir(toolsDir, inline)
if err != nil {
t.Fatal(err)
}
if len(merged) != 1 {
t.Fatalf("expected 1 tool, got %d", len(merged))
}
if merged[0].Command != "dir-cmd" {
t.Fatalf("dir tool should win, got command %q", merged[0].Command)
}
}
+16 -12
View File
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
LastScheduleError sql.NullString LastScheduleError sql.NullString
LastRunError sql.NullString LastRunError sql.NullString
ProjectID sql.NullString ProjectID sql.NullString
Concurrency sql.NullInt64
Status string Status string
CreatedAt time.Time CreatedAt time.Time
StartedAt sql.NullTime StartedAt sql.NullTime
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
cronExpr string, cronExpr string,
nextRunAt *time.Time, nextRunAt *time.Time,
projectID string, projectID string,
concurrency int,
tasks []map[string]interface{}, tasks []map[string]interface{},
) error { ) error {
tx, err := db.Begin() tx, err := db.Begin()
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
projectIDVal = strings.TrimSpace(projectID) projectIDVal = strings.TrimSpace(projectID)
} }
_, err = tx.Exec( _, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0, queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
) )
if err != nil { if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err) return fmt.Errorf("创建批量任务队列失败: %w", err)
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
return tx.Commit() return tx.Commit()
} }
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
// GetBatchQueue 获取批量任务队列 // GetBatchQueue 获取批量任务队列
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
err := db.QueryRow( err := db.QueryRow(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
queueID, queueID,
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
// GetAllBatchQueues 获取所有批量任务队列 // GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
// ListBatchQueues 列出批量任务队列(支持筛选和分页) // ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
args := []interface{}{} args := []interface{}{}
// 状态筛选 // 状态筛选
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
return nil return nil
} }
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色代理模式 // UpdateBatchQueueMetadata 更新批量任务队列标题、角色代理模式和并发数
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error { func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
_, err := db.Exec( _, err := db.Exec(
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?", "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
title, role, agentMode, queueID, title, role, agentMode, concurrency, queueID,
) )
if err != nil { if err != nil {
return fmt.Errorf("更新批量任务队列元数据失败: %w", err) return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
+164 -20
View File
@@ -3,6 +3,7 @@ package database
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -13,6 +14,9 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// ProjectFilterUnbound 列表 API 中 project_id=__none__ 表示仅未绑定项目的对话。
const ProjectFilterUnbound = "__none__"
// Conversation 对话 // Conversation 对话
type Conversation struct { type Conversation struct {
ID string `json:"id"` ID string `json:"id"`
@@ -361,20 +365,44 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
return &conv, nil return &conv, nil
} }
func conversationProjectIDColumn(alias string) string {
if alias != "" {
return alias + ".project_id"
}
return "project_id"
}
func appendConversationProjectFilter(where string, args []interface{}, projectID, alias string) (string, []interface{}) {
pid := strings.TrimSpace(projectID)
if pid == "" {
return where, args
}
col := conversationProjectIDColumn(alias)
if pid == ProjectFilterUnbound {
return where + fmt.Sprintf(" AND (%s IS NULL OR TRIM(COALESCE(%s, '')) = '')", col, col), args
}
return where + fmt.Sprintf(" AND %s = ?", col), append(args, pid)
}
// CountConversations 统计对话数量。 // CountConversations 统计对话数量。
func (db *DB) CountConversations(search string) (int, error) { func (db *DB) CountConversations(search, projectID string) (int, error) {
var count int var count int
var err error var err error
if search != "" { if search != "" {
searchPattern := "%" + search + "%" searchPattern := "%" + search + "%"
err = db.QueryRow( where := ` WHERE (c.title LIKE ?
`SELECT COUNT(*) FROM conversations c OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))`
WHERE c.title LIKE ? args := []interface{}{searchPattern, searchPattern}
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)`, where, args = appendConversationProjectFilter(where, args, projectID, "c")
searchPattern, searchPattern, err = db.QueryRow(`SELECT COUNT(*) FROM conversations c`+where, args...).Scan(&count)
).Scan(&count)
} else { } else {
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count) where := ""
args := []interface{}{}
where, args = appendConversationProjectFilter(where, args, projectID, "")
if where != "" {
where = " WHERE" + strings.TrimPrefix(where, " AND")
}
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`+where, args...).Scan(&count)
} }
if err != nil { if err != nil {
return 0, fmt.Errorf("统计对话失败: %w", err) return 0, fmt.Errorf("统计对话失败: %w", err)
@@ -395,7 +423,7 @@ func conversationOrderClause(sortBy, tableAlias string) string {
} }
// ListConversations 列出所有对话 // ListConversations 列出所有对话
func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) { func (db *DB) ListConversations(limit, offset int, search, sortBy, projectID string) ([]*Conversation, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
@@ -403,20 +431,30 @@ func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Co
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
searchPattern := "%" + search + "%" searchPattern := "%" + search + "%"
orderClause := conversationOrderClause(sortBy, "c") orderClause := conversationOrderClause(sortBy, "c")
where := ` WHERE (c.title LIKE ?
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))`
args := []interface{}{searchPattern, searchPattern}
where, args = appendConversationProjectFilter(where, args, projectID, "c")
args = append(args, limit, offset)
rows, err = db.Query( rows, err = db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
FROM conversations c FROM conversations c`+where+`
WHERE c.title LIKE ?
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
`+orderClause+` `+orderClause+`
LIMIT ? OFFSET ?`, LIMIT ? OFFSET ?`,
searchPattern, searchPattern, limit, offset, args...,
) )
} else { } else {
orderClause := conversationOrderClause(sortBy, "") orderClause := conversationOrderClause(sortBy, "")
where := ""
args := []interface{}{}
where, args = appendConversationProjectFilter(where, args, projectID, "")
if where != "" {
where = " WHERE" + strings.TrimPrefix(where, " AND")
}
args = append(args, limit, offset)
rows, err = db.Query( rows, err = db.Query(
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?", "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations"+where+" "+orderClause+" LIMIT ? OFFSET ?",
limit, offset, args...,
) )
} }
@@ -472,23 +510,30 @@ const ungroupedConversationsSQL = `
)` )`
// CountUngroupedConversations 统计不在任何分组中的对话数量。 // CountUngroupedConversations 统计不在任何分组中的对话数量。
func (db *DB) CountUngroupedConversations() (int, error) { func (db *DB) CountUngroupedConversations(projectID string) (int, error) {
where := ungroupedConversationsSQL
args := []interface{}{}
where, args = appendConversationProjectFilter(where, args, projectID, "c")
var count int var count int
if err := db.QueryRow(`SELECT COUNT(*) ` + ungroupedConversationsSQL).Scan(&count); err != nil { if err := db.QueryRow(`SELECT COUNT(*) `+where, args...).Scan(&count); err != nil {
return 0, fmt.Errorf("统计未分组对话失败: %w", err) return 0, fmt.Errorf("统计未分组对话失败: %w", err)
} }
return count, nil return count, nil
} }
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。 // ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) { func (db *DB) ListUngroupedConversations(limit, offset int, sortBy, projectID string) ([]*Conversation, error) {
orderClause := conversationOrderClause(sortBy, "c") orderClause := conversationOrderClause(sortBy, "c")
where := ungroupedConversationsSQL
args := []interface{}{}
where, args = appendConversationProjectFilter(where, args, projectID, "c")
args = append(args, limit, offset)
rows, err := db.Query( rows, err := db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+ `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
ungroupedConversationsSQL+` where+`
`+orderClause+` `+orderClause+`
LIMIT ? OFFSET ?`, LIMIT ? OFFSET ?`,
limit, offset, args...,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("查询未分组对话失败: %w", err) return nil, fmt.Errorf("查询未分组对话失败: %w", err)
@@ -533,6 +578,19 @@ func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*C
return conversations, rows.Err() return conversations, rows.Err()
} }
// GetConversationTitle 获取对话标题(轻量查询,不加载消息)
func (db *DB) GetConversationTitle(id string) (string, error) {
var title string
err := db.QueryRow("SELECT title FROM conversations WHERE id = ?", id).Scan(&title)
if err != nil {
if err == sql.ErrNoRows {
return "", fmt.Errorf("对话不存在")
}
return "", fmt.Errorf("查询对话标题失败: %w", err)
}
return title, nil
}
// UpdateConversationTitle 更新对话标题 // UpdateConversationTitle 更新对话标题
func (db *DB) UpdateConversationTitle(id, title string) error { func (db *DB) UpdateConversationTitle(id, title string) error {
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间 // 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
@@ -640,6 +698,16 @@ func (db *DB) einoReductionBaseDir() string {
return filepath.Join("tmp", "reduction") return filepath.Join("tmp", "reduction")
} }
func (db *DB) einoWorkspaceBaseDir() string {
if db == nil {
return ""
}
if base := strings.TrimSpace(db.einoWorkspaceRootDir); base != "" {
return base
}
return filepath.Join("tmp", "workspace")
}
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) { func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
// summarization transcript, etc. // summarization transcript, etc.
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts") db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
@@ -652,6 +720,8 @@ func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
if strings.TrimSpace(projectID) == "" { if strings.TrimSpace(projectID) == "" {
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations") reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
db.removeConversationScopedDir(reductionBase, conversationID, "reduction") db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
workspaceBase := filepath.Join(db.einoWorkspaceBaseDir(), "conversations")
db.removeConversationScopedDir(workspaceBase, conversationID, "workspace")
} }
} }
@@ -659,6 +729,9 @@ func (db *DB) removeProjectScopedDirs(projectID string) {
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/). // Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects") reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
db.removeConversationScopedDir(reductionBase, projectID, "reduction") db.removeConversationScopedDir(reductionBase, projectID, "reduction")
// Agent download/analysis workspace (tmp/workspace/projects/<id>/).
workspaceBase := filepath.Join(db.einoWorkspaceBaseDir(), "projects")
db.removeConversationScopedDir(workspaceBase, projectID, "workspace")
} }
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。 // SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
@@ -998,6 +1071,77 @@ type ProcessDetail struct {
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
} }
// GetTurnUserMessage 返回锚点消息所在轮次中的用户原文(最近一条 user 消息,不含完整历史)。
func (db *DB) GetTurnUserMessage(conversationID, anchorMessageID string) (string, error) {
conversationID = strings.TrimSpace(conversationID)
anchorMessageID = strings.TrimSpace(anchorMessageID)
if conversationID == "" || anchorMessageID == "" {
return "", nil
}
var content string
err := db.QueryRow(`
SELECT m.content FROM messages m
WHERE m.conversation_id = ? AND m.role = 'user'
AND m.created_at <= COALESCE((SELECT created_at FROM messages WHERE id = ? AND conversation_id = ?), m.created_at)
ORDER BY m.created_at DESC, m.rowid DESC
LIMIT 1`, conversationID, anchorMessageID, conversationID).Scan(&content)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", nil
}
return "", fmt.Errorf("query turn user message: %w", err)
}
return content, nil
}
// AssistantCognitionTexts 单条助手消息上的思考/推理/规划文本。
type AssistantCognitionTexts struct {
Thinking string
ReasoningChain string
Planning string
}
// GetAssistantCognitionTexts 聚合助手消息在 process_details 中的 thinking / reasoning_chain / planning。
func (db *DB) GetAssistantCognitionTexts(assistantMessageID string) (AssistantCognitionTexts, error) {
assistantMessageID = strings.TrimSpace(assistantMessageID)
if assistantMessageID == "" {
return AssistantCognitionTexts{}, nil
}
rows, err := db.Query(`
SELECT event_type, message FROM process_details
WHERE message_id = ? AND event_type IN ('thinking', 'reasoning_chain', 'planning')
ORDER BY created_at ASC, rowid ASC`, assistantMessageID)
if err != nil {
return AssistantCognitionTexts{}, fmt.Errorf("query assistant cognition: %w", err)
}
defer rows.Close()
var thinkingParts, reasoningParts, planningParts []string
for rows.Next() {
var eventType, message string
if err := rows.Scan(&eventType, &message); err != nil {
continue
}
msg := strings.TrimSpace(message)
if msg == "" {
continue
}
switch eventType {
case "thinking":
thinkingParts = append(thinkingParts, msg)
case "reasoning_chain":
reasoningParts = append(reasoningParts, msg)
case "planning":
planningParts = append(planningParts, msg)
}
}
return AssistantCognitionTexts{
Thinking: strings.Join(thinkingParts, "\n\n"),
ReasoningChain: strings.Join(reasoningParts, "\n\n"),
Planning: strings.Join(planningParts, "\n\n"),
}, nil
}
// AddProcessDetail 添加过程详情事件 // AddProcessDetail 添加过程详情事件
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error { func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
id := uuid.New().String() id := uuid.New().String()
+17 -3
View File
@@ -20,7 +20,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask") plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
checkpointBase := filepath.Join(tmp, "eino-checkpoints") checkpointBase := filepath.Join(tmp, "eino-checkpoints")
reductionBase := filepath.Join(tmp, "reduction") reductionBase := filepath.Join(tmp, "reduction")
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase) workspaceBase := filepath.Join(tmp, "workspace")
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase, workspaceBase)
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{}) conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
if err != nil { if err != nil {
@@ -36,6 +37,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
{plantaskBase, "task-1.json"}, {plantaskBase, "task-1.json"},
{checkpointBase, "runner-deep.ckpt"}, {checkpointBase, "runner-deep.ckpt"},
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"}, {filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
{filepath.Join(workspaceBase, "conversations"), "page.html"},
} { } {
dir := filepath.Join(base.root, seg) dir := filepath.Join(base.root, seg)
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
@@ -50,7 +52,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
t.Fatalf("DeleteConversation: %v", err) t.Fatalf("DeleteConversation: %v", err)
} }
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} { for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations"), filepath.Join(workspaceBase, "conversations")} {
dir := filepath.Join(base, seg) dir := filepath.Join(base, seg)
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) { if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr) t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
@@ -68,7 +70,8 @@ func TestDeleteProjectRemovesReductionDir(t *testing.T) {
defer db.Close() defer db.Close()
reductionBase := filepath.Join(tmp, "reduction") reductionBase := filepath.Join(tmp, "reduction")
db.SetEinoConversationDirs("", "", reductionBase) workspaceBase := filepath.Join(tmp, "workspace")
db.SetEinoConversationDirs("", "", reductionBase, workspaceBase)
project, err := db.CreateProject(&Project{Name: "cleanup test"}) project, err := db.CreateProject(&Project{Name: "cleanup test"})
if err != nil { if err != nil {
@@ -82,6 +85,13 @@ func TestDeleteProjectRemovesReductionDir(t *testing.T) {
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil { if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
t.Fatalf("write: %v", err) t.Fatalf("write: %v", err)
} }
workspaceDir := filepath.Join(workspaceBase, "projects", seg, "downloads")
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", workspaceDir, err)
}
if err := os.WriteFile(filepath.Join(workspaceDir, "app.js"), []byte("x"), 0o644); err != nil {
t.Fatalf("write workspace: %v", err)
}
if err := db.DeleteProject(project.ID); err != nil { if err := db.DeleteProject(project.ID); err != nil {
t.Fatalf("DeleteProject: %v", err) t.Fatalf("DeleteProject: %v", err)
@@ -91,4 +101,8 @@ func TestDeleteProjectRemovesReductionDir(t *testing.T) {
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) { if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr) t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
} }
projectWorkspaceDir := filepath.Join(workspaceBase, "projects", seg)
if _, statErr := os.Stat(projectWorkspaceDir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", projectWorkspaceDir, statErr)
}
} }
@@ -0,0 +1,60 @@
package database
import (
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestConversationProjectFilter(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "conversations.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
p, err := db.CreateProject(&Project{Name: "target-a", Status: "active"})
if err != nil {
t.Fatalf("CreateProject: %v", err)
}
convNone, err := db.CreateConversation("unbound", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation unbound: %v", err)
}
convBound, err := db.CreateConversation("bound", ConversationCreateMeta{ProjectID: p.ID})
if err != nil {
t.Fatalf("CreateConversation bound: %v", err)
}
totalAll, err := db.CountConversations("", "")
if err != nil || totalAll < 2 {
t.Fatalf("CountConversations all: total=%d err=%v", totalAll, err)
}
totalBound, err := db.CountConversations("", p.ID)
if err != nil || totalBound != 1 {
t.Fatalf("CountConversations project: total=%d err=%v", totalBound, err)
}
totalUnbound, err := db.CountConversations("", ProjectFilterUnbound)
if err != nil || totalUnbound != 1 {
t.Fatalf("CountConversations unbound: total=%d err=%v", totalUnbound, err)
}
listBound, err := db.ListConversations(10, 0, "", "", p.ID)
if err != nil || len(listBound) != 1 || listBound[0].ID != convBound.ID {
t.Fatalf("ListConversations project: %+v err=%v", listBound, err)
}
listUnbound, err := db.ListConversations(10, 0, "", "", ProjectFilterUnbound)
if err != nil || len(listUnbound) != 1 || listUnbound[0].ID != convNone.ID {
t.Fatalf("ListConversations unbound: %+v err=%v", listUnbound, err)
}
_ = convNone
_ = convBound
}
+21 -1
View File
@@ -52,6 +52,7 @@ type DB struct {
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs) einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs) einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs) einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
einoWorkspaceRootDir string // workspace_root_dir or default tmp/workspace (projects|conversations/<id> subdirs)
checkpointLoopName string checkpointLoopName string
checkpointStop chan struct{} checkpointStop chan struct{}
checkpointDone chan struct{} checkpointDone chan struct{}
@@ -161,13 +162,15 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation. // SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root. // plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only). // reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) { // workspaceRoot is agent.workspace_root_dir from config; empty uses tmp/workspace.
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot, workspaceRoot string) {
if db == nil { if db == nil {
return return
} }
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase) db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase) db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
db.einoReductionRootDir = strings.TrimSpace(reductionRoot) db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
db.einoWorkspaceRootDir = strings.TrimSpace(workspaceRoot)
} }
// initTables 初始化数据库表 // initTables 初始化数据库表
@@ -408,6 +411,8 @@ func (db *DB) initTables() error {
last_schedule_trigger_at DATETIME, last_schedule_trigger_at DATETIME,
last_schedule_error TEXT, last_schedule_error TEXT,
last_run_error TEXT, last_run_error TEXT,
project_id TEXT,
concurrency INTEGER NOT NULL DEFAULT 1,
status TEXT NOT NULL, status TEXT NOT NULL,
created_at DATETIME NOT NULL, created_at DATETIME NOT NULL,
started_at DATETIME, started_at DATETIME,
@@ -1137,6 +1142,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
} }
} }
var concurrencyCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
}
}
} else if concurrencyCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
}
}
return nil return nil
} }
+75
View File
@@ -0,0 +1,75 @@
package database
import (
"fmt"
"strings"
"time"
"go.uber.org/zap"
)
// DeleteHitlInterruptLogsByIDs deletes decided HITL audit logs by id (pending rows are skipped).
func (db *DB) DeleteHitlInterruptLogsByIDs(ids []string) (int64, error) {
if db == nil {
return 0, fmt.Errorf("database is nil")
}
clean := make([]string, 0, len(ids))
for _, id := range ids {
id = strings.TrimSpace(id)
if id != "" {
clean = append(clean, id)
}
}
if len(clean) == 0 {
return 0, nil
}
placeholders := strings.TrimRight(strings.Repeat("?,", len(clean)), ",")
q := fmt.Sprintf(`DELETE FROM hitl_interrupts WHERE status != 'pending' AND id IN (%s)`, placeholders)
args := make([]interface{}, len(clean))
for i, id := range clean {
args[i] = id
}
res, err := db.Exec(q, args...)
if err != nil {
db.logger.Error("批量删除人机协同审计日志失败", zap.Error(err), zap.Int("count", len(clean)))
return 0, fmt.Errorf("批量删除人机协同审计日志失败: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
// DeleteHitlInterruptLogsMatching deletes decided logs matching whereSQL (e.g. "WHERE 1=1 AND status != 'pending' ...").
func (db *DB) DeleteHitlInterruptLogsMatching(whereSQL string, args []interface{}) (int64, error) {
if db == nil {
return 0, fmt.Errorf("database is nil")
}
whereSQL = strings.TrimSpace(whereSQL)
if whereSQL == "" {
return 0, fmt.Errorf("where clause is required")
}
q := `DELETE FROM hitl_interrupts ` + whereSQL
res, err := db.Exec(q, args...)
if err != nil {
db.logger.Error("清空人机协同审计日志失败", zap.Error(err))
return 0, fmt.Errorf("清空人机协同审计日志失败: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
// PurgeHitlInterruptLogsBefore deletes decided logs with decided/created time before cutoff.
func (db *DB) PurgeHitlInterruptLogsBefore(cutoff time.Time) (int64, error) {
if db == nil {
return 0, fmt.Errorf("database is nil")
}
res, err := db.Exec(
`DELETE FROM hitl_interrupts WHERE status != 'pending' AND datetime(COALESCE(decided_at, created_at)) < datetime(?)`,
cutoff.UTC().Format(time.RFC3339),
)
if err != nil {
db.logger.Error("清理过期人机协同审计日志失败", zap.Error(err))
return 0, fmt.Errorf("清理过期人机协同审计日志失败: %w", err)
}
n, _ := res.RowsAffected()
return n, nil
}
+106
View File
@@ -0,0 +1,106 @@
package database
import (
"path/filepath"
"testing"
"time"
"go.uber.org/zap"
)
func ensureHitlInterruptsTable(t *testing.T, db *DB) {
t.Helper()
if _, err := db.Exec(`
CREATE TABLE IF NOT EXISTS hitl_interrupts (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
message_id TEXT,
mode TEXT NOT NULL,
tool_name TEXT NOT NULL,
tool_call_id TEXT,
payload TEXT,
status TEXT NOT NULL,
decision TEXT,
decision_comment TEXT,
created_at DATETIME NOT NULL,
decided_at DATETIME
);`); err != nil {
t.Fatalf("create hitl_interrupts: %v", err)
}
}
func TestDeleteHitlInterruptLogsByIDs_skipsPending(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "hitl.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
ensureHitlInterruptsTable(t, db)
now := time.Now().UTC().Format(time.RFC3339)
if _, err := db.Exec(`INSERT INTO hitl_interrupts
(id, conversation_id, mode, tool_name, status, created_at)
VALUES ('pending-1', 'c1', 'approval', 'exec', 'pending', ?)`, now); err != nil {
t.Fatalf("insert pending: %v", err)
}
if _, err := db.Exec(`INSERT INTO hitl_interrupts
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
VALUES ('done-1', 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, now, now); err != nil {
t.Fatalf("insert decided: %v", err)
}
deleted, err := db.DeleteHitlInterruptLogsByIDs([]string{"pending-1", "done-1"})
if err != nil {
t.Fatalf("DeleteHitlInterruptLogsByIDs: %v", err)
}
if deleted != 1 {
t.Fatalf("deleted = %d, want 1", deleted)
}
var status string
if err := db.QueryRow(`SELECT status FROM hitl_interrupts WHERE id = 'pending-1'`).Scan(&status); err != nil {
t.Fatalf("pending row missing: %v", err)
}
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'done-1'`).Scan(new(string)); err == nil {
t.Fatal("decided row should be deleted")
}
}
func TestPurgeHitlInterruptLogsBefore(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "hitl.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
ensureHitlInterruptsTable(t, db)
old := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339)
recent := time.Now().AddDate(0, 0, -1).UTC().Format(time.RFC3339)
for _, row := range []struct{ id, decided string }{
{"old-1", old},
{"new-1", recent},
} {
if _, err := db.Exec(`INSERT INTO hitl_interrupts
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
VALUES (?, 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, row.id, row.decided, row.decided); err != nil {
t.Fatalf("insert %s: %v", row.id, err)
}
}
cutoff := time.Now().AddDate(0, 0, -90)
deleted, err := db.PurgeHitlInterruptLogsBefore(cutoff)
if err != nil {
t.Fatalf("PurgeHitlInterruptLogsBefore: %v", err)
}
if deleted != 1 {
t.Fatalf("deleted = %d, want 1", deleted)
}
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'old-1'`).Scan(new(string)); err == nil {
t.Fatal("old row should be purged")
}
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'new-1'`).Scan(new(string)); err != nil {
t.Fatalf("new row should remain: %v", err)
}
}
+288 -26
View File
@@ -3,7 +3,6 @@ package database
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"sort"
"strings" "strings"
"time" "time"
@@ -227,6 +226,167 @@ func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolNa
return executions, nil return executions, nil
} }
func toolExecutionsFilterSQL(status, toolName string) (string, []interface{}) {
args := []interface{}{}
conditions := []string{}
if status != "" {
conditions = append(conditions, "status = ?")
args = append(args, status)
}
if toolName != "" {
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) == 0 {
return "", args
}
return ` WHERE ` + strings.Join(conditions, ` AND `), args
}
// ToolStatsSummary 工具调用汇总(全量聚合,不含逐工具明细)
type ToolStatsSummary struct {
TotalCalls int
SuccessCalls int
FailedCalls int
LastCallTime *time.Time
ToolCount int
}
// ToolStatsSummaryResult 汇总 + Top N 工具排行
type ToolStatsSummaryResult struct {
Summary ToolStatsSummary
TopTools []*mcp.ToolStats
}
// LoadToolStatsSummary 聚合统计信息,仅返回汇总与 Top N 工具(避免全量 map 传输)
func (db *DB) LoadToolStatsSummary(topN int) (*ToolStatsSummaryResult, error) {
if topN <= 0 {
topN = 6
}
if topN > 100 {
topN = 100
}
result := &ToolStatsSummaryResult{
TopTools: make([]*mcp.ToolStats, 0, topN),
}
summaryQuery := `
SELECT COUNT(*),
COALESCE(SUM(total_calls), 0),
COALESCE(SUM(success_calls), 0),
COALESCE(SUM(failed_calls), 0),
MAX(last_call_time)
FROM tool_stats
`
var lastCallRaw sql.NullString
err := db.QueryRow(summaryQuery).Scan(
&result.Summary.ToolCount,
&result.Summary.TotalCalls,
&result.Summary.SuccessCalls,
&result.Summary.FailedCalls,
&lastCallRaw,
)
if err != nil {
return nil, err
}
if lastCallRaw.Valid && strings.TrimSpace(lastCallRaw.String) != "" {
if t, parseErr := time.Parse(time.RFC3339Nano, lastCallRaw.String); parseErr == nil {
result.Summary.LastCallTime = &t
} else if t, parseErr := time.Parse("2006-01-02 15:04:05.999999999-07:00", lastCallRaw.String); parseErr == nil {
result.Summary.LastCallTime = &t
} else if t, parseErr := time.Parse("2006-01-02 15:04:05", lastCallRaw.String); parseErr == nil {
result.Summary.LastCallTime = &t
}
}
topQuery := `
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
FROM tool_stats
WHERE total_calls > 0
ORDER BY total_calls DESC, tool_name ASC
LIMIT ?
`
rows, err := db.Query(topQuery, topN)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var stat mcp.ToolStats
var lastCallTime sql.NullTime
if err := rows.Scan(
&stat.ToolName,
&stat.TotalCalls,
&stat.SuccessCalls,
&stat.FailedCalls,
&lastCallTime,
); err != nil {
db.logger.Warn("加载 Top 工具统计失败", zap.Error(err))
continue
}
if lastCallTime.Valid {
stat.LastCallTime = &lastCallTime.Time
}
result.TopTools = append(result.TopTools, &stat)
}
return result, nil
}
// LoadToolExecutionListPage 分页加载执行记录列表(不含 arguments/result,供监控列表使用)
func (db *DB) LoadToolExecutionListPage(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
if limit <= 0 {
limit = 20
}
if limit > 100 {
limit = 100
}
query := `
SELECT id, tool_name, status, start_time, end_time, duration_ms
FROM tool_executions
`
whereSQL, args := toolExecutionsFilterSQL(status, toolName)
query += whereSQL + ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
executions := make([]*mcp.ToolExecution, 0, limit)
for rows.Next() {
var exec mcp.ToolExecution
var endTime sql.NullTime
var durationMs sql.NullInt64
if err := rows.Scan(
&exec.ID,
&exec.ToolName,
&exec.Status,
&exec.StartTime,
&endTime,
&durationMs,
); err != nil {
db.logger.Warn("加载执行记录列表失败", zap.Error(err))
continue
}
if endTime.Valid {
exec.EndTime = &endTime.Time
}
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
executions = append(executions, &exec)
}
return executions, nil
}
// GetToolExecution 根据ID获取单条工具执行记录 // GetToolExecution 根据ID获取单条工具执行记录
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) { func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
query := ` query := `
@@ -288,6 +448,93 @@ func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
return &exec, nil return &exec, nil
} }
// CancelOrphanedRunningToolExecutions 将仍为 running 的记录批量标记为 cancelled(如进程重启后无对应执行协程)。
func (db *DB) CancelOrphanedRunningToolExecutions(endTime time.Time, errMsg string) (int64, error) {
errMsg = strings.TrimSpace(errMsg)
if errMsg == "" {
errMsg = "执行已中断(服务重启或会话结束)"
}
query := `
UPDATE tool_executions
SET status = 'cancelled',
error = ?,
end_time = ?,
duration_ms = MAX(0, CAST((julianday(?) - julianday(start_time)) * 86400000 AS INTEGER))
WHERE status = 'running'
`
res, err := db.Exec(query, errMsg, endTime, endTime)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// FinalizeStaleRunningToolExecutions 将「非活跃且超过 minAge」的 running 记录标记为 cancelled。
// activeIDs 为当前进程内仍登记 cancel 的 executionId;不在集合内且已超时的视为孤儿记录。
func (db *DB) FinalizeStaleRunningToolExecutions(endTime time.Time, minAge time.Duration, activeIDs map[string]struct{}, errMsg string) (int64, error) {
errMsg = strings.TrimSpace(errMsg)
if errMsg == "" {
errMsg = "执行已中断(会话已结束)"
}
if minAge < 0 {
minAge = 0
}
cutoff := endTime.Add(-minAge)
rows, err := db.Query(`
SELECT id, start_time FROM tool_executions
WHERE status = 'running' AND start_time <= ?
`, cutoff)
if err != nil {
return 0, err
}
defer rows.Close()
type staleRow struct {
id string
startTime time.Time
}
var stale []staleRow
for rows.Next() {
var row staleRow
if err := rows.Scan(&row.id, &row.startTime); err != nil {
db.logger.Warn("读取 stale running 执行记录失败", zap.Error(err))
continue
}
if activeIDs != nil {
if _, active := activeIDs[row.id]; active {
continue
}
}
stale = append(stale, row)
}
if err := rows.Err(); err != nil {
return 0, err
}
if len(stale) == 0 {
return 0, nil
}
var affected int64
for _, row := range stale {
durationMs := endTime.Sub(row.startTime).Milliseconds()
if durationMs < 0 {
durationMs = 0
}
res, err := db.Exec(`
UPDATE tool_executions
SET status = 'cancelled', error = ?, end_time = ?, duration_ms = ?
WHERE id = ? AND status = 'running'
`, errMsg, endTime, durationMs, row.id)
if err != nil {
db.logger.Warn("更新 stale running 执行记录失败", zap.Error(err), zap.String("executionId", row.id))
continue
}
n, _ := res.RowsAffected()
affected += n
}
return affected, nil
}
// DeleteToolExecution 删除工具执行记录 // DeleteToolExecution 删除工具执行记录
func (db *DB) DeleteToolExecution(id string) error { func (db *DB) DeleteToolExecution(id string) error {
query := `DELETE FROM tool_executions WHERE id = ?` query := `DELETE FROM tool_executions WHERE id = ?`
@@ -600,13 +847,28 @@ func truncateCallsTimelineBucket(t time.Time, dailyBuckets bool) time.Time {
// LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界) // LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界)
func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) { func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) {
// 在 Go 侧按本地时区分桶,避免 SQLite strftime 对 UTC 存储时间分桶后再误当本地时间解析(差 8h 等问题) var query string
query := ` if dailyBuckets {
SELECT start_time, query = `
CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END AS failed SELECT date(start_time, 'localtime') AS bucket,
FROM tool_executions COUNT(*) AS total,
WHERE start_time >= ? SUM(CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END) AS failed
` FROM tool_executions
WHERE start_time >= ?
GROUP BY bucket
ORDER BY bucket
`
} else {
query = `
SELECT strftime('%Y-%m-%d %H:00:00', start_time, 'localtime') AS bucket,
COUNT(*) AS total,
SUM(CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END) AS failed
FROM tool_executions
WHERE start_time >= ?
GROUP BY bucket
ORDER BY bucket
`
}
rows, err := db.Query(query, since) rows, err := db.Query(query, since)
if err != nil { if err != nil {
@@ -614,35 +876,35 @@ func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTime
} }
defer rows.Close() defer rows.Close()
bucketMap := make(map[time.Time]struct{ total, failed int }) buckets := make([]CallsTimelineBucket, 0)
for rows.Next() { for rows.Next() {
var startTime time.Time var bucketStr string
var failed int var total, failed int
if err := rows.Scan(&startTime, &failed); err != nil { if err := rows.Scan(&bucketStr, &total, &failed); err != nil {
db.logger.Warn("加载调用趋势失败", zap.Error(err)) db.logger.Warn("加载调用趋势失败", zap.Error(err))
continue continue
} }
key := truncateCallsTimelineBucket(startTime, dailyBuckets) bucketTime, err := parseCallsTimelineBucket(bucketStr, dailyBuckets)
entry := bucketMap[key] if err != nil {
entry.total++ db.logger.Warn("解析调用趋势时间桶失败", zap.Error(err), zap.String("bucket", bucketStr))
entry.failed += failed continue
bucketMap[key] = entry }
}
buckets := make([]CallsTimelineBucket, 0, len(bucketMap))
for bucketTime, counts := range bucketMap {
buckets = append(buckets, CallsTimelineBucket{ buckets = append(buckets, CallsTimelineBucket{
BucketTime: bucketTime, BucketTime: bucketTime,
Total: counts.total, Total: total,
Failed: counts.failed, Failed: failed,
}) })
} }
sort.Slice(buckets, func(i, j int) bool {
return buckets[i].BucketTime.Before(buckets[j].BucketTime)
})
return buckets, nil return buckets, nil
} }
func parseCallsTimelineBucket(bucketStr string, dailyBuckets bool) (time.Time, error) {
if dailyBuckets {
return time.ParseInLocation("2006-01-02", bucketStr, time.Local)
}
return time.ParseInLocation("2006-01-02 15:04:05", bucketStr, time.Local)
}
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时) // DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
// 如果统计信息变为0,则删除该统计记录 // 如果统计信息变为0,则删除该统计记录
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error { func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
+102
View File
@@ -0,0 +1,102 @@
package database
import (
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestCancelOrphanedRunningToolExecutions(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
start := time.Now().Add(-2 * time.Hour)
exec := &mcp.ToolExecution{
ID: "orphan-hydra",
ToolName: "hydra",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "running",
StartTime: start,
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
end := time.Now()
n, err := db.CancelOrphanedRunningToolExecutions(end, "执行已中断(服务重启)")
if err != nil {
t.Fatalf("CancelOrphanedRunningToolExecutions: %v", err)
}
if n != 1 {
t.Fatalf("expected 1 row updated, got %d", n)
}
got, err := db.GetToolExecution("orphan-hydra")
if err != nil {
t.Fatalf("GetToolExecution: %v", err)
}
if got.Status != "cancelled" {
t.Fatalf("expected cancelled, got %s", got.Status)
}
if got.EndTime == nil {
t.Fatal("expected end_time to be set")
}
if got.Duration <= 0 {
t.Fatalf("expected positive duration, got %v", got.Duration)
}
}
func TestFinalizeStaleRunningToolExecutions_skipsActive(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
now := time.Now()
oldStart := now.Add(-5 * time.Minute)
if err := db.SaveToolExecution(&mcp.ToolExecution{
ID: "stale", ToolName: "hydra", Status: "running", StartTime: oldStart,
}); err != nil {
t.Fatalf("SaveToolExecution stale: %v", err)
}
if err := db.SaveToolExecution(&mcp.ToolExecution{
ID: "active", ToolName: "hydra", Status: "running", StartTime: oldStart,
}); err != nil {
t.Fatalf("SaveToolExecution active: %v", err)
}
active := map[string]struct{}{"active": {}}
n, err := db.FinalizeStaleRunningToolExecutions(now, time.Minute, active, "执行已中断(会话已结束)")
if err != nil {
t.Fatalf("FinalizeStaleRunningToolExecutions: %v", err)
}
if n != 1 {
t.Fatalf("expected 1 stale row updated, got %d", n)
}
stale, err := db.GetToolExecution("stale")
if err != nil {
t.Fatalf("GetToolExecution stale: %v", err)
}
if stale.Status != "cancelled" {
t.Fatalf("stale expected cancelled, got %s", stale.Status)
}
activeExec, err := db.GetToolExecution("active")
if err != nil {
t.Fatalf("GetToolExecution active: %v", err)
}
if activeExec.Status != "running" {
t.Fatalf("active expected running, got %s", activeExec.Status)
}
}
+86
View File
@@ -0,0 +1,86 @@
package database
import (
"fmt"
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestLoadToolStatsSummaryAndListPage(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor-summary.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
now := time.Now()
tools := []struct {
name string
calls int
ok int
fail int
result string
}{
{"alpha::run", 10, 9, 1, `{"content":[{"type":"text","text":"` + string(make([]byte, 64*1024)) + `"}]}`},
{"beta::scan", 5, 5, 0, `{"content":[{"type":"text","text":"ok"}]}`},
{"gamma::ping", 1, 1, 0, `{"content":[{"type":"text","text":"pong"}]}`},
}
for _, tool := range tools {
if err := db.UpdateToolStats(tool.name, tool.calls, tool.ok, tool.fail, &now); err != nil {
t.Fatalf("UpdateToolStats(%s): %v", tool.name, err)
}
for j := 0; j < tool.calls; j++ {
exec := &mcp.ToolExecution{
ID: fmt.Sprintf("%s-exec-%d", tool.name, j),
ToolName: tool.name,
Arguments: map[string]interface{}{"n": j},
Status: "completed",
StartTime: now.Add(-time.Duration(j) * time.Minute),
Result: &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: tool.result}}},
}
end := exec.StartTime.Add(time.Second)
exec.EndTime = &end
exec.Duration = time.Second
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
}
}
summary, err := db.LoadToolStatsSummary(2)
if err != nil {
t.Fatalf("LoadToolStatsSummary: %v", err)
}
if summary.Summary.ToolCount != 3 {
t.Fatalf("toolCount = %d, want 3", summary.Summary.ToolCount)
}
if summary.Summary.TotalCalls != 16 {
t.Fatalf("totalCalls = %d, want 16", summary.Summary.TotalCalls)
}
if len(summary.TopTools) != 2 {
t.Fatalf("top tools = %d, want 2", len(summary.TopTools))
}
if summary.TopTools[0].ToolName != "alpha::run" {
t.Fatalf("top tool = %q, want alpha::run", summary.TopTools[0].ToolName)
}
list, err := db.LoadToolExecutionListPage(0, 5, "", "")
if err != nil {
t.Fatalf("LoadToolExecutionListPage: %v", err)
}
if len(list) != 5 {
t.Fatalf("list len = %d, want 5", len(list))
}
for _, exec := range list {
if exec.Arguments != nil || exec.Result != nil || exec.Error != "" {
t.Fatalf("expected lite execution row, got args/result/error on %s", exec.ID)
}
}
}
+2 -2
View File
@@ -2,8 +2,8 @@ package einomcp
import "sync" import "sync"
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire, // ToolInvokeNotifyHolder 由 Eino run loop 与 MCP/execute 桥共享;Fire 在工具原始返回时触发。
// 用于清除 pending tool_calltool_result ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)。 // UI 的 tool_result 须等 ADK schema.Tool 事件reduction 后正文),不在此 holder 的回调里推送
type ToolInvokeNotifyHolder struct { type ToolInvokeNotifyHolder struct {
mu sync.RWMutex mu sync.RWMutex
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
+190 -401
View File
@@ -21,7 +21,6 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
@@ -78,6 +77,13 @@ type responsePlanAgg struct {
b strings.Builder b strings.Builder
} }
// thinkingBuf aggregates thinking_stream_* / reasoning_chain_stream_* before flush to process_details.
type thinkingBuf struct {
b strings.Builder
meta map[string]interface{}
persistAs string // "thinking" | "reasoning_chain"
}
func normalizeProcessDetailText(s string) string { func normalizeProcessDetailText(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n") s = strings.ReplaceAll(s, "\r\n", "\n")
s = strings.ReplaceAll(s, "\r", "\n") s = strings.ReplaceAll(s, "\r", "\n")
@@ -178,10 +184,10 @@ type AgentHandler struct {
} }
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
batchCronParser cron.Parser batchCronParser cron.Parser
batchRunnerMu sync.Mutex
batchRunning map[string]struct{}
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
hitlWhitelistSaver HitlToolWhitelistSaver hitlWhitelistSaver HitlToolWhitelistSaver
hitlStrategySaver HitlAuditStrategySaver
auditLLM *openai.Client
audit *audit.Service audit *audit.Service
} }
@@ -190,14 +196,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
h.audit = s h.audit = s
} }
// TaskManager 返回 Agent 任务管理器(供 MCP 监控页终止 Eino execute 等)。
func (h *AgentHandler) TaskManager() *AgentTaskManager {
if h == nil {
return nil
}
return h.tasks
}
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent). // CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) { func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
if h == nil || conversationID == "" || h.tasks == nil { if h == nil || conversationID == "" || h.tasks == nil {
return return
} }
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" { h.cancelActiveMCPToolForConversation(conversationID)
h.agent.CancelMCPToolExecutionWithNote(execID, "") h.tasks.AbortActiveEinoExecute(conversationID, "")
}
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok { if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID)) h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
} else if err != nil { } else if err != nil {
@@ -205,9 +218,19 @@ func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
} }
} }
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string) {
if h == nil || h.tasks == nil || h.agent == nil {
return
}
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
h.agent.CancelMCPToolExecutionWithNote(execID, "")
}
}
// HitlToolWhitelistSaver 合并/设置 HITL 免审批工具到全局配置并落盘
type HitlToolWhitelistSaver interface { type HitlToolWhitelistSaver interface {
MergeHitlToolWhitelistIntoConfig(add []string) error MergeHitlToolWhitelistIntoConfig(add []string) error
SetHitlToolWhitelist(tools []string) error
} }
// NewAgentHandler 创建新的Agent处理器 // NewAgentHandler 创建新的Agent处理器
@@ -223,6 +246,11 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
bus := NewTaskEventBus() bus := NewTaskEventBus()
tm := NewAgentTaskManager() tm := NewAgentTaskManager()
tm.SetTaskEventBus(bus) tm.SetTaskEventBus(bus)
llmHTTP := &http.Client{Timeout: 2 * time.Minute}
var llmCfg *config.OpenAIConfig
if cfg != nil {
llmCfg = &cfg.OpenAI
}
handler := &AgentHandler{ handler := &AgentHandler{
agent: agent, agent: agent,
db: db, db: db,
@@ -233,8 +261,9 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
config: cfg, config: cfg,
hitlManager: NewHITLManager(db, logger), hitlManager: NewHITLManager(db, logger),
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
batchRunning: make(map[string]struct{}), auditLLM: openai.NewClient(llmCfg, llmHTTP, logger),
} }
tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation)
if err := handler.hitlManager.EnsureSchema(); err != nil { if err := handler.hitlManager.EnsureSchema(); err != nil {
logger.Warn("初始化 HITL 表失败", zap.Error(err)) logger.Warn("初始化 HITL 表失败", zap.Error(err))
} }
@@ -307,6 +336,7 @@ func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientInten
type HITLRequest struct { type HITLRequest struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Mode string `json:"mode,omitempty"` Mode string `json:"mode,omitempty"`
Reviewer string `json:"reviewer,omitempty"` // human | audit_agent
SensitiveTools []string `json:"sensitiveTools,omitempty"` SensitiveTools []string `json:"sensitiveTools,omitempty"`
TimeoutSeconds int `json:"timeoutSeconds,omitempty"` TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
} }
@@ -648,7 +678,7 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
) (string, string, error) { ) (string, string, error) {
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent( resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID), conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.agentSessionContextBlock(conversationID),
) )
if errMA != nil { if errMA != nil {
*taskStatus = "failed" *taskStatus = "failed"
@@ -669,7 +699,7 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
resultMA, errMA := multiagent.RunDeepAgent( resultMA, errMA := multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID), h.agentsMarkdownDir, orchestration, nil, h.agentSessionContextBlock(conversationID),
) )
if errMA != nil { if errMA != nil {
*taskStatus = "failed" *taskStatus = "failed"
@@ -836,11 +866,6 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
// thinking_stream_*ReAct 等助手正文流)与 reasoning_chain_stream_*Eino ReasoningContent): // thinking_stream_*ReAct 等助手正文流)与 reasoning_chain_stream_*Eino ReasoningContent):
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。 // 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
type thinkingBuf struct {
b strings.Builder
meta map[string]interface{}
persistAs string // "thinking" | "reasoning_chain"
}
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
flushedThinking := make(map[string]bool) // streamId -> flushed flushedThinking := make(map[string]bool) // streamId -> flushed
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
@@ -853,6 +878,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta // response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。 // 聚合为一条 planning 写入 process_details,刷新后与线上一致。
var respPlan responsePlanAgg var respPlan responsePlanAgg
if assistantMessageID != "" {
h.tasks.SetHitlAssistantMessageID(conversationID, assistantMessageID)
}
syncHitlCognition := func() {
h.syncHitlCognitionFromProgress(conversationID, assistantMessageID, thinkingStreams, &respPlan)
}
flushResponsePlan := func() { flushResponsePlan := func() {
if assistantMessageID == "" { if assistantMessageID == "" {
return return
@@ -872,6 +903,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil { if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil {
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning")) h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning"))
} }
syncHitlCognition()
respPlan.meta = nil respPlan.meta = nil
respPlan.b.Reset() respPlan.b.Reset()
} }
@@ -908,6 +940,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
flushedThinking[sid] = true flushedThinking[sid] = true
} }
syncHitlCognition()
} }
return func(eventType, message string, data interface{}) { return func(eventType, message string, data interface{}) {
@@ -968,6 +1001,25 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
} }
if eventType == "tool_result" {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
toolCallID, _ := dataMap["toolCallId"].(string)
success := true
if v, ok := dataMap["success"].(bool); ok {
success = v
}
resultText := ""
if r, ok := dataMap["result"].(string); ok {
resultText = r
}
if strings.TrimSpace(resultText) == "" {
resultText = message
}
h.recordHitlToolExecutionResult(conversationID, toolCallID, toolName, success, resultText)
}
}
// 处理知识检索日志记录 // 处理知识检索日志记录
if eventType == "tool_result" && h.knowledgeManager != nil { if eventType == "tool_result" && h.knowledgeManager != nil {
if dataMap, ok := data.(map[string]interface{}); ok { if dataMap, ok := data.(map[string]interface{}); ok {
@@ -1175,6 +1227,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
respPlan.meta[k] = v respPlan.meta[k] = v
} }
} }
syncHitlCognition()
return return
} }
if eventType == "response" { if eventType == "response" {
@@ -1244,6 +1297,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
} }
} }
syncHitlCognition()
return return
} }
@@ -1295,10 +1349,60 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
} }
// cancelToolContinueAfter 仅终止当前工具调用,不停止整条 Agent 任务(对话「中断并继续」与 MCP 监控终止共用)。
func (h *AgentHandler) cancelToolContinueAfter(conversationID, preferredExecID, note string) (bool, gin.H) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" || h.tasks.GetTask(conversationID) == nil {
return false, nil
}
note = strings.TrimSpace(note)
execID := strings.TrimSpace(preferredExecID)
if execID == "" {
execID = h.tasks.ActiveMCPExecutionID(conversationID)
}
if execID != "" {
if h.agent.CancelMCPToolExecutionWithNote(execID, note) {
return true, gin.H{
"status": "tool_abort_requested",
"conversationId": conversationID,
"executionId": execID,
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
}
}
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
return true, gin.H{
"status": "tool_abort_requested",
"conversationId": conversationID,
"executionId": execID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
}
}
return false, nil
}
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
return true, gin.H{
"status": "tool_abort_requested",
"conversationId": conversationID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
}
}
return false, nil
}
// CancelAgentLoop 取消正在执行的任务 // CancelAgentLoop 取消正在执行的任务
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
var req struct { var req struct {
ConversationID string `json:"conversationId" binding:"required"` ConversationID string `json:"conversationId" binding:"required"`
ExecutionID string `json:"executionId,omitempty"`
Reason string `json:"reason,omitempty"` Reason string `json:"reason,omitempty"`
ContinueAfter bool `json:"continueAfter,omitempty"` ContinueAfter bool `json:"continueAfter,omitempty"`
} }
@@ -1313,42 +1417,20 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
return return
} }
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
note := strings.TrimSpace(req.Reason) note := strings.TrimSpace(req.Reason)
if execID != "" { activeExec := strings.TrimSpace(h.tasks.ActiveMCPExecutionID(req.ConversationID))
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) { if ok, payload := h.cancelToolContinueAfter(req.ConversationID, strings.TrimSpace(req.ExecutionID), note); ok {
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"}) execID, _ := payload["executionId"].(string)
return h.logger.Info("对话页仅终止当前工具",
}
h.logger.Info("对话页仅终止当前 MCP 工具",
zap.String("conversationId", req.ConversationID), zap.String("conversationId", req.ConversationID),
zap.String("executionId", execID), zap.String("executionId", execID),
zap.Bool("hasNote", note != ""), zap.Bool("hasNote", note != ""),
) )
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, payload)
"status": "tool_abort_requested",
"conversationId": req.ConversationID,
"executionId": execID,
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
})
return return
} }
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) { if activeExec != "" {
h.logger.Info("对话页仅终止当前 Eino execute", c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
zap.String("conversationId", req.ConversationID),
zap.Bool("hasNote", note != ""),
)
c.JSON(http.StatusOK, gin.H{
"status": "tool_abort_requested",
"conversationId": req.ConversationID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
})
return return
} }
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。 // 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
@@ -1380,6 +1462,8 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
var cause error = ErrTaskCancelled var cause error = ErrTaskCancelled
msg := "已提交取消请求,任务将在当前步骤完成后停止。" msg := "已提交取消请求,任务将在当前步骤完成后停止。"
h.cancelActiveMCPToolForConversation(req.ConversationID)
h.tasks.AbortActiveEinoExecute(req.ConversationID, "")
ok, err := h.tasks.CancelTask(req.ConversationID, cause) ok, err := h.tasks.CancelTask(req.ConversationID, cause)
if err != nil { if err != nil {
h.logger.Error("取消任务失败", zap.Error(err)) h.logger.Error("取消任务失败", zap.Error(err))
@@ -1446,17 +1530,51 @@ func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) {
} }
} }
// enrichAgentTasksWithConversationTitles 为任务列表附加当前会话标题(供顶栏/任务页展示,重命名后自动同步)
func (h *AgentHandler) enrichAgentTasksWithConversationTitles(tasks []*AgentTask) {
if h == nil || h.db == nil {
return
}
for _, task := range tasks {
if task == nil || strings.TrimSpace(task.ConversationID) == "" {
continue
}
if title, err := h.db.GetConversationTitle(task.ConversationID); err == nil {
task.Title = strings.TrimSpace(title)
}
}
}
// enrichCompletedTasksWithConversationTitles 为已完成任务附加当前会话标题
func (h *AgentHandler) enrichCompletedTasksWithConversationTitles(tasks []*CompletedTask) {
if h == nil || h.db == nil {
return
}
for _, task := range tasks {
if task == nil || strings.TrimSpace(task.ConversationID) == "" {
continue
}
if title, err := h.db.GetConversationTitle(task.ConversationID); err == nil {
task.Title = strings.TrimSpace(title)
}
}
}
// ListAgentTasks 列出所有运行中的任务 // ListAgentTasks 列出所有运行中的任务
func (h *AgentHandler) ListAgentTasks(c *gin.Context) { func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
tasks := h.tasks.GetActiveTasks()
h.enrichAgentTasksWithConversationTitles(tasks)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"tasks": h.tasks.GetActiveTasks(), "tasks": tasks,
}) })
} }
// ListCompletedTasks 列出最近完成的任务历史 // ListCompletedTasks 列出最近完成的任务历史
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) { func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
tasks := h.tasks.GetCompletedTasks()
h.enrichCompletedTasksWithConversationTitles(tasks)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"tasks": h.tasks.GetCompletedTasks(), "tasks": tasks,
}) })
} }
@@ -1470,6 +1588,7 @@ type BatchTaskRequest struct {
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选) ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
} }
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。 // batchQueueWantsEino 队列是否配置为走 Eino 多代理。
@@ -1529,7 +1648,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
nextRunAt = &next nextRunAt = &next
} }
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks) queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
if createErr != nil { if createErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
return return
@@ -1719,15 +1838,16 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) { func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
var req struct { var req struct {
Title string `json:"title"` Title string `json:"title"`
Role string `json:"role"` Role string `json:"role"`
AgentMode string `json:"agentMode"` AgentMode string `json:"agentMode"`
Concurrency *int `json:"concurrency"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil { if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@@ -1802,9 +1922,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
// DeleteBatchQueue 删除批量任务队列 // DeleteBatchQueue 删除批量任务队列
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
success := h.batchTaskManager.DeleteQueue(queueID) if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
if !success { switch {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) case errors.Is(err, ErrBatchQueueNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
case errors.Is(err, ErrBatchQueueExecutorActive):
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
case errors.Is(err, ErrBatchQueueStillRunning):
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return return
} }
if h.audit != nil { if h.audit != nil {
@@ -1898,7 +2026,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动 // 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused { if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
h.forceUnmarkBatchQueueRunning(queueID) h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
} }
autoStarted := true autoStarted := true
@@ -1957,26 +2085,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
} }
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
h.batchRunnerMu.Lock()
defer h.batchRunnerMu.Unlock()
if _, exists := h.batchRunning[queueID]; exists {
return false
}
h.batchRunning[queueID] = struct{}{}
return true
}
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
h.batchRunnerMu.Lock()
defer h.batchRunnerMu.Unlock()
delete(h.batchRunning, queueID)
}
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
h.unmarkBatchQueueRunning(queueID)
}
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
expr := strings.TrimSpace(cronExpr) expr := strings.TrimSpace(cronExpr)
if expr == "" { if expr == "" {
@@ -1992,43 +2100,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断 // 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
if !h.markBatchQueueRunning(queueID) { if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
return true, nil return true, nil
} }
queue, exists := h.batchTaskManager.GetBatchQueue(queueID) queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists { if !exists {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
return false, nil return false, nil
} }
if scheduled { if scheduled {
if queue.ScheduleMode != "cron" { if queue.ScheduleMode != "cron" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("队列未启用 cron 调度") err := fmt.Errorf("队列未启用 cron 调度")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("当前队列状态不允许被调度执行") err := fmt.Errorf("当前队列状态不允许被调度执行")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
if !h.batchTaskManager.ResetQueueForRerun(queueID) { if !h.batchTaskManager.ResetQueueForRerun(queueID) {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("重置队列失败") err := fmt.Errorf("重置队列失败")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
queue, _ = h.batchTaskManager.GetBatchQueue(queueID) queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
} else if queue.Status != "pending" && queue.Status != "paused" { } else if queue.Status != "pending" && queue.Status != "paused" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
return true, fmt.Errorf("队列状态不允许启动") return true, fmt.Errorf("队列状态不允许启动")
} }
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) { if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理") err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
if scheduled { if scheduled {
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
@@ -2080,325 +2188,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
} }
} }
// executeBatchQueue 执行批量任务队列
func (h *AgentHandler) executeBatchQueue(queueID string) {
defer h.unmarkBatchQueueRunning(queueID)
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
for {
// 检查队列状态
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
break
}
// 获取下一个任务
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
if !hasNext {
// 所有任务完成:汇总子任务失败信息便于排障
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
lastRunErr := ""
if ok {
for _, t := range q.Tasks {
if t.Status == "failed" && t.Error != "" {
lastRunErr = t.Error
}
}
}
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
break
}
// 更新任务状态为运行中
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
// 创建新对话
title := safeTruncateString(task.Message, 50)
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
break
}
continue
}
conversationID = conv.ID
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
// 应用角色用户提示词和工具配置
finalMessage := task.Message
var roleTools []string // 角色配置的工具列表
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + task.Message
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
// 预先创建助手消息,以便关联过程详情
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
// 如果创建失败,继续执行但不保存过程详情
assistantMsg = nil
}
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
// 需要把进度事件镜像到 TaskEventBusGET /api/agent-loop/task-events 会订阅这里)。
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
var progressCallback func(eventType, message string, data interface{})
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
func() {
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
registered := false
finishStatus := "completed"
defer func() {
h.batchTaskManager.SetTaskCancel(queueID, nil)
timeoutCancel()
if registered {
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
if h.taskEventBus != nil {
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
if b, err := json.Marshal(ev); err == nil {
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
}
}
h.tasks.FinishTask(conversationID, finishStatus)
}
cancelWithCause(nil)
}()
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
sendEvent := func(eventType, message string, data interface{}) {
if h.taskEventBus == nil {
return
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, err := json.Marshal(ev)
if err != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
line := make([]byte, 0, len(b)+8)
line = append(line, []byte("data: ")...)
line = append(line, b...)
line = append(line, '\n', '\n')
h.taskEventBus.Publish(conversationID, line)
}
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
h.logger.Warn("批量队列子任务注册会话运行状态失败",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID),
zap.Error(err))
failMsg := err.Error()
if errors.Is(err, ErrTaskAlreadyRunning) {
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
return
}
registered = true
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
useBatchMulti := false
batchOrch := "deep"
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
if am == "multi" {
am = "deep"
}
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
useBatchMulti = true
batchOrch = config.NormalizeMultiAgentOrchestration(am)
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
useBatchMulti = true
batchOrch = "deep"
}
var resultMA *multiagent.RunResult
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
}
}
if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errStr := runErr.Error()
partialResp := ""
if resultMA != nil {
partialResp = resultMA.Response
}
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
errors.Is(runErr, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
if isTimeout {
finishStatus = "timeout"
} else if isCancelled {
finishStatus = "cancelled"
} else {
finishStatus = "failed"
}
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
// 如果执行结果中有更具体的取消消息,使用它
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
cancelMsg = partialResp
}
// 更新助手消息内容
if assistantMessageID != "" {
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存取消详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
if errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
} else {
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
errorMsg := "执行失败: " + runErr.Error()
// 更新助手消息内容
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
errorMsg,
time.Now(), assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存错误详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
}
} else {
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
resText := resultMA.Response
mcpIDs := resultMA.MCPExecutionIDs
lastIn := resultMA.LastAgentTraceInput
lastOut := resultMA.LastAgentTraceOutput
// 更新助手消息内容
if assistantMessageID != "" {
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
// 如果更新失败,尝试创建新消息
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
// 保存代理轨迹
if lastIn != "" || lastOut != "" {
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
// 保存结果
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
}
}()
// 移动到下一个任务
h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
break
}
// 检查是否被取消或暂停
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue.Status == "cancelled" || queue.Status == "paused" {
break
}
}
}
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。 // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。 // 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
+352
View File
@@ -0,0 +1,352 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
)
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
func (h *AgentHandler) executeBatchQueue(queueID string) {
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
return
}
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
h.runBatchQueueWorker(queueID)
}()
}
wg.Wait()
h.tryFinalizeBatchQueue(queueID)
}
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
for {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if batchQueueExecutionShouldStop(queue, exists) {
return
}
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
if !ok {
if !h.batchTaskManager.HasRunningTasks(queueID) {
return
}
time.Sleep(batchQueueWorkerIdlePoll)
continue
}
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue == nil {
return
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
h.executeOneBatchSubTask(queueID, queue, task)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
return
}
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
if batchQueueExecutionShouldStop(queue, exists) {
if !exists {
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
}
return
}
}
}
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue == nil {
return
}
if queue.Status != BatchQueueStatusRunning {
return
}
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
return
}
lastRunErr := ""
for _, t := range queue.Tasks {
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
lastRunErr = t.Error
}
}
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
}
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
title := safeTruncateString(task.Message, 50)
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
return
}
conversationID := conv.ID
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
finalMessage := task.Message
var roleTools []string
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + task.Message
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
}
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
assistantMsg = nil
}
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
registered := false
finishStatus := "completed"
defer func() {
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
timeoutCancel()
if registered {
if h.taskEventBus != nil {
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
if b, err := json.Marshal(ev); err == nil {
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
}
}
h.tasks.FinishTask(conversationID, finishStatus)
}
cancelWithCause(nil)
}()
sendEvent := func(eventType, message string, data interface{}) {
if h.taskEventBus == nil {
return
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, err := json.Marshal(ev)
if err != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
line := make([]byte, 0, len(b)+8)
line = append(line, []byte("data: ")...)
line = append(line, b...)
line = append(line, '\n', '\n')
h.taskEventBus.Publish(conversationID, line)
}
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
h.logger.Warn("批量队列子任务注册会话运行状态失败",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID),
zap.Error(err))
failMsg := err.Error()
if errors.Is(err, ErrTaskAlreadyRunning) {
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
return
}
registered = true
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
useBatchMulti := false
batchOrch := "deep"
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
if am == "multi" {
am = "deep"
}
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
useBatchMulti = true
batchOrch = config.NormalizeMultiAgentOrchestration(am)
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
useBatchMulti = true
batchOrch = "deep"
}
var resultMA *multiagent.RunResult
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.agentSessionContextBlock(conversationID))
default:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.agentSessionContextBlock(conversationID))
}
}
if runErr != nil {
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
return
}
if resultMA == nil {
h.logger.Error("批量任务执行成功但无结果对象",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
return
}
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
resText := resultMA.Response
mcpIDs := resultMA.MCPExecutionIDs
lastIn := resultMA.LastAgentTraceInput
lastOut := resultMA.LastAgentTraceOutput
if assistantMessageID != "" {
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
if lastIn != "" || lastOut != "" {
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
}
func (h *AgentHandler) handleBatchSubTaskRunError(
queueID string,
task *BatchTask,
conversationID, assistantMessageID string,
baseCtx, taskCtx context.Context,
resultMA *multiagent.RunResult,
runErr error,
finishStatus *string,
) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errStr := runErr.Error()
partialResp := ""
if resultMA != nil {
partialResp = resultMA.Response
}
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
errors.Is(runErr, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
if isTimeout {
*finishStatus = "timeout"
} else if isCancelled {
*finishStatus = "cancelled"
} else {
*finishStatus = "failed"
}
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
cancelMsg = partialResp
}
if assistantMessageID != "" {
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
return
}
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
errorMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
errorMsg,
time.Now(), assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
}
+216 -43
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@@ -17,6 +18,15 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
var (
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
ErrBatchQueueNotFound = errors.New("batch queue not found")
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
)
// 批量任务状态常量 // 批量任务状态常量
const ( const (
BatchQueueStatusPending = "pending" BatchQueueStatusPending = "pending"
@@ -39,6 +49,12 @@ const (
// MaxBatchQueueRoleLen 角色名最大长度 // MaxBatchQueueRoleLen 角色名最大长度
MaxBatchQueueRoleLen = 100 MaxBatchQueueRoleLen = 100
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
DefaultBatchQueueConcurrency = 1
// MaxBatchQueueConcurrency 批量队列最大并发数
MaxBatchQueueConcurrency = 8
) )
// BatchTask 批量任务项 // BatchTask 批量任务项
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
LastScheduleError string `json:"lastScheduleError,omitempty"` LastScheduleError string `json:"lastScheduleError,omitempty"`
LastRunError string `json:"lastRunError,omitempty"` LastRunError string `json:"lastRunError,omitempty"`
ProjectID string `json:"projectId,omitempty"` ProjectID string `json:"projectId,omitempty"`
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
Tasks []*BatchTask `json:"tasks"` Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
queues map[string]*BatchTaskQueue queues map[string]*BatchTaskQueue
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列 singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
mu sync.RWMutex mu sync.RWMutex
} }
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
return &BatchTaskManager{ return &BatchTaskManager{
logger: logger, logger: logger,
queues: make(map[string]*BatchTaskQueue), queues: make(map[string]*BatchTaskQueue),
taskCancels: make(map[string]context.CancelFunc), taskCancels: make(map[string]map[string]context.CancelFunc),
singleRunTasks: make(map[string]string), singleRunTasks: make(map[string]string),
queueExecutors: make(map[string]struct{}),
} }
} }
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
if !exists || queue == nil {
return true
}
switch queue.Status {
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
return true
default:
return false
}
}
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.queueExecutors[queueID]; exists {
return false
}
m.queueExecutors[queueID] = struct{}{}
return true
}
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.queueExecutors, queueID)
}
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
m.UnmarkQueueExecutor(queueID)
}
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.queueExecutors[queueID]
return ok
}
// SetDB 设置数据库连接 // SetDB 设置数据库连接
func (m *BatchTaskManager) SetDB(db *database.DB) { func (m *BatchTaskManager) SetDB(db *database.DB) {
m.mu.Lock() m.mu.Lock()
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
m.db = db m.db = db
} }
// normalizeBatchQueueConcurrency 规范化队列并发数。
func normalizeBatchQueueConcurrency(n int) int {
if n < 1 {
return DefaultBatchQueueConcurrency
}
if n > MaxBatchQueueConcurrency {
return MaxBatchQueueConcurrency
}
return n
}
// CreateBatchQueue 创建批量任务队列 // CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue( func (m *BatchTaskManager) CreateBatchQueue(
title, role, agentMode, scheduleMode, cronExpr, projectID string, title, role, agentMode, scheduleMode, cronExpr, projectID string,
nextRunAt *time.Time, nextRunAt *time.Time,
concurrency int,
tasks []string, tasks []string,
) (*BatchTaskQueue, error) { ) (*BatchTaskQueue, error) {
// 输入校验 // 输入校验
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
CronExpr: strings.TrimSpace(cronExpr), CronExpr: strings.TrimSpace(cronExpr),
NextRunAt: nextRunAt, NextRunAt: nextRunAt,
ScheduleEnabled: true, ScheduleEnabled: true,
Concurrency: normalizeBatchQueueConcurrency(concurrency),
Tasks: make([]*BatchTask, 0, len(tasks)), Tasks: make([]*BatchTask, 0, len(tasks)),
Status: BatchQueueStatusPending, Status: BatchQueueStatusPending,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
queue.CronExpr, queue.CronExpr,
queue.NextRunAt, queue.NextRunAt,
queue.ProjectID, queue.ProjectID,
queue.Concurrency,
dbTasks, dbTasks,
); err != nil { ); err != nil {
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if queueRow.ProjectID.Valid { if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
} }
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
if queueRow.ProjectID.Valid { if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
} }
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
} }
} }
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用) // batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error { func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
if row == nil || !row.Concurrency.Valid {
return DefaultBatchQueueConcurrency
}
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
}
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
} }
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
queue.Title = title queue.Title = title
queue.Role = role queue.Role = role
queue.AgentMode = agentMode queue.AgentMode = agentMode
if concurrency != nil {
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
}
if m.db != nil { if m.db != nil {
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil { if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
} }
} }
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引 // PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error { func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
var cancelFunc context.CancelFunc
var siblingRunningIDs []string var siblingRunningIDs []string
m.mu.Lock() m.mu.Lock()
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
} }
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项 // 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
var cancelFuncs []context.CancelFunc
if queue.Status == BatchQueueStatusPaused { if queue.Status == BatchQueueStatusPaused {
if c, ok := m.taskCancels[queueID]; ok { cancelFuncs = m.drainTaskCancelsLocked(queueID)
cancelFunc = c
delete(m.taskCancels, queueID)
}
for _, t := range queue.Tasks { for _, t := range queue.Tasks {
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning { if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
siblingRunningIDs = append(siblingRunningIDs, t.ID) siblingRunningIDs = append(siblingRunningIDs, t.ID)
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
m.mu.Unlock() m.mu.Unlock()
if cancelFunc != nil { for _, c := range cancelFuncs {
cancelFunc() if c != nil {
c()
}
} }
const staleRunMsg = "为单条执行其它任务,已中止" const staleRunMsg = "为单条执行其它任务,已中止"
for _, sid := range siblingRunningIDs { for _, sid := range siblingRunningIDs {
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
} }
} }
// GetNextTask 取下一个待执行任务 // ClaimNextPendingTask 原子领取下一个待执行任务(并发 worker 安全)。
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return nil, false
}
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
return nil, false
}
onlyTaskID := ""
if m.singleRunTasks != nil {
onlyTaskID = m.singleRunTasks[queueID]
}
for i, task := range queue.Tasks {
if task == nil || task.Status != BatchTaskStatusPending {
continue
}
if onlyTaskID != "" && task.ID != onlyTaskID {
continue
}
task.Status = BatchTaskStatusRunning
queue.CurrentIndex = i
return task, true
}
return nil, false
}
// HasRunningTasks 队列是否仍有 running 状态的子任务。
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return false
}
for _, task := range queue.Tasks {
if task != nil && task.Status == BatchTaskStatusRunning {
return true
}
}
return false
}
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return false
}
for _, task := range queue.Tasks {
if task == nil {
continue
}
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
return true
}
}
return false
}
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
taskMap, ok := m.taskCancels[queueID]
if !ok || len(taskMap) == 0 {
return nil
}
cancels := make([]context.CancelFunc, 0, len(taskMap))
for _, c := range taskMap {
if c != nil {
cancels = append(cancels, c)
}
}
delete(m.taskCancels, queueID)
return cancels
}
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
} }
} }
// SetTaskCancel 设置当前任务的取消函数 // SetTaskCancel 设置任务的取消函数
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) { func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if cancel != nil { if cancel == nil {
m.taskCancels[queueID] = cancel if taskMap, ok := m.taskCancels[queueID]; ok {
} else { delete(taskMap, taskID)
delete(m.taskCancels, queueID) if len(taskMap) == 0 {
delete(m.taskCancels, queueID)
}
}
return
} }
if m.taskCancels[queueID] == nil {
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
}
m.taskCancels[queueID][taskID] = cancel
} }
// PauseQueue 暂停队列 // PauseQueue 暂停队列
func (m *BatchTaskManager) PauseQueue(queueID string) bool { func (m *BatchTaskManager) PauseQueue(queueID string) bool {
var cancelFunc context.CancelFunc var cancelFuncs []context.CancelFunc
m.mu.Lock() m.mu.Lock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
} }
queue.Status = BatchQueueStatusPaused queue.Status = BatchQueueStatusPaused
cancelFuncs = m.drainTaskCancelsLocked(queueID)
// 取消当前正在执行的任务(通过取消context)
if cancel, ok := m.taskCancels[queueID]; ok {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
m.mu.Unlock() m.mu.Unlock()
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) for _, c := range cancelFuncs {
if cancelFunc != nil { c()
cancelFunc()
} }
return true return true
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) // CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
func (m *BatchTaskManager) CancelQueue(queueID string) bool { func (m *BatchTaskManager) CancelQueue(queueID string) bool {
now := time.Now() now := time.Now()
var cancelFunc context.CancelFunc var cancelFuncs []context.CancelFunc
m.mu.Lock() m.mu.Lock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
} }
} }
// 取消当前正在执行的任务 cancelFuncs = m.drainTaskCancelsLocked(queueID)
if cancel, ok := m.taskCancels[queueID]; ok {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
m.mu.Unlock() m.mu.Unlock()
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) for _, c := range cancelFuncs {
if cancelFunc != nil { c()
cancelFunc()
} }
return true return true
} }
// DeleteQueue 删除队列(运行中的队列不允许删除) // DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
func (m *BatchTaskManager) DeleteQueue(queueID string) bool { func (m *BatchTaskManager) DeleteQueue(queueID string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
if !exists { if !exists {
return false return ErrBatchQueueNotFound
}
if _, exec := m.queueExecutors[queueID]; exec {
return ErrBatchQueueExecutorActive
} }
// 运行中的队列不允许删除,防止孤儿协程和数据丢失 // 运行中的队列不允许删除,防止孤儿协程和数据丢失
if queue.Status == BatchQueueStatusRunning { if queue.Status == BatchQueueStatusRunning {
return false return ErrBatchQueueStillRunning
} }
// 清理取消函数 // 清理取消函数
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
} }
delete(m.queues, queueID) delete(m.queues, queueID)
return true return nil
} }
// generateShortID 生成短ID // generateShortID 生成短ID
+121
View File
@@ -0,0 +1,121 @@
package handler
import (
"errors"
"testing"
"go.uber.org/zap"
)
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
}
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
}
}
func TestClaimNextPendingTaskParallel(t *testing.T) {
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
if !ok1 || !ok2 || t1.ID == t2.ID {
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
}
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
t.Fatalf("claimed tasks should be running")
}
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
if !ok3 {
t.Fatal("expected third claim")
}
_, ok4 := m.ClaimNextPendingTask(queue.ID)
if ok4 {
t.Fatal("expected no fourth pending task")
}
_ = t3
}
func TestBatchQueueExecutionShouldStop(t *testing.T) {
t.Parallel()
if !batchQueueExecutionShouldStop(nil, false) {
t.Fatal("expected stop when queue missing")
}
if !batchQueueExecutionShouldStop(nil, true) {
t.Fatal("expected stop when queue is nil but exists=true")
}
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
if batchQueueExecutionShouldStop(q, true) {
t.Fatal("expected continue when running")
}
q.Status = BatchQueueStatusCancelled
if !batchQueueExecutionShouldStop(q, true) {
t.Fatal("expected stop when cancelled")
}
}
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
if !m.TryMarkQueueExecutor(queue.ID) {
t.Fatal("expected to mark executor")
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
err = m.DeleteQueue(queue.ID)
if !errors.Is(err, ErrBatchQueueExecutorActive) {
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
}
if _, ok := m.GetBatchQueue(queue.ID); !ok {
t.Fatal("queue should still exist while executor active")
}
m.UnmarkQueueExecutor(queue.ID)
if err := m.DeleteQueue(queue.ID); err != nil {
t.Fatalf("expected delete after executor unmarked, got %v", err)
}
if _, ok := m.GetBatchQueue(queue.ID); ok {
t.Fatal("queue should be deleted")
}
}
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
err = m.DeleteQueue(queue.ID)
if !errors.Is(err, ErrBatchQueueStillRunning) {
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
}
}
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
if !m.TryMarkQueueExecutor("q-1") {
t.Fatal("first mark should succeed")
}
if m.TryMarkQueueExecutor("q-1") {
t.Fatal("second mark should fail")
}
m.UnmarkQueueExecutor("q-1")
if !m.TryMarkQueueExecutor("q-1") {
t.Fatal("mark after unmark should succeed")
}
}
+30 -4
View File
@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"type": "string", "type": "string",
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id", "description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id",
}, },
"concurrency": map[string]interface{}{
"type": "integer",
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
},
}, },
}, },
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
executeNow = false executeNow = false
} }
projectID := strings.TrimSpace(mcpArgString(args, "project_id")) projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks) concurrency := int(mcpArgFloat(args, "concurrency"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
if createErr != nil { if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
} }
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
if qid == "" { if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil return batchMCPTextResult("queue_id 不能为空", true), nil
} }
if !h.batchTaskManager.DeleteQueue(qid) { if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
return batchMCPTextResult("删除失败:队列不存在", true), nil switch {
case errors.Is(err, ErrBatchQueueNotFound):
return batchMCPTextResult("删除失败:队列不存在", true), nil
case errors.Is(err, ErrBatchQueueExecutorActive):
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
case errors.Is(err, ErrBatchQueueStillRunning):
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
default:
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
}
} }
logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
return batchMCPTextResult("队列已删除。", false), nil return batchMCPTextResult("队列已删除。", false), nil
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"description": "代理模式:eino_single、deep、plan_execute、supervisor", "description": "代理模式:eino_single、deep、plan_execute、supervisor",
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
}, },
"concurrency": map[string]interface{}{
"type": "integer",
"description": "同时执行的子任务数,默认 1,最大 8",
},
}, },
"required": []string{"queue_id"}, "required": []string{"queue_id"},
}, },
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
title := mcpArgString(args, "title") title := mcpArgString(args, "title")
role := mcpArgString(args, "role") role := mcpArgString(args, "role")
agentMode := mcpArgString(args, "agent_mode") agentMode := mcpArgString(args, "agent_mode")
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil { var concurrency *int
if raw, ok := args["concurrency"]; ok && raw != nil {
v := int(mcpArgFloat(args, "concurrency"))
concurrency = &v
}
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
return batchMCPTextResult(err.Error(), true), nil return batchMCPTextResult(err.Error(), true), nil
} }
updated, _ := h.batchTaskManager.GetBatchQueue(qid) updated, _ := h.batchTaskManager.GetBatchQueue(qid)
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
StartedAt *time.Time `json:"startedAt,omitempty"` StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"` CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"` CurrentIndex int `json:"currentIndex"`
Concurrency int `json:"concurrency"`
TaskTotal int `json:"task_total"` TaskTotal int `json:"task_total"`
TaskCounts map[string]int `json:"task_counts"` TaskCounts map[string]int `json:"task_counts"`
Tasks []batchTaskMCPListSummary `json:"tasks"` Tasks []batchTaskMCPListSummary `json:"tasks"`
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
StartedAt: q.StartedAt, StartedAt: q.StartedAt,
CompletedAt: q.CompletedAt, CompletedAt: q.CompletedAt,
CurrentIndex: q.CurrentIndex, CurrentIndex: q.CurrentIndex,
Concurrency: q.Concurrency,
TaskTotal: len(tasks), TaskTotal: len(tasks),
TaskCounts: counts, TaskCounts: counts,
Tasks: tasks, Tasks: tasks,
+44
View File
@@ -798,6 +798,10 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// 更新机器人配置 // 更新机器人配置
if req.Robots != nil { if req.Robots != nil {
if err := config.ValidateWecomConfig(req.Robots.Wecom); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.config.Robots = *req.Robots h.config.Robots = *req.Robots
h.logger.Info("更新机器人配置", h.logger.Info("更新机器人配置",
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled), zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
@@ -1329,6 +1333,17 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("已更新嵌入模型配置记录") h.logger.Info("已更新嵌入模型配置记录")
} }
// 从 tools 目录重新加载工具配置(新增/修改/删除 yaml 后无需重启)
if err := config.ReloadSecurityToolsFromDir(h.config, h.configPath); err != nil {
h.logger.Error("重新加载工具配置失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新加载工具", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新加载工具配置失败: " + err.Error()})
return
}
h.logger.Info("已从 tools 目录重新加载工具配置", zap.Int("tools_count", len(h.config.Security.Tools)))
// 重新注册工具(根据新的启用状态) // 重新注册工具(根据新的启用状态)
h.logger.Info("重新注册工具") h.logger.Info("重新注册工具")
@@ -1751,6 +1766,20 @@ func mergeHitlToolWhitelistSlice(existing, add []string) []string {
return out return out
} }
// SetHitlToolWhitelist 将全局免审批工具白名单整表写入 config.yaml(替换,非合并)。
func (h *ConfigHandler) SetHitlToolWhitelist(tools []string) error {
h.mu.Lock()
defer h.mu.Unlock()
h.config.Hitl.ToolWhitelist = mergeHitlToolWhitelistSlice(nil, tools)
if err := h.saveConfig(); err != nil {
return err
}
h.logger.Info("HITL 全局工具白名单已写入配置文件",
zap.Int("count", len(h.config.Hitl.ToolWhitelist)),
)
return nil
}
// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。 // MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。
func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error { func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error {
h.mu.Lock() h.mu.Lock()
@@ -1771,6 +1800,21 @@ func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) {
hitlNode := ensureMap(root, "hitl") hitlNode := ensureMap(root, "hitl")
// flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数 // flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数
setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist) setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist)
setStringInMap(hitlNode, "audit_agent_prompt", cfg.AuditAgentPrompt)
setStringInMap(hitlNode, "audit_agent_prompt_review_edit", cfg.AuditAgentPromptReviewEdit)
}
// UpdateHitlAuditAgentStrategy 更新审批/审查编辑两套审计 Agent 提示词并写入 config.yaml。
func (h *ConfigHandler) UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error {
h.mu.Lock()
defer h.mu.Unlock()
h.config.Hitl.AuditAgentPrompt = strings.TrimSpace(approvalPrompt)
h.config.Hitl.AuditAgentPromptReviewEdit = strings.TrimSpace(reviewEditPrompt)
if err := h.saveConfig(); err != nil {
return err
}
h.logger.Info("HITL 审计 Agent 提示词已写入配置文件")
return nil
} }
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) { func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
+6 -5
View File
@@ -103,6 +103,7 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50") limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0") offsetStr := c.DefaultQuery("offset", "0")
search := c.Query("search") // 获取搜索参数 search := c.Query("search") // 获取搜索参数
projectID := strings.TrimSpace(c.Query("project_id"))
limit, _ := strconv.Atoi(limitStr) limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr) offset, _ := strconv.Atoi(offsetStr)
@@ -114,7 +115,7 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
limit = 1000 limit = 1000
} }
excludeGrouped := strings.TrimSpace(search) == "" && excludeGrouped := strings.TrimSpace(search) == "" && projectID == "" &&
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1") (c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
sortBy := strings.TrimSpace(c.Query("sort_by")) sortBy := strings.TrimSpace(c.Query("sort_by"))
@@ -122,14 +123,14 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
var total int var total int
var err error var err error
if excludeGrouped { if excludeGrouped {
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy) conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy, projectID)
if err == nil { if err == nil {
total, err = h.db.CountUngroupedConversations() total, err = h.db.CountUngroupedConversations(projectID)
} }
} else { } else {
conversations, err = h.db.ListConversations(limit, offset, search, sortBy) conversations, err = h.db.ListConversations(limit, offset, search, sortBy, projectID)
if err == nil { if err == nil {
total, err = h.db.CountConversations(search) total, err = h.db.CountConversations(search, projectID)
} }
} }
if err != nil { if err != nil {
@@ -0,0 +1,83 @@
package handler
import (
"context"
"fmt"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
)
// rebindEinoRunningTask 中断并继续 / 空正文续跑:重建 cancel 链与超时 ctx,保持任务 running。
func (h *AgentHandler) rebindEinoRunningTask(conversationID string, timeoutCancel context.CancelFunc) (context.Context, context.CancelCauseFunc, context.Context, context.CancelFunc) {
if timeoutCancel != nil {
timeoutCancel()
}
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, newTimeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
return baseCtx, cancelWithCause, taskCtx, newTimeoutCancel
}
// tryContinueOnEinoEmptyResponse Run 成功但 Response 为 emptyHint 时退避续跑;true 表示已准备下一段 Run。
func (h *AgentHandler) tryContinueOnEinoEmptyResponse(
taskCtx context.Context,
mw *config.MultiAgentEinoMiddlewareConfig,
conversationID string,
result *multiagent.RunResult,
attempt *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
progressCallback func(eventType, message string, data interface{}),
) bool {
if result == nil || !multiagent.IsEinoEmptyResponseResult(result) || !multiagent.HasEinoResumeTrace(result) {
return false
}
maxAttempts := multiagent.EmptyResponseContinueMaxAttemptsFromConfig(mw)
if *attempt >= maxAttempts {
if h.logger != nil {
h.logger.Warn("eino empty response continue exhausted",
zap.String("conversationId", conversationID),
zap.Int("maxAttempts", maxAttempts))
}
return false
}
*attempt++
h.persistEinoAgentTraceForResume(conversationID, result)
backoff := multiagent.EmptyResponseContinueBackoff(*attempt-1, mw)
waitMsg := fmt.Sprintf("会话已结束但未捕获到助手正文,%d 秒后第 %d/%d 次自动续跑…",
int(backoff.Seconds()), *attempt, maxAttempts)
if progressCallback != nil {
progressCallback("eino_empty_response_continue", waitMsg, map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": *attempt,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
}
select {
case <-taskCtx.Done():
return false
case <-time.After(backoff):
}
inject := multiagent.FormatEmptyResponseContinueUserMessage()
h.applyEinoTraceResumeSegment(conversationID, result, curHistory, curFinalMessage, inject)
if progressCallback != nil {
progressCallback("eino_empty_response_continue", "已恢复上下文,正在续跑…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": *attempt,
"maxAttempts": maxAttempts,
"contextSource": "empty_response_continue",
})
}
return true
}
+10 -2
View File
@@ -178,6 +178,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
var emptyResponseContinueAttempt int
for { for {
segmentMainIterationMax := 0 segmentMainIterationMax := 0
@@ -231,7 +232,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
roleTools, roleTools,
progressCallback, progressCallback,
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID), h.agentSessionContextBlock(conversationID),
) )
if result != nil && len(result.MCPExecutionIDs) > 0 { if result != nil && len(result.MCPExecutionIDs) > 0 {
@@ -239,6 +240,13 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
} }
if runErr == nil { if runErr == nil {
mw := &h.config.MultiAgent.EinoMiddleware
if h.tryContinueOnEinoEmptyResponse(taskCtx, mw, conversationID, result, &emptyResponseContinueAttempt, &curHistory, &curFinalMessage, progressCallback) {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause, taskCtx, timeoutCancel = h.rebindEinoRunningTask(conversationID, timeoutCancel)
continue
}
timeoutCancel() timeoutCancel()
break break
} }
@@ -416,7 +424,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.RoleTools, prep.RoleTools,
progressCallback, progressCallback,
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID), h.agentSessionContextBlock(prep.ConversationID),
) )
if runErr == nil { if runErr == nil {
break break
+136 -89
View File
@@ -23,6 +23,7 @@ import (
type hitlRuntimeConfig struct { type hitlRuntimeConfig struct {
Enabled bool Enabled bool
Mode string Mode string
Reviewer string
SensitiveTools map[string]struct{} SensitiveTools map[string]struct{}
Timeout time.Duration Timeout time.Duration
} }
@@ -49,6 +50,8 @@ type HITLManager struct {
mu sync.RWMutex mu sync.RWMutex
runtime map[string]hitlRuntimeConfig runtime map[string]hitlRuntimeConfig
pending map[string]*pendingInterrupt pending map[string]*pendingInterrupt
// approvedExec 审批通过、待回写 tool_result 的队列(按会话 FIFO
approvedExec map[string][]hitlApprovedExecTrack
} }
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager { func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
@@ -90,6 +93,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
if err != nil { if err != nil {
return err return err
} }
m.migrateHitlSchemaColumns()
// On startup, cancel all orphaned pending interrupts from previous process. // On startup, cancel all orphaned pending interrupts from previous process.
// Their in-memory channels are gone, so they can never be resolved. // Their in-memory channels are gone, so they can never be resolved.
@@ -141,6 +145,7 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
m.runtime[conversationID] = hitlRuntimeConfig{ m.runtime[conversationID] = hitlRuntimeConfig{
Enabled: true, Enabled: true,
Mode: normalizeHitlMode(req.Mode), Mode: normalizeHitlMode(req.Mode),
Reviewer: normalizeHitlReviewer(req.Reviewer),
SensitiveTools: tools, SensitiveTools: tools,
Timeout: timeout, Timeout: timeout,
} }
@@ -153,17 +158,14 @@ func (m *HITLManager) DeactivateConversation(conversationID string) {
m.mu.Unlock() m.mu.Unlock()
} }
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。 // hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空),并合并内置元工具免审批项
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string { func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
if h == nil || h.config == nil { if h == nil || h.config == nil {
return nil return multiagent.MergeHitlExemptMetaTools(nil)
} }
raw := h.config.Hitl.ToolWhitelist raw := h.config.Hitl.ToolWhitelist
if len(raw) == 0 {
return nil
}
seen := make(map[string]struct{}) seen := make(map[string]struct{})
out := make([]string, 0, len(raw)) out := make([]string, 0, len(raw)+len(multiagent.HitlExemptMetaTools))
for _, t := range raw { for _, t := range raw {
n := strings.ToLower(strings.TrimSpace(t)) n := strings.ToLower(strings.TrimSpace(t))
if n == "" { if n == "" {
@@ -175,44 +177,35 @@ func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
seen[n] = struct{}{} seen[n] = struct{}{}
out = append(out, strings.TrimSpace(t)) out = append(out, strings.TrimSpace(t))
} }
return out return multiagent.MergeHitlExemptMetaTools(out)
} }
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。 // hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单及内置元工具免审批项合并(并集),仅用于运行时 Activate;不写入数据库。
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest { func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
gw := h.hitlConfigGlobalToolWhitelist()
if len(gw) == 0 {
return req
}
if req == nil { if req == nil {
return nil return nil
} }
seen := make(map[string]struct{}) seen := make(map[string]struct{})
union := make([]string, 0, len(gw)+len(req.SensitiveTools)) union := make([]string, 0, len(req.SensitiveTools)+16)
for _, t := range gw { add := func(t string) {
n := strings.ToLower(strings.TrimSpace(t)) n := strings.ToLower(strings.TrimSpace(t))
if n == "" { if n == "" {
continue return
} }
if _, ok := seen[n]; ok { if _, ok := seen[n]; ok {
continue return
} }
seen[n] = struct{}{} seen[n] = struct{}{}
union = append(union, strings.TrimSpace(t)) union = append(union, strings.TrimSpace(t))
} }
for _, t := range h.hitlConfigGlobalToolWhitelist() {
add(t)
}
for _, t := range req.SensitiveTools { for _, t := range req.SensitiveTools {
n := strings.ToLower(strings.TrimSpace(t)) add(t)
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
union = append(union, strings.TrimSpace(t))
} }
out := *req out := *req
out.SensitiveTools = union out.SensitiveTools = multiagent.MergeHitlExemptMetaTools(union)
return &out return &out
} }
@@ -362,22 +355,22 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
timeout = 0 timeout = 0
} }
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs _, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at) (conversation_id, enabled, mode, reviewer, sensitive_tools, timeout_seconds, updated_at)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(conversation_id) DO UPDATE SET ON CONFLICT(conversation_id) DO UPDATE SET
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`, enabled=excluded.enabled, mode=excluded.mode, reviewer=excluded.reviewer, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now()) conversationID, boolToInt(req.Enabled), mode, normalizeHitlReviewer(req.Reviewer), string(tools), timeout, time.Now())
return err return err
} }
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) { func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
var enabledInt int var enabledInt int
var mode, toolsJSON string var mode, reviewer, toolsJSON string
var timeout int var timeout int
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID). err := m.db.QueryRow(`SELECT enabled, mode, COALESCE(reviewer,'human'), sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
Scan(&enabledInt, &mode, &toolsJSON, &timeout) Scan(&enabledInt, &mode, &reviewer, &toolsJSON, &timeout)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil return &HITLRequest{Enabled: false, Mode: "off", Reviewer: "human", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -390,6 +383,7 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
return &HITLRequest{ return &HITLRequest{
Enabled: enabledInt == 1, Enabled: enabledInt == 1,
Mode: mode, Mode: mode,
Reviewer: normalizeHitlReviewer(reviewer),
SensitiveTools: tools, SensitiveTools: tools,
TimeoutSeconds: timeout, TimeoutSeconds: timeout,
}, nil }, nil
@@ -413,15 +407,16 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 { if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
d.EditedArguments = nil d.EditedArguments = nil
} }
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`, _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='human' WHERE id=?`,
d.Decision, d.Comment, time.Now(), p.InterruptID) d.Decision, d.Comment, time.Now(), p.InterruptID)
return d, nil return d, nil
case <-timeoutCh: case <-timeoutCh:
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`, comment := "HITL timeout auto-reject for safety"
time.Now(), p.InterruptID) _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='reject', decision_comment=?, decided_at=?, decided_by='system' WHERE id=?`,
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil comment, time.Now(), p.InterruptID)
return hitlDecision{Decision: "reject", Comment: comment}, nil
case <-ctx.Done(): case <-ctx.Done():
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`, _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=?, decided_by='system' WHERE id=?`,
time.Now(), p.InterruptID) time.Now(), p.InterruptID)
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err() return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
} }
@@ -445,12 +440,57 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
if !need { if !need {
return nil, nil return nil, nil
} }
h.enrichHitlApprovalPayload(conversationID, assistantMessageID, payload)
payloadRaw, _ := json.Marshal(payload) payloadRaw, _ := json.Marshal(payload)
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw)) p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
if err != nil { if err != nil {
h.logger.Warn("创建 HITL 中断失败", zap.Error(err)) h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
return nil, err return nil, err
} }
if cfg.Reviewer == "audit_agent" {
ad := h.auditAgentReview(runCtx, cfg.Mode, toolName, payload)
now := time.Now()
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='audit_agent' WHERE id=?`,
ad.Decision, ad.Comment, now, p.InterruptID)
if sendEventFunc != nil {
sendEventFunc("hitl_audit_agent", "审计 Agent 已裁决", map[string]interface{}{
"conversationId": conversationID,
"interruptId": p.InterruptID,
"toolName": toolName,
"mode": cfg.Mode,
"decision": ad.Decision,
"comment": ad.Comment,
"editedArgs": ad.EditedArguments,
"decidedBy": "audit_agent",
})
}
if ad.Decision == "reject" {
if sendEventFunc != nil {
sendEventFunc("hitl_rejected", "审计 Agent 拒绝本次工具调用", map[string]interface{}{
"conversationId": conversationID,
"interruptId": p.InterruptID,
"toolName": toolName,
"comment": ad.Comment,
"decidedBy": "audit_agent",
})
}
return &ad, nil
}
if sendEventFunc != nil {
sendEventFunc("hitl_resumed", "审计 Agent 已通过,继续执行", map[string]interface{}{
"conversationId": conversationID,
"interruptId": p.InterruptID,
"toolName": toolName,
"comment": ad.Comment,
"editedArgs": ad.EditedArguments,
"decidedBy": "audit_agent",
})
}
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
return &ad, nil
}
if sendEventFunc != nil { if sendEventFunc != nil {
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{ sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
@@ -479,8 +519,12 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
return nil, waitErr return nil, waitErr
} }
if d.Decision == "reject" { if d.Decision == "reject" {
rejectMsg := "人工拒绝本次工具调用,模型将基于反馈继续迭代"
if strings.Contains(strings.ToLower(strings.TrimSpace(d.Comment)), "timeout") {
rejectMsg = "审批超时,安全起见已自动拒绝,模型将基于反馈继续迭代"
}
if sendEventFunc != nil { if sendEventFunc != nil {
sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{ sendEventFunc("hitl_rejected", rejectMsg, map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"interruptId": p.InterruptID, "interruptId": p.InterruptID,
"toolName": toolName, "toolName": toolName,
@@ -498,6 +542,7 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
"editedArgs": d.EditedArguments, "editedArgs": d.EditedArguments,
}) })
} }
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
return &d, nil return &d, nil
} }
@@ -527,11 +572,6 @@ func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun cont
} }
func (h *AgentHandler) ListHITLPending(c *gin.Context) { func (h *AgentHandler) ListHITLPending(c *gin.Context) {
conversationID := strings.TrimSpace(c.Query("conversationId"))
status := strings.TrimSpace(c.Query("status"))
if status == "" {
status = "pending"
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
if page < 1 { if page < 1 {
page = 1 page = 1
@@ -539,15 +579,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200))) pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1` q, args := h.buildHitlListQuery(false)
args := []interface{}{} q, args = h.appendHitlListFilters(q, args, c)
if conversationID != "" { total, err := h.countHitlQuery(q, args)
q += " AND conversation_id = ?" if err != nil {
args = append(args, conversationID) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} return
if status != "all" {
q += " AND status = ?"
args = append(args, status)
} }
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?" q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, pageSize, offset) args = append(args, pageSize, offset)
@@ -557,41 +594,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
return return
} }
defer rows.Close() defer rows.Close()
items := make([]map[string]interface{}, 0) items, err := h.scanHitlInterruptRows(rows)
for rows.Next() { if err != nil {
var id, cid, mode, toolName, toolCallID, payload, rowStatus string c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
var messageID sql.NullString return
var decision, comment sql.NullString
var createdAt time.Time
var decidedAt sql.NullTime
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
continue
}
msgID := ""
if messageID.Valid {
msgID = messageID.String
}
items = append(items, map[string]interface{}{
"id": id,
"conversationId": cid,
"messageId": msgID,
"mode": mode,
"toolName": toolName,
"toolCallId": toolCallID,
"payload": payload,
"status": rowStatus,
"decision": decision.String,
"comment": comment.String,
"createdAt": createdAt,
"decidedAt": func() interface{} {
if decidedAt.Valid {
return decidedAt.Time
}
return nil
}(),
})
} }
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize}) c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total})
} }
type hitlDecisionReq struct { type hitlDecisionReq struct {
@@ -636,7 +644,7 @@ func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) {
return return
} }
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP, decided_by='human'
WHERE id=? AND status='pending'`, req.InterruptID) WHERE id=? AND status='pending'`, req.InterruptID)
if err != nil { if err != nil {
c.JSON(500, gin.H{"error": err.Error()}) c.JSON(500, gin.H{"error": err.Error()})
@@ -732,6 +740,7 @@ func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
return return
} }
req.Mode = normalizeHitlMode(req.Mode) req.Mode = normalizeHitlMode(req.Mode)
req.Reviewer = normalizeHitlReviewer(req.Reviewer)
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil { if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -753,6 +762,44 @@ type mergeHitlGlobalWhitelistReq struct {
SensitiveTools []string `json:"sensitiveTools"` SensitiveTools []string `json:"sensitiveTools"`
} }
type setHitlGlobalWhitelistReq struct {
ToolWhitelist []string `json:"toolWhitelist"`
}
// GetHITLGlobalToolWhitelist 返回 config.yaml 中的全局免审批工具白名单。
func (h *AgentHandler) GetHITLGlobalToolWhitelist(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
})
}
// SetHITLGlobalToolWhitelist 整表替换 config.yaml 中的全局免审批工具白名单。
func (h *AgentHandler) SetHITLGlobalToolWhitelist(c *gin.Context) {
if h.hitlWhitelistSaver == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
return
}
var req setHitlGlobalWhitelistReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.hitlWhitelistSaver.SetHitlToolWhitelist(req.ToolWhitelist); err != nil {
h.logger.Warn("写入 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "tool_whitelist_update", "HITL 全局白名单更新", "hitl_config", "tool_whitelist", nil)
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
"hitlGlobalWhitelistMerged": false,
})
}
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。 // MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) { func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
if h.hitlWhitelistSaver == nil { if h.hitlWhitelistSaver == nil {
+357
View File
@@ -0,0 +1,357 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// auditAgentReview 在 reviewer=audit_agent 时由 LLM 代行审批。
// 白名单工具在 shouldInterrupt 阶段已跳过,到达此处的一律需要裁决。
func (h *AgentHandler) auditAgentReview(ctx context.Context, hitlMode, toolName string, payload map[string]interface{}) hitlDecision {
if h == nil {
return hitlDecision{Decision: "reject", Comment: "audit agent: handler unavailable"}
}
mode := normalizeHitlMode(hitlMode)
prompt := config.DefaultHitlAuditAgentPrompt()
if h.config != nil {
prompt = h.config.Hitl.EffectiveAuditAgentPromptForMode(mode)
}
if h.auditLLM == nil {
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 未配置"}
}
if ctx == nil {
ctx = context.Background()
}
callCtx, cancel := context.WithTimeout(ctx, 90*time.Second)
defer cancel()
userContent := buildAuditAgentReviewInput(mode, toolName, payload)
requestBody := map[string]interface{}{
"model": h.auditLLMModel(),
"messages": []map[string]interface{}{
{"role": "system", "content": prompt},
{"role": "user", "content": userContent},
},
"temperature": 0.1,
"max_completion_tokens": 1024,
// 审计裁决需要结构化 JSON;关闭 thinking 避免 Qwen 等把正文放进 reasoning_content 导致解析失败。
"thinking": map[string]interface{}{"type": "disabled"},
}
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
} `json:"message"`
} `json:"choices"`
}
if err := h.auditLLM.ChatCompletion(callCtx, requestBody, &apiResponse); err != nil {
h.logger.Warn("审计 Agent LLM 调用失败", zap.Error(err), zap.String("tool", toolName))
return hitlDecision{
Decision: "reject",
Comment: "audit agent: LLM 调用失败,保守拒绝",
}
}
if len(apiResponse.Choices) == 0 {
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 无有效响应,保守拒绝"}
}
msg := apiResponse.Choices[0].Message
raw := strings.TrimSpace(msg.Content)
if raw == "" {
raw = strings.TrimSpace(msg.ReasoningContent)
}
dec, err := parseAuditAgentLLMContent(raw)
if err != nil {
snippet := raw
if len(snippet) > 240 {
snippet = snippet[:240] + "..."
}
h.logger.Warn("审计 Agent 响应解析失败",
zap.Error(err),
zap.String("tool", toolName),
zap.String("mode", mode),
zap.String("snippet", snippet),
)
return hitlDecision{Decision: "reject", Comment: "audit agent: 响应无法解析,保守拒绝"}
}
if mode != "review_edit" && len(dec.EditedArguments) > 0 {
h.logger.Warn("审计 Agent 在审批模式下返回 editedArguments,已忽略",
zap.String("tool", toolName),
)
dec.EditedArguments = nil
}
if dec.Comment == "" {
dec.Comment = "audit agent: " + dec.Decision
} else if !strings.HasPrefix(strings.ToLower(dec.Comment), "audit agent") {
dec.Comment = "audit agent: " + dec.Comment
}
return dec
}
func (h *AgentHandler) auditLLMModel() string {
if h.config != nil && strings.TrimSpace(h.config.OpenAI.Model) != "" {
return strings.TrimSpace(h.config.OpenAI.Model)
}
return ""
}
func buildAuditAgentReviewInput(hitlMode, toolName string, payload map[string]interface{}) string {
review := map[string]interface{}{
"hitlMode": normalizeHitlMode(hitlMode),
"toolName": strings.TrimSpace(toolName),
}
if payload != nil {
for _, k := range []string{"arguments", "argumentsObj", "command", hitlPayloadUserMessage, hitlPayloadThinking, hitlPayloadReasoningChain, hitlPayloadPlanning} {
if v, ok := payload[k]; ok && v != nil && fmt.Sprint(v) != "" {
review[k] = v
}
}
}
b, err := json.MarshalIndent(review, "", " ")
if err != nil {
return fmt.Sprintf(`{"hitlMode":%q,"toolName":%q}`, normalizeHitlMode(hitlMode), toolName)
}
return string(b)
}
func parseAuditAgentLLMContent(content string) (hitlDecision, error) {
s := strings.TrimSpace(content)
if s == "" {
return hitlDecision{}, errors.New("empty content")
}
for _, candidate := range auditAgentJSONCandidates(s) {
dec, comment, editedArgs, err := parseAuditAgentDecisionObject(candidate)
if err == nil {
return hitlDecision{
Decision: dec,
Comment: comment,
EditedArguments: editedArgs,
}, nil
}
}
return hitlDecision{}, fmt.Errorf("no valid decision json in response")
}
func auditAgentJSONCandidates(s string) []string {
out := make([]string, 0, 4)
seen := make(map[string]struct{})
add := func(c string) {
c = strings.TrimSpace(c)
if c == "" {
return
}
if _, ok := seen[c]; ok {
return
}
seen[c] = struct{}{}
out = append(out, c)
}
add(s)
add(stripMarkdownCodeFence(s))
if obj := extractFirstJSONObject(s); obj != "" {
add(obj)
}
if obj := extractFirstJSONObject(stripMarkdownCodeFence(s)); obj != "" {
add(obj)
}
return out
}
func stripMarkdownCodeFence(s string) string {
s = strings.TrimSpace(s)
for _, fence := range []string{"```json", "```JSON", "```"} {
if strings.HasPrefix(s, fence) {
s = strings.TrimPrefix(s, fence)
}
}
s = strings.TrimSuffix(s, "```")
return strings.TrimSpace(s)
}
func extractFirstJSONObject(s string) string {
start := strings.Index(s, "{")
if start < 0 {
return ""
}
depth := 0
inStr := false
esc := false
for i := start; i < len(s); i++ {
ch := s[i]
if inStr {
if esc {
esc = false
continue
}
if ch == '\\' {
esc = true
continue
}
if ch == '"' {
inStr = false
}
continue
}
switch ch {
case '"':
inStr = true
case '{':
depth++
case '}':
depth--
if depth == 0 {
return s[start : i+1]
}
}
}
return ""
}
func parseAuditAgentDecisionObject(jsonText string) (decision, comment string, editedArgs map[string]interface{}, err error) {
var parsed map[string]interface{}
if err := json.Unmarshal([]byte(jsonText), &parsed); err != nil {
return "", "", nil, err
}
rawDecision := auditAgentPickString(parsed, "decision", "Decision", "result", "action", "verdict", "决策", "决定")
decision = normalizeAuditAgentDecision(rawDecision)
if decision == "" {
return "", "", nil, fmt.Errorf("missing decision")
}
comment = auditAgentPickString(parsed, "comment", "Comment", "reason", "message", "rationale", "备注", "理由", "说明")
editedArgs = auditAgentPickObject(parsed, "editedArguments", "edited_arguments", "editedArgs")
return decision, strings.TrimSpace(comment), editedArgs, nil
}
func auditAgentPickString(m map[string]interface{}, keys ...string) string {
for _, k := range keys {
if v, ok := m[k]; ok && v != nil {
s := strings.TrimSpace(fmt.Sprint(v))
if s != "" {
return s
}
}
}
return ""
}
func auditAgentPickObject(m map[string]interface{}, keys ...string) map[string]interface{} {
for _, k := range keys {
v, ok := m[k]
if !ok || v == nil {
continue
}
switch t := v.(type) {
case map[string]interface{}:
if len(t) > 0 {
return t
}
case string:
s := strings.TrimSpace(t)
if s == "" || s == "{}" {
continue
}
var obj map[string]interface{}
if err := json.Unmarshal([]byte(s), &obj); err == nil && len(obj) > 0 {
return obj
}
}
}
return nil
}
func normalizeAuditAgentDecision(v string) string {
d := strings.ToLower(strings.TrimSpace(v))
switch d {
case "approve", "approved", "pass", "passed", "allow", "allowed", "yes", "ok", "accept", "accepted":
return "approve"
case "reject", "rejected", "deny", "denied", "no", "block", "blocked", "refuse", "refused":
return "reject"
}
switch strings.TrimSpace(v) {
case "通过", "批准", "允许", "同意", "放行":
return "approve"
case "拒绝", "驳回", "禁止", "否决":
return "reject"
}
return ""
}
type hitlAuditStrategyReq struct {
AuditAgentPrompt string `json:"auditAgentPrompt"`
AuditAgentPromptReviewEdit string `json:"auditAgentPromptReviewEdit"`
}
func (h *AgentHandler) GetHITLAuditStrategy(c *gin.Context) {
approvalPrompt := config.DefaultHitlAuditAgentPrompt()
reviewEditPrompt := config.DefaultHitlAuditAgentPromptReviewEdit()
approvalCustom := false
reviewEditCustom := false
if h.config != nil {
approvalPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("approval")
reviewEditPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("review_edit")
approvalCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPrompt) != ""
reviewEditCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPromptReviewEdit) != ""
}
c.JSON(http.StatusOK, gin.H{
"auditAgentPrompt": approvalPrompt,
"auditAgentPromptCustom": approvalCustom,
"auditAgentPromptReviewEdit": reviewEditPrompt,
"auditAgentPromptReviewEditCustom": reviewEditCustom,
"defaultAuditAgentPrompt": config.DefaultHitlAuditAgentPrompt(),
"defaultAuditAgentPromptReviewEdit": config.DefaultHitlAuditAgentPromptReviewEdit(),
})
}
func (h *AgentHandler) UpdateHITLAuditStrategy(c *gin.Context) {
if h.hitlStrategySaver == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 策略持久化不可用"})
return
}
var req hitlAuditStrategyReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
approvalPrompt := strings.TrimSpace(req.AuditAgentPrompt)
reviewEditPrompt := strings.TrimSpace(req.AuditAgentPromptReviewEdit)
if err := h.hitlStrategySaver.UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt); err != nil {
h.logger.Warn("保存审计 Agent 提示词失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "audit_strategy_update", "HITL 审计策略更新", "hitl_config", "audit_agent_prompt", nil)
}
if h.config != nil {
h.config.Hitl.AuditAgentPrompt = approvalPrompt
h.config.Hitl.AuditAgentPromptReviewEdit = reviewEditPrompt
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"auditAgentPrompt": config.HitlConfig{AuditAgentPrompt: approvalPrompt}.EffectiveAuditAgentPromptForMode("approval"),
"auditAgentPromptCustom": approvalPrompt != "",
"auditAgentPromptReviewEdit": config.HitlConfig{AuditAgentPromptReviewEdit: reviewEditPrompt}.EffectiveAuditAgentPromptForMode("review_edit"),
"auditAgentPromptReviewEditCustom": reviewEditPrompt != "",
})
}
// HitlAuditStrategySaver 持久化审计 Agent 提示词到 config.yaml。
type HitlAuditStrategySaver interface {
UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error
}
// SetHitlAuditStrategySaver 设置审计策略落盘。
func (h *AgentHandler) SetHitlAuditStrategySaver(s HitlAuditStrategySaver) {
h.hitlStrategySaver = s
}
+88
View File
@@ -0,0 +1,88 @@
package handler
import (
"strings"
"testing"
)
func TestParseAuditAgentLLMContentApprove(t *testing.T) {
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"与任务一致"}`)
if err != nil {
t.Fatal(err)
}
if d.Decision != "approve" || d.Comment != "与任务一致" {
t.Fatalf("unexpected %+v", d)
}
}
func TestParseAuditAgentLLMContentReject(t *testing.T) {
d, err := parseAuditAgentLLMContent("```json\n{\"decision\":\"reject\",\"comment\":\"风险过高\"}\n```")
if err != nil {
t.Fatal(err)
}
if d.Decision != "reject" {
t.Fatalf("expected reject, got %s", d.Decision)
}
}
func TestParseAuditAgentLLMContentInvalid(t *testing.T) {
_, err := parseAuditAgentLLMContent(`{"decision":"maybe"}`)
if err == nil {
t.Fatal("expected error for invalid decision")
}
}
func TestParseAuditAgentLLMContentProseWrapped(t *testing.T) {
d, err := parseAuditAgentLLMContent("好的,裁决如下:\n```json\n{\"decision\":\"approve\",\"comment\":\"只读 ls\"}\n```\n以上。")
if err != nil {
t.Fatal(err)
}
if d.Decision != "approve" {
t.Fatalf("expected approve, got %s", d.Decision)
}
}
func TestParseAuditAgentLLMContentChineseDecision(t *testing.T) {
d, err := parseAuditAgentLLMContent(`{"decision":"通过","comment":"风险低"}`)
if err != nil {
t.Fatal(err)
}
if d.Decision != "approve" {
t.Fatalf("expected approve, got %s", d.Decision)
}
}
func TestParseAuditAgentLLMContentWithEditedArguments(t *testing.T) {
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"收窄路径","editedArguments":{"path":"/safe"}}`)
if err != nil {
t.Fatal(err)
}
if d.Decision != "approve" {
t.Fatalf("expected approve, got %s", d.Decision)
}
if d.EditedArguments == nil || d.EditedArguments["path"] != "/safe" {
t.Fatalf("unexpected edited args: %+v", d.EditedArguments)
}
}
func TestBuildAuditAgentReviewInputIncludesMode(t *testing.T) {
s := buildAuditAgentReviewInput("review_edit", "execute", map[string]interface{}{
"arguments": `{"command":"pwd"}`,
})
if !strings.Contains(s, "review_edit") || !strings.Contains(s, "execute") {
t.Fatalf("unexpected input: %s", s)
}
}
func TestBuildAuditAgentReviewInput(t *testing.T) {
s := buildAuditAgentReviewInput("approval", "nmap", map[string]interface{}{
"arguments": `{"target":"10.0.0.1"}`,
"userMessage": "扫描内网",
})
if s == "" {
t.Fatal("expected non-empty input")
}
if !strings.Contains(s, "nmap") || !strings.Contains(s, "10.0.0.1") || !strings.Contains(s, "扫描内网") {
t.Fatalf("unexpected input: %s", s)
}
}
+97
View File
@@ -0,0 +1,97 @@
package handler
import (
"strings"
)
type hitlCognitionState struct {
AssistantMessageID string
UserMessage string
Thinking string
ReasoningChain string
Planning string
}
// GetHitlCognition 返回当前运行任务上缓存的本轮 HITL 上下文(不含会话历史)。
func (m *AgentTaskManager) GetHitlCognition(conversationID string) hitlCognitionFields {
conversationID = strings.TrimSpace(conversationID)
if m == nil || conversationID == "" {
return hitlCognitionFields{}
}
m.mu.RLock()
defer m.mu.RUnlock()
t, ok := m.tasks[conversationID]
if !ok || t == nil || t.hitlCognition == nil {
return hitlCognitionFields{}
}
c := t.hitlCognition
return hitlCognitionFields{
UserMessage: c.UserMessage,
Thinking: c.Thinking,
ReasoningChain: c.ReasoningChain,
Planning: c.Planning,
}
}
// ResetHitlCognition 新任务开始时重置本轮 HITL 上下文。
func (m *AgentTaskManager) ResetHitlCognition(conversationID, userMessage string) {
conversationID = strings.TrimSpace(conversationID)
if m == nil || conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
t, ok := m.tasks[conversationID]
if !ok || t == nil {
return
}
t.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(userMessage)}
}
// SetHitlAssistantMessageID 记录当前助手消息 ID,供 HITL 与 DB 回退对齐。
func (m *AgentTaskManager) SetHitlAssistantMessageID(conversationID, assistantMessageID string) {
conversationID = strings.TrimSpace(conversationID)
assistantMessageID = strings.TrimSpace(assistantMessageID)
if m == nil || conversationID == "" || assistantMessageID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
t, ok := m.tasks[conversationID]
if !ok || t == nil {
return
}
if t.hitlCognition == nil {
t.hitlCognition = &hitlCognitionState{}
}
t.hitlCognition.AssistantMessageID = assistantMessageID
}
// UpdateHitlCognitionSnapshot 从进行中的进度流快照更新 thinking / reasoning / planning。
func (m *AgentTaskManager) UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoningChain, planning string) {
conversationID = strings.TrimSpace(conversationID)
if m == nil || conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
t, ok := m.tasks[conversationID]
if !ok || t == nil {
return
}
if t.hitlCognition == nil {
t.hitlCognition = &hitlCognitionState{}
}
if id := strings.TrimSpace(assistantMessageID); id != "" {
t.hitlCognition.AssistantMessageID = id
}
if s := strings.TrimSpace(thinking); s != "" {
t.hitlCognition.Thinking = s
}
if s := strings.TrimSpace(reasoningChain); s != "" {
t.hitlCognition.ReasoningChain = s
}
if s := strings.TrimSpace(planning); s != "" {
t.hitlCognition.Planning = s
}
}
+102
View File
@@ -0,0 +1,102 @@
package handler
import (
"strings"
)
const (
hitlPayloadUserMessage = "userMessage"
hitlPayloadThinking = "thinking"
hitlPayloadReasoningChain = "reasoningChain"
hitlPayloadPlanning = "planning"
)
type hitlCognitionFields struct {
UserMessage string
Thinking string
ReasoningChain string
Planning string
}
func (h *AgentHandler) enrichHitlApprovalPayload(conversationID, assistantMessageID string, payload map[string]interface{}) {
if h == nil || payload == nil {
return
}
cog := h.collectHitlCognition(conversationID, assistantMessageID)
if s := strings.TrimSpace(cog.UserMessage); s != "" {
payload[hitlPayloadUserMessage] = s
}
if s := strings.TrimSpace(cog.Thinking); s != "" {
payload[hitlPayloadThinking] = s
}
if s := strings.TrimSpace(cog.ReasoningChain); s != "" {
payload[hitlPayloadReasoningChain] = s
}
if s := strings.TrimSpace(cog.Planning); s != "" {
payload[hitlPayloadPlanning] = s
}
}
func (h *AgentHandler) collectHitlCognition(conversationID, assistantMessageID string) hitlCognitionFields {
var out hitlCognitionFields
if h.tasks != nil {
out = h.tasks.GetHitlCognition(conversationID)
}
if strings.TrimSpace(out.UserMessage) == "" && h.db != nil {
if msg, err := h.db.GetTurnUserMessage(conversationID, assistantMessageID); err == nil {
out.UserMessage = msg
}
}
if h.db != nil && assistantMessageID != "" {
dbCog, err := h.db.GetAssistantCognitionTexts(assistantMessageID)
if err == nil {
if strings.TrimSpace(out.Thinking) == "" {
out.Thinking = dbCog.Thinking
}
if strings.TrimSpace(out.ReasoningChain) == "" {
out.ReasoningChain = dbCog.ReasoningChain
}
if strings.TrimSpace(out.Planning) == "" {
out.Planning = dbCog.Planning
}
}
}
return out
}
func snapshotHitlCognitionFromStreams(thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) (thinking, reasoningChain, planning string) {
if len(thinkingStreams) > 0 {
var thinkingParts, reasoningParts []string
for _, tb := range thinkingStreams {
if tb == nil {
continue
}
content := strings.TrimSpace(tb.b.String())
if content == "" {
continue
}
if tb.persistAs == "reasoning_chain" {
reasoningParts = append(reasoningParts, content)
} else {
thinkingParts = append(thinkingParts, content)
}
}
thinking = strings.Join(thinkingParts, "\n\n")
reasoningChain = strings.Join(reasoningParts, "\n\n")
}
if respPlan != nil {
planning = strings.TrimSpace(respPlan.b.String())
}
return thinking, reasoningChain, planning
}
func (h *AgentHandler) syncHitlCognitionFromProgress(conversationID, assistantMessageID string, thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) {
if h == nil || h.tasks == nil {
return
}
thinking, reasoning, planning := snapshotHitlCognitionFromStreams(thinkingStreams, respPlan)
if thinking == "" && reasoning == "" && planning == "" {
return
}
h.tasks.UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoning, planning)
}
+46
View File
@@ -0,0 +1,46 @@
package handler
import (
"os"
"path/filepath"
"testing"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
func TestEnrichHitlApprovalPayload(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
if err != nil {
t.Fatalf("db: %v", err)
}
defer os.RemoveAll(tmp)
conv, err := db.CreateConversation("hitl ctx", database.ConversationCreateMeta{})
if err != nil {
t.Fatalf("conv: %v", err)
}
if _, err := db.AddMessage(conv.ID, "user", "scan 10.0.0.1 please", nil); err != nil {
t.Fatalf("user msg: %v", err)
}
asst, err := db.AddMessage(conv.ID, "assistant", "", nil)
if err != nil {
t.Fatalf("asst msg: %v", err)
}
if err := db.AddProcessDetail(asst.ID, conv.ID, "thinking", "need port scan first", nil); err != nil {
t.Fatalf("detail: %v", err)
}
h := &AgentHandler{db: db, tasks: NewAgentTaskManager()}
payload := map[string]interface{}{"toolName": "nmap", "arguments": "{}"}
h.enrichHitlApprovalPayload(conv.ID, asst.ID, payload)
if got := payload["userMessage"]; got != "scan 10.0.0.1 please" {
t.Fatalf("userMessage=%v", got)
}
if got := payload["thinking"]; got != "need port scan first" {
t.Fatalf("thinking=%v", got)
}
}
+132
View File
@@ -0,0 +1,132 @@
package handler
import (
"encoding/json"
"strings"
"time"
)
const hitlPayloadExecutionResult = "executionResult"
type hitlExecutionResult struct {
Success bool `json:"success"`
Result string `json:"result,omitempty"`
ToolName string `json:"toolName,omitempty"`
ToolCallID string `json:"toolCallId,omitempty"`
RecordedAt time.Time `json:"recordedAt"`
}
type hitlApprovedExecTrack struct {
InterruptID string
ConversationID string
ToolName string
ToolCallID string
}
// TrackApprovedHitlExecution 审批通过后登记,待 tool_result 回写执行结果。
func (m *HITLManager) TrackApprovedHitlExecution(interruptID, conversationID, toolName, toolCallID string) {
if m == nil {
return
}
interruptID = strings.TrimSpace(interruptID)
conversationID = strings.TrimSpace(conversationID)
if interruptID == "" || conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if m.approvedExec == nil {
m.approvedExec = make(map[string][]hitlApprovedExecTrack)
}
m.approvedExec[conversationID] = append(m.approvedExec[conversationID], hitlApprovedExecTrack{
InterruptID: interruptID,
ConversationID: conversationID,
ToolName: strings.TrimSpace(toolName),
ToolCallID: strings.TrimSpace(toolCallID),
})
}
func (m *HITLManager) popApprovedInterruptForTool(conversationID, toolCallID, toolName string) string {
if m == nil {
return ""
}
conversationID = strings.TrimSpace(conversationID)
toolCallID = strings.TrimSpace(toolCallID)
toolName = strings.TrimSpace(toolName)
m.mu.Lock()
defer m.mu.Unlock()
queue := m.approvedExec[conversationID]
if len(queue) == 0 {
return ""
}
idx := -1
if toolCallID != "" {
for i, t := range queue {
if t.ToolCallID == toolCallID {
idx = i
break
}
}
}
if idx < 0 && toolName != "" {
for i, t := range queue {
if strings.EqualFold(t.ToolName, toolName) {
idx = i
break
}
}
}
if idx < 0 {
return ""
}
id := queue[idx].InterruptID
queue = append(queue[:idx], queue[idx+1:]...)
if len(queue) == 0 {
delete(m.approvedExec, conversationID)
} else {
m.approvedExec[conversationID] = queue
}
return id
}
func mergeHitlPayloadExecutionResult(payloadJSON string, exec hitlExecutionResult) (string, error) {
root := make(map[string]interface{})
if strings.TrimSpace(payloadJSON) != "" {
_ = json.Unmarshal([]byte(payloadJSON), &root)
}
if root == nil {
root = make(map[string]interface{})
}
root[hitlPayloadExecutionResult] = exec
out, err := json.Marshal(root)
if err != nil {
return payloadJSON, err
}
return string(out), nil
}
func (h *AgentHandler) recordHitlToolExecutionResult(conversationID, toolCallID, toolName string, success bool, result string) {
if h == nil || h.hitlManager == nil || h.db == nil {
return
}
interruptID := h.hitlManager.popApprovedInterruptForTool(conversationID, toolCallID, toolName)
if interruptID == "" {
return
}
var payloadJSON string
err := h.db.QueryRow(`SELECT payload FROM hitl_interrupts WHERE id = ?`, interruptID).Scan(&payloadJSON)
if err != nil {
return
}
merged, err := mergeHitlPayloadExecutionResult(payloadJSON, hitlExecutionResult{
Success: success,
Result: strings.TrimSpace(result),
ToolName: strings.TrimSpace(toolName),
ToolCallID: strings.TrimSpace(toolCallID),
RecordedAt: time.Now(),
})
if err != nil {
return
}
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET payload = ? WHERE id = ?`, merged, interruptID)
}
+39
View File
@@ -0,0 +1,39 @@
package handler
import (
"encoding/json"
"testing"
)
func TestMergeHitlPayloadExecutionResult(t *testing.T) {
merged, err := mergeHitlPayloadExecutionResult(`{"userMessage":"hi","toolName":"nmap"}`, hitlExecutionResult{
Success: true,
Result: "open ports: 80",
})
if err != nil {
t.Fatal(err)
}
var root map[string]interface{}
if err := json.Unmarshal([]byte(merged), &root); err != nil {
t.Fatal(err)
}
if root["userMessage"] != "hi" {
t.Fatalf("userMessage lost: %v", root["userMessage"])
}
exec, ok := root["executionResult"].(map[string]interface{})
if !ok || exec["success"] != true {
t.Fatalf("executionResult missing: %v", root["executionResult"])
}
}
func TestPopApprovedInterruptForTool(t *testing.T) {
m := NewHITLManager(nil, nil)
m.TrackApprovedHitlExecution("hitl_a", "conv1", "nmap", "tc1")
m.TrackApprovedHitlExecution("hitl_b", "conv1", "exec", "")
if id := m.popApprovedInterruptForTool("conv1", "tc1", "nmap"); id != "hitl_a" {
t.Fatalf("tc1 match=%q", id)
}
if id := m.popApprovedInterruptForTool("conv1", "", "exec"); id != "hitl_b" {
t.Fatalf("tool name match=%q", id)
}
}
+263
View File
@@ -0,0 +1,263 @@
package handler
import (
"database/sql"
"errors"
"math"
"net/http"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin"
)
func normalizeHitlReviewer(v string) string {
switch strings.ToLower(strings.TrimSpace(v)) {
case "audit_agent", "agent", "ai":
return "audit_agent"
default:
return "human"
}
}
func normalizeHitlDecidedBy(v string) string {
switch strings.ToLower(strings.TrimSpace(v)) {
case "audit_agent", "agent", "ai":
return "audit_agent"
case "system", "timeout":
return "system"
case "manual":
return "manual"
default:
return "human"
}
}
func (m *HITLManager) migrateHitlSchemaColumns() {
_, _ = m.db.Exec(`ALTER TABLE hitl_interrupts ADD COLUMN decided_by TEXT NOT NULL DEFAULT 'human'`)
_, _ = m.db.Exec(`ALTER TABLE hitl_conversation_configs ADD COLUMN reviewer TEXT NOT NULL DEFAULT 'human'`)
}
func hitlInterruptRowToMap(
id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string,
messageID sql.NullString,
decision, comment sql.NullString,
createdAt time.Time,
decidedAt sql.NullTime,
) map[string]interface{} {
msgID := ""
if messageID.Valid {
msgID = messageID.String
}
return map[string]interface{}{
"id": id,
"conversationId": cid,
"messageId": msgID,
"mode": mode,
"toolName": toolName,
"toolCallId": toolCallID,
"payload": payload,
"status": rowStatus,
"decision": decision.String,
"comment": comment.String,
"decidedBy": decidedBy,
"createdAt": createdAt,
"decidedAt": func() interface{} {
if decidedAt.Valid {
return decidedAt.Time
}
return nil
}(),
}
}
func (h *AgentHandler) buildHitlListQuery(logs bool) (string, []interface{}) {
where, args := h.buildHitlLogsWhere(logs)
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts` + where
return q, args
}
func (h *AgentHandler) buildHitlLogsWhere(logs bool) (string, []interface{}) {
q := " WHERE 1=1"
args := []interface{}{}
if logs {
q += " AND status != 'pending'"
} else {
q += " AND status = 'pending'"
}
return q, args
}
func (h *AgentHandler) appendHitlListFilters(q string, args []interface{}, c *gin.Context) (string, []interface{}) {
conversationID := strings.TrimSpace(c.Query("conversationId"))
toolName := strings.TrimSpace(c.Query("toolName"))
decision := strings.TrimSpace(c.Query("decision"))
decidedBy := strings.TrimSpace(c.Query("decidedBy"))
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("q"))
if conversationID != "" {
q += " AND conversation_id = ?"
args = append(args, conversationID)
}
if toolName != "" {
q += " AND tool_name LIKE ?"
args = append(args, "%"+toolName+"%")
}
if decision != "" && decision != "all" {
q += " AND decision = ?"
args = append(args, decision)
}
if decidedBy != "" && decidedBy != "all" {
q += " AND COALESCE(decided_by,'human') = ?"
args = append(args, normalizeHitlDecidedBy(decidedBy))
}
if status != "" && status != "all" {
q += " AND status = ?"
args = append(args, status)
}
if search != "" {
like := "%" + search + "%"
q += " AND (id LIKE ? OR conversation_id LIKE ? OR tool_name LIKE ? OR payload LIKE ? OR COALESCE(decision_comment,'') LIKE ?)"
args = append(args, like, like, like, like, like)
}
return q, args
}
func (h *AgentHandler) scanHitlInterruptRows(rows *sql.Rows) ([]map[string]interface{}, error) {
items := make([]map[string]interface{}, 0)
for rows.Next() {
var id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
var messageID sql.NullString
var decision, comment sql.NullString
var createdAt time.Time
var decidedAt sql.NullTime
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt); err != nil {
continue
}
items = append(items, hitlInterruptRowToMap(id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
}
return items, nil
}
func (h *AgentHandler) countHitlQuery(baseQ string, args []interface{}) (int, error) {
countQ := "SELECT COUNT(*) FROM (" + baseQ + ") AS hitl_cnt"
var total int
if err := h.db.QueryRow(countQ, args...).Scan(&total); err != nil {
return 0, err
}
return total, nil
}
func (h *AgentHandler) ListHITLLogs(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
if page < 1 {
page = 1
}
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
offset := (page - 1) * pageSize
q, args := h.buildHitlListQuery(true)
q, args = h.appendHitlListFilters(q, args, c)
total, err := h.countHitlQuery(q, args)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
q += " ORDER BY COALESCE(decided_at, created_at) DESC LIMIT ? OFFSET ?"
args = append(args, pageSize, offset)
rows, err := h.db.Query(q, args...)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer rows.Close()
items, err := h.scanHitlInterruptRows(rows)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total, "retentionDays": h.hitlRetentionDays()})
}
func (h *AgentHandler) hitlRetentionDays() int {
if h.config != nil {
return h.config.Hitl.RetentionDaysEffective()
}
return config.HitlConfig{}.RetentionDaysEffective()
}
// DeleteHITLLogs 批量删除或按筛选清空已决策的人机协同审计日志(不删除 pending)。
func (h *AgentHandler) DeleteHITLLogs(c *gin.Context) {
var request struct {
IDs []string `json:"ids"`
All bool `json:"all"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
var deleted int64
var err error
if request.All {
where, args := h.buildHitlLogsWhere(true)
where, args = h.appendHitlListFilters(where, args, c)
deleted, err = h.db.DeleteHitlInterruptLogsMatching(where, args)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "logs_clear", "清空人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
"deleted": deleted,
})
}
} else {
if len(request.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "审计日志 ID 列表不能为空"})
return
}
deleted, err = h.db.DeleteHitlInterruptLogsByIDs(request.IDs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "logs_delete_batch", "批量删除人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
"count": len(request.IDs),
"deleted": deleted,
})
}
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功", "deleted": deleted})
}
func (h *AgentHandler) GetHITLLog(c *gin.Context) {
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE id = ?`
var rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
var messageID sql.NullString
var decision, comment sql.NullString
var createdAt time.Time
var decidedAt sql.NullTime
err := h.db.QueryRow(q, id).Scan(&rowID, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt)
if errors.Is(err, sql.ErrNoRows) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, hitlInterruptRowToMap(rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
}
+295 -14
View File
@@ -5,6 +5,7 @@ import (
"errors" "errors"
"io" "io"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -23,6 +24,8 @@ import (
type MonitorHandler struct { type MonitorHandler struct {
mcpServer *mcp.Server mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager externalMCPMgr *mcp.ExternalMCPManager
taskManager *AgentTaskManager
agentHandler *AgentHandler
executor *security.Executor executor *security.Executor
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
@@ -56,16 +59,44 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
h.externalMCPMgr = mgr h.externalMCPMgr = mgr
} }
// SetTaskManager 设置 Agent 任务管理器(用于 Eino execute 等按 executionId 终止)。
func (h *MonitorHandler) SetTaskManager(mgr *AgentTaskManager) {
h.taskManager = mgr
}
// SetAgentHandler 设置 Agent 处理器(MCP 监控终止与对话页「中断并继续」共用逻辑)。
func (h *MonitorHandler) SetAgentHandler(ah *AgentHandler) {
h.agentHandler = ah
}
const monitorPageTopTools = 6
// MonitorStatsSummary 工具调用汇总
type MonitorStatsSummary struct {
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
ToolCount int `json:"toolCount"`
}
// MonitorResponse 监控响应 // MonitorResponse 监控响应
type MonitorResponse struct { type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"` Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"` Summary *MonitorStatsSummary `json:"summary"`
Timestamp time.Time `json:"timestamp"` TopTools []*mcp.ToolStats `json:"topTools"`
Total int `json:"total,omitempty"` Timestamp time.Time `json:"timestamp"`
Page int `json:"page,omitempty"` Total int `json:"total"`
PageSize int `json:"page_size,omitempty"` Page int `json:"page"`
TotalPages int `json:"total_pages,omitempty"` PageSize int `json:"pageSize"`
RetentionDays int `json:"retention_days,omitempty"` TotalPages int `json:"totalPages"`
RetentionDays int `json:"retentionDays"`
}
// StatsResponse 统计信息响应(Dashboard 等)
type StatsResponse struct {
Summary *MonitorStatsSummary `json:"summary"`
TopTools []*mcp.ToolStats `json:"topTools"`
} }
// Monitor 获取监控信息 // Monitor 获取监控信息
@@ -89,8 +120,9 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool // 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool
toolName := normalizeToolNameFilter(c.Query("tool")) toolName := normalizeToolNameFilter(c.Query("tool"))
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName) executions, total := h.loadExecutionListWithPagination(page, pageSize, status, toolName)
stats := h.loadStats() h.enrichExecutionsConversationID(executions)
summary, topTools := h.loadStatsSummary(monitorPageTopTools)
totalPages := (total + pageSize - 1) / pageSize totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 { if totalPages == 0 {
@@ -99,7 +131,8 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
c.JSON(http.StatusOK, MonitorResponse{ c.JSON(http.StatusOK, MonitorResponse{
Executions: executions, Executions: executions,
Stats: stats, Summary: summary,
TopTools: topTools,
Timestamp: time.Now(), Timestamp: time.Now(),
Total: total, Total: total,
Page: page, Page: page,
@@ -121,6 +154,112 @@ func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
return executions return executions
} }
func (h *MonitorHandler) loadExecutionListWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
if h.db == nil {
allExecutions := h.mcpServer.GetAllExecutions()
if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool {
filtered = append(filtered, exec)
}
}
allExecutions = filtered
}
total := len(allExecutions)
offset := (page - 1) * pageSize
end := offset + pageSize
if end > total {
end = total
}
if offset >= total {
return []*mcp.ToolExecution{}, total
}
pageSlice := allExecutions[offset:end]
out := make([]*mcp.ToolExecution, 0, len(pageSlice))
for _, exec := range pageSlice {
if exec == nil {
continue
}
out = append(out, slimToolExecution(exec))
}
return out, total
}
offset := (page - 1) * pageSize
executions, err := h.db.LoadToolExecutionListPage(offset, pageSize, status, toolName)
if err != nil {
h.logger.Warn("从数据库加载执行记录列表失败,回退到内存数据", zap.Error(err))
return h.loadExecutionListWithPaginationFromMemory(page, pageSize, status, toolName)
}
total, err := h.db.CountToolExecutions(status, toolName)
if err != nil {
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
total = offset + len(executions)
if len(executions) == pageSize {
total = offset + len(executions) + 1
}
}
return executions, total
}
func (h *MonitorHandler) loadExecutionListWithPaginationFromMemory(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
allExecutions := h.mcpServer.GetAllExecutions()
if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool {
filtered = append(filtered, exec)
}
}
allExecutions = filtered
}
total := len(allExecutions)
offset := (page - 1) * pageSize
end := offset + pageSize
if end > total {
end = total
}
if offset >= total {
return []*mcp.ToolExecution{}, total
}
pageSlice := allExecutions[offset:end]
out := make([]*mcp.ToolExecution, 0, len(pageSlice))
for _, exec := range pageSlice {
if exec == nil {
continue
}
out = append(out, slimToolExecution(exec))
}
return out, total
}
func slimToolExecution(exec *mcp.ToolExecution) *mcp.ToolExecution {
if exec == nil {
return nil
}
slim := &mcp.ToolExecution{
ID: exec.ID,
ToolName: exec.ToolName,
Status: exec.Status,
StartTime: exec.StartTime,
}
if exec.EndTime != nil {
end := *exec.EndTime
slim.EndTime = &end
}
if exec.Duration > 0 {
slim.Duration = exec.Duration
}
return slim
}
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) { func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
if h.db == nil { if h.db == nil {
allExecutions := h.mcpServer.GetAllExecutions() allExecutions := h.mcpServer.GetAllExecutions()
@@ -193,7 +332,78 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
return executions, total return executions, total
} }
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats { func (h *MonitorHandler) loadStatsSummary(topN int) (*MonitorStatsSummary, []*mcp.ToolStats) {
if topN <= 0 {
topN = monitorPageTopTools
}
if h.db != nil {
result, err := h.db.LoadToolStatsSummary(topN)
if err == nil {
return dbStatsSummaryToMonitor(result), result.TopTools
}
h.logger.Warn("从数据库加载统计汇总失败,回退到内存数据", zap.Error(err))
}
stats := h.loadStatsMap()
return summarizeToolStats(stats, topN)
}
func dbStatsSummaryToMonitor(result *database.ToolStatsSummaryResult) *MonitorStatsSummary {
if result == nil {
return &MonitorStatsSummary{}
}
summary := &MonitorStatsSummary{
TotalCalls: result.Summary.TotalCalls,
SuccessCalls: result.Summary.SuccessCalls,
FailedCalls: result.Summary.FailedCalls,
ToolCount: result.Summary.ToolCount,
}
if result.Summary.LastCallTime != nil {
t := *result.Summary.LastCallTime
summary.LastCallTime = &t
}
return summary
}
func summarizeToolStats(stats map[string]*mcp.ToolStats, topN int) (*MonitorStatsSummary, []*mcp.ToolStats) {
summary := &MonitorStatsSummary{}
if len(stats) == 0 {
return summary, nil
}
all := make([]*mcp.ToolStats, 0, len(stats))
for _, stat := range stats {
if stat == nil {
continue
}
summary.ToolCount++
summary.TotalCalls += stat.TotalCalls
summary.SuccessCalls += stat.SuccessCalls
summary.FailedCalls += stat.FailedCalls
if stat.LastCallTime != nil && (summary.LastCallTime == nil || stat.LastCallTime.After(*summary.LastCallTime)) {
t := *stat.LastCallTime
summary.LastCallTime = &t
}
if stat.TotalCalls > 0 {
statCopy := *stat
all = append(all, &statCopy)
}
}
sort.Slice(all, func(i, j int) bool {
if all[i].TotalCalls == all[j].TotalCalls {
return all[i].ToolName < all[j].ToolName
}
return all[i].TotalCalls > all[j].TotalCalls
})
if len(all) > topN {
all = all[:topN]
}
return summary, all
}
func (h *MonitorHandler) loadStatsMap() map[string]*mcp.ToolStats {
// 合并内部MCP服务器和外部MCP管理器的统计信息 // 合并内部MCP服务器和外部MCP管理器的统计信息
stats := make(map[string]*mcp.ToolStats) stats := make(map[string]*mcp.ToolStats)
@@ -247,6 +457,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
// 先从内部MCP服务器查找 // 先从内部MCP服务器查找
exec, exists := h.mcpServer.GetExecution(id) exec, exists := h.mcpServer.GetExecution(id)
if exists { if exists {
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
c.JSON(http.StatusOK, exec) c.JSON(http.StatusOK, exec)
return return
} }
@@ -255,6 +466,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
if h.externalMCPMgr != nil { if h.externalMCPMgr != nil {
exec, exists = h.externalMCPMgr.GetExecution(id) exec, exists = h.externalMCPMgr.GetExecution(id)
if exists { if exists {
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
c.JSON(http.StatusOK, exec) c.JSON(http.StatusOK, exec)
return return
} }
@@ -264,6 +476,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
if h.db != nil { if h.db != nil {
exec, err := h.db.GetToolExecution(id) exec, err := h.db.GetToolExecution(id)
if err == nil && exec != nil { if err == nil && exec != nil {
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
c.JSON(http.StatusOK, exec) c.JSON(http.StatusOK, exec)
return return
} }
@@ -290,6 +503,19 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
return return
} }
note = strings.TrimSpace(body.Note) note = strings.TrimSpace(body.Note)
convID := h.conversationIDForRunningExecution(id)
if convID != "" && h.agentHandler != nil {
if ok, payload := h.agentHandler.cancelToolContinueAfter(convID, id, note); ok {
h.logger.Info("MCP 监控页终止工具(与对话中断并继续一致)",
zap.String("executionId", id),
zap.String("conversationId", convID),
zap.Bool("hasNote", note != ""),
)
c.JSON(http.StatusOK, payload)
return
}
}
if h.mcpServer.CancelToolExecutionWithNote(id, note) { if h.mcpServer.CancelToolExecutionWithNote(id, note) {
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != "")) h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id}) c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
@@ -303,6 +529,52 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"}) c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
} }
func (h *MonitorHandler) enrichExecutionsConversationID(executions []*mcp.ToolExecution) {
for _, exec := range executions {
if exec == nil || exec.Status != "running" {
continue
}
exec.ConversationID = h.conversationIDForRunningExecution(exec.ID)
}
}
func (h *MonitorHandler) conversationIDForRunningExecution(executionID string) string {
executionID = strings.TrimSpace(executionID)
if executionID == "" || h.taskManager == nil {
return ""
}
if conv := h.taskManager.ConversationIDForActiveMCPExecution(executionID); conv != "" {
return conv
}
exec := h.lookupExecution(executionID)
if exec == nil || exec.Status != "running" {
return ""
}
if strings.TrimSpace(exec.ToolName) == "execute" {
if onlyConv, ok := h.taskManager.ConversationIDForActiveEinoExecute(); ok {
return onlyConv
}
}
return ""
}
func (h *MonitorHandler) lookupExecution(id string) *mcp.ToolExecution {
if exec, ok := h.mcpServer.GetExecution(id); ok {
return exec
}
if h.externalMCPMgr != nil {
if exec, ok := h.externalMCPMgr.GetExecution(id); ok {
return exec
}
}
if h.db != nil {
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
return exec
}
}
return nil
}
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求) // BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) { func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
var req struct { var req struct {
@@ -340,8 +612,17 @@ func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
// GetStats 获取统计信息 // GetStats 获取统计信息
func (h *MonitorHandler) GetStats(c *gin.Context) { func (h *MonitorHandler) GetStats(c *gin.Context) {
stats := h.loadStats() topN := 30
c.JSON(http.StatusOK, stats) if topStr := c.Query("top"); topStr != "" {
if t, err := strconv.Atoi(topStr); err == nil && t > 0 && t <= 100 {
topN = t
}
}
summary, topTools := h.loadStatsSummary(topN)
c.JSON(http.StatusOK, StatsResponse{
Summary: summary,
TopTools: topTools,
})
} }
// CallsTimelinePoint 调用趋势数据点 // CallsTimelinePoint 调用趋势数据点
+10 -2
View File
@@ -188,6 +188,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
var emptyResponseContinueAttempt int
for { for {
segmentMainIterationMax := 0 segmentMainIterationMax := 0
@@ -243,7 +244,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.agentsMarkdownDir, h.agentsMarkdownDir,
orch, orch,
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID), h.agentSessionContextBlock(conversationID),
) )
if result != nil && len(result.MCPExecutionIDs) > 0 { if result != nil && len(result.MCPExecutionIDs) > 0 {
@@ -251,6 +252,13 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
} }
if runErr == nil { if runErr == nil {
mw := &h.config.MultiAgent.EinoMiddleware
if h.tryContinueOnEinoEmptyResponse(taskCtx, mw, conversationID, result, &emptyResponseContinueAttempt, &curHistory, &curFinalMessage, progressCallback) {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause, taskCtx, timeoutCancel = h.rebindEinoRunningTask(conversationID, timeoutCancel)
continue
}
timeoutCancel() timeoutCancel()
break break
} }
@@ -430,7 +438,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
h.agentsMarkdownDir, h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration), strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID), h.agentSessionContextBlock(prep.ConversationID),
) )
if runErr == nil { if runErr == nil {
break break
+45 -6
View File
@@ -740,14 +740,21 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"executions": map[string]interface{}{ "executions": map[string]interface{}{
"type": "array", "type": "array",
"description": "执行记录列表", "description": "执行记录列表(轻量字段,不含 arguments/result",
"items": map[string]interface{}{ "items": map[string]interface{}{
"$ref": "#/components/schemas/ToolExecution", "$ref": "#/components/schemas/ToolExecution",
}, },
}, },
"stats": map[string]interface{}{ "summary": map[string]interface{}{
"type": "object", "type": "object",
"description": "统计信息", "description": "工具调用汇总",
},
"topTools": map[string]interface{}{
"type": "array",
"description": "调用量 Top N 工具",
"items": map[string]interface{}{
"type": "object",
},
}, },
"timestamp": map[string]interface{}{ "timestamp": map[string]interface{}{
"type": "string", "type": "string",
@@ -756,20 +763,24 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
}, },
"total": map[string]interface{}{ "total": map[string]interface{}{
"type": "integer", "type": "integer",
"description": "总数", "description": "执行记录总数",
}, },
"page": map[string]interface{}{ "page": map[string]interface{}{
"type": "integer", "type": "integer",
"description": "当前页", "description": "当前页",
}, },
"page_size": map[string]interface{}{ "pageSize": map[string]interface{}{
"type": "integer", "type": "integer",
"description": "每页数量", "description": "每页数量",
}, },
"total_pages": map[string]interface{}{ "totalPages": map[string]interface{}{
"type": "integer", "type": "integer",
"description": "总页数", "description": "总页数",
}, },
"retentionDays": map[string]interface{}{
"type": "integer",
"description": "执行记录保留天数",
},
}, },
}, },
"ConfigResponse": map[string]interface{}{ "ConfigResponse": map[string]interface{}{
@@ -1232,6 +1243,34 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"type": "string", "type": "string",
}, },
}, },
{
"name": "project_id",
"in": "query",
"required": false,
"description": "按项目筛选;传 __none__ 表示仅未绑定项目的对话",
"schema": map[string]interface{}{
"type": "string",
},
},
{
"name": "exclude_grouped",
"in": "query",
"required": false,
"description": "为 true 时排除已加入分组的对话(默认在未搜索且未按项目筛选时启用)",
"schema": map[string]interface{}{
"type": "boolean",
},
},
{
"name": "sort_by",
"in": "query",
"required": false,
"description": "排序字段:updated_at(默认)或 created_at",
"schema": map[string]interface{}{
"type": "string",
"enum": []string{"updated_at", "created_at"},
},
},
}, },
"responses": map[string]interface{}{ "responses": map[string]interface{}{
"200": map[string]interface{}{ "200": map[string]interface{}{
+62
View File
@@ -7,6 +7,45 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// agentSessionContextBlock 注入会话工作目录、项目黑板与用户原文锚点(用于 system prompt 追加块)。
func (h *AgentHandler) agentSessionContextBlock(conversationID string) string {
var parts []string
if ws := h.buildWorkspaceBlock(conversationID); ws != "" {
parts = append(parts, ws)
}
if bb := h.projectBlackboardBlock(conversationID); bb != "" {
parts = append(parts, bb)
}
if uv := h.userVerbatimAnchorBlock(conversationID); uv != "" {
parts = append(parts, uv)
}
return strings.Join(parts, "\n\n")
}
func (h *AgentHandler) buildWorkspaceBlock(conversationID string) string {
if h == nil || h.config == nil {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID := h.conversationProjectID(conversationID)
rel := project.WorkspaceRootDir(h.config.Agent.WorkspaceRootDir, projectID, conversationID)
abs, err := project.EnsureWorkspace(rel)
if err != nil {
if h.logger != nil {
h.logger.Warn("创建会话工作目录失败",
zap.String("conversationId", conversationID),
zap.String("projectId", projectID),
zap.String("path", rel),
zap.Error(err))
}
return ""
}
return project.BuildWorkspaceBlock(abs)
}
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。 // projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string { func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
if h == nil || h.db == nil || h.config == nil { if h == nil || h.db == nil || h.config == nil {
@@ -31,6 +70,29 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
return strings.TrimSpace(block) return strings.TrimSpace(block)
} }
// userVerbatimAnchorBlock 从 messages 表构建用户各轮原文锚点(压缩后仍由 summarization Finalize 刷新)。
func (h *AgentHandler) userVerbatimAnchorBlock(conversationID string) string {
if h == nil || h.db == nil || h.config == nil {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
maxRunes := h.config.MultiAgent.UserVerbatimAnchorMaxRunesEffective()
if maxRunes < 0 {
return ""
}
msgs, err := h.db.GetMessages(conversationID)
if err != nil {
if h.logger != nil {
h.logger.Warn("构建用户原文锚点失败", zap.String("conversationId", conversationID), zap.Error(err))
}
return ""
}
return project.BuildUserVerbatimAnchorBlockFromMessages(msgs, maxRunes)
}
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。 // conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
func (h *AgentHandler) conversationProjectID(conversationID string) string { func (h *AgentHandler) conversationProjectID(conversationID string) string {
if h == nil || h.db == nil { if h == nil || h.db == nil {
+46 -23
View File
@@ -447,7 +447,7 @@ func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
} }
func (h *RobotHandler) cmdList() string { func (h *RobotHandler) cmdList() string {
convs, err := h.db.ListConversations(50, 0, "", "") convs, err := h.db.ListConversations(50, 0, "", "", "")
if err != nil { if err != nil {
return "获取对话列表失败: " + err.Error() return "获取对话列表失败: " + err.Error()
} }
@@ -711,12 +711,27 @@ type wecomReplyXML struct {
Content string `xml:"Content"` Content string `xml:"Content"`
} }
// wecomRequireToken 企业微信回调必须配置 Token;未配置时拒绝请求,防止未授权触发 Agent。
func (h *RobotHandler) wecomRequireToken(c *gin.Context) (string, bool) {
token := strings.TrimSpace(h.config.Robots.Wecom.Token)
if token == "" {
h.logger.Warn("企业微信已启用但未配置 token,已拒绝回调(请在配置中设置 robots.wecom.token")
c.String(http.StatusForbidden, "")
return "", false
}
return token, true
}
// HandleWecomGET 企业微信 URL 校验(GET // HandleWecomGET 企业微信 URL 校验(GET
func (h *RobotHandler) HandleWecomGET(c *gin.Context) { func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
if !h.config.Robots.Wecom.Enabled { if !h.config.Robots.Wecom.Enabled {
c.String(http.StatusNotFound, "") c.String(http.StatusNotFound, "")
return return
} }
token, ok := h.wecomRequireToken(c)
if !ok {
return
}
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串 // Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
echostr := c.Query("echostr") echostr := c.Query("echostr")
msgSignature := c.Query("msg_signature") msgSignature := c.Query("msg_signature")
@@ -724,7 +739,7 @@ func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
nonce := c.Query("nonce") nonce := c.Query("nonce")
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1 // 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr) signature := h.signWecomRequest(token, timestamp, nonce, echostr)
if signature != msgSignature { if signature != msgSignature {
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature)) h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
c.String(http.StatusBadRequest, "invalid signature") c.String(http.StatusBadRequest, "invalid signature")
@@ -865,27 +880,28 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
} }
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw))) h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段 // 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
// 若配置了 Token 则必须校验签名,避免未授权请求触发 Agent(防止平台被接管) // 启用企业微信时必须配置 token 校验签名,避免未授权请求触发 Agent
token := h.config.Robots.Wecom.Token token, ok := h.wecomRequireToken(c)
if token != "" { if !ok {
if msgSignature == "" { return
h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需配置 token 并确保回调携带 msg_signature") }
c.String(http.StatusOK, "") if msgSignature == "" {
return h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需确保回调携带 msg_signature")
} c.String(http.StatusOK, "")
var tmp wecomXML return
if err := xml.Unmarshal(bodyRaw, &tmp); err != nil { }
h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err)) var tmp wecomXML
c.String(http.StatusOK, "") if err := xml.Unmarshal(bodyRaw, &tmp); err != nil {
return h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err))
} c.String(http.StatusOK, "")
expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt) return
if expected != msgSignature { }
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature)) expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt)
c.String(http.StatusOK, "") if expected != msgSignature {
return h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
} c.String(http.StatusOK, "")
return
} }
var body wecomXML var body wecomXML
@@ -899,6 +915,13 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
// 保存企业 ID(用于明文模式回复) // 保存企业 ID(用于明文模式回复)
enterpriseID := body.ToUserName enterpriseID := body.ToUserName
// 配置了 EncodingAESKey 时必须走加密消息,拒绝明文 XML 绕过
if strings.TrimSpace(h.config.Robots.Wecom.EncodingAESKey) != "" && strings.TrimSpace(body.Encrypt) == "" {
h.logger.Warn("企业微信已配置加密模式但收到明文消息,已拒绝")
c.String(http.StatusOK, "")
return
}
// 加密模式:先解密再解析内层 XML // 加密模式:先解密再解析内层 XML
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" { if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
h.logger.Debug("企业微信进入加密模式解密流程") h.logger.Debug("企业微信进入加密模式解密流程")
+78
View File
@@ -0,0 +1,78 @@
package handler
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func newWecomTestHandler(token string, aesKey string) *RobotHandler {
return &RobotHandler{
config: &config.Config{
Robots: config.RobotsConfig{
Wecom: config.RobotWecomConfig{
Enabled: true,
Token: token,
EncodingAESKey: aesKey,
},
},
},
logger: zap.NewNop(),
}
}
func TestHandleWecomPOST_rejectsWhenTokenEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newWecomTestHandler("", "")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `<?xml version="1.0"?><xml><FromUserName>attacker</FromUserName><MsgType>text</MsgType><Content>hi</Content></xml>`
c.Request = httptest.NewRequest(http.MethodPost, "/api/robot/wecom", strings.NewReader(body))
h.HandleWecomPOST(c)
if w.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", w.Code, http.StatusForbidden)
}
if w.Body.String() == "success" {
t.Fatal("expected rejection, got success")
}
}
func TestHandleWecomPOST_rejectsPlaintextWhenEncryptionConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newWecomTestHandler("secret-token", "abcdefghijklmnopqrstuvwxyz0123456789ABCD")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `<?xml version="1.0"?><xml><FromUserName>attacker</FromUserName><MsgType>text</MsgType><Content>hi</Content></xml>`
c.Request = httptest.NewRequest(http.MethodPost, "/api/robot/wecom?timestamp=1&nonce=2&msg_signature=fake", strings.NewReader(body))
h.HandleWecomPOST(c)
if w.Body.String() == "success" {
t.Fatal("expected rejection for plaintext in encryption mode, got success")
}
}
func TestHandleWecomGET_rejectsWhenTokenEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
h := newWecomTestHandler("", "")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/robot/wecom?msg_signature=x&timestamp=1&nonce=2&echostr=abc", nil)
h.HandleWecomGET(c)
if w.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", w.Code, http.StatusForbidden)
}
}
+58 -2
View File
@@ -26,6 +26,7 @@ func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
// AgentTask 描述正在运行的Agent任务 // AgentTask 描述正在运行的Agent任务
type AgentTask struct { type AgentTask struct {
ConversationID string `json:"conversationId"` ConversationID string `json:"conversationId"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"` Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"` StartedAt time.Time `json:"startedAt"`
Status string `json:"status"` Status string `json:"status"`
@@ -42,6 +43,9 @@ type AgentTask struct {
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果 // activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
activeEinoExecuteAbortNote string activeEinoExecuteAbortNote string
// hitlCognition 本轮运行中供 HITL/审计 Agent 读取的上下文(用户原话 + 思考,不含会话历史)
hitlCognition *hitlCognitionState
cancel func(error) cancel func(error)
} }
@@ -103,6 +107,40 @@ func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
} }
} }
// ConversationIDForActiveMCPExecution 根据当前登记的工具 executionId 反查会话 ID(供 MCP 监控页按 executionId 终止)。
func (m *AgentTaskManager) ConversationIDForActiveMCPExecution(executionID string) string {
executionID = strings.TrimSpace(executionID)
if executionID == "" {
return ""
}
m.mu.Lock()
defer m.mu.Unlock()
for convID, t := range m.tasks {
if t != nil && t.ActiveMCPExecutionID == executionID {
return convID
}
}
return ""
}
// ConversationIDForActiveEinoExecute 返回当前唯一进行 Eino execute 的会话 ID;多会话并行时返回空。
func (m *AgentTaskManager) ConversationIDForActiveEinoExecute() (string, bool) {
m.mu.Lock()
defer m.mu.Unlock()
var found string
count := 0
for convID, t := range m.tasks {
if t != nil && t.activeEinoExecuteCancel != nil {
found = convID
count++
}
}
if count == 1 {
return found, true
}
return "", false
}
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。 // AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool { func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
conversationID = strings.TrimSpace(conversationID) conversationID = strings.TrimSpace(conversationID)
@@ -199,6 +237,7 @@ func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
// CompletedTask 已完成的任务(用于历史记录) // CompletedTask 已完成的任务(用于历史记录)
type CompletedTask struct { type CompletedTask struct {
ConversationID string `json:"conversationId"` ConversationID string `json:"conversationId"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"` Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"` StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"` CompletedAt time.Time `json:"completedAt"`
@@ -213,6 +252,8 @@ type AgentTaskManager struct {
maxHistorySize int // 最大历史记录数 maxHistorySize int // 最大历史记录数
historyRetention time.Duration // 历史记录保留时间 historyRetention time.Duration // 历史记录保留时间
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅 eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
// toolCanceler 在用户整轮停止任务时终止当前 MCP 工具(非「中断并继续」)。
toolCanceler func(conversationID string)
} }
const ( const (
@@ -243,6 +284,13 @@ func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) {
m.eventBus = b m.eventBus = b
} }
// SetToolCanceler 设置整轮停止任务时终止当前 MCP 工具的回调(由 AgentHandler 注入)。
func (m *AgentTaskManager) SetToolCanceler(fn func(conversationID string)) {
m.mu.Lock()
defer m.mu.Unlock()
m.toolCanceler = fn
}
// GetTask 返回运行中任务(无则 nil)。 // GetTask 返回运行中任务(无则 nil)。
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask { func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
m.mu.RLock() m.mu.RLock()
@@ -309,6 +357,7 @@ func (m *AgentTaskManager) StartTask(conversationID, message string, cancel cont
} }
m.tasks[conversationID] = task m.tasks[conversationID] = task
task.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(message)}
return task, nil return task, nil
} }
@@ -338,14 +387,21 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
task.InterruptContinueNote = "" task.InterruptContinueNote = ""
} }
cancel := task.cancel cancel := task.cancel
m.mu.Unlock()
if cause == nil { if cause == nil {
cause = ErrTaskCancelled cause = ErrTaskCancelled
} }
var toolCanceler func(string)
if errors.Is(cause, ErrTaskCancelled) {
toolCanceler = m.toolCanceler
}
m.mu.Unlock()
if cancel != nil { if cancel != nil {
cancel(cause) cancel(cause)
} }
if toolCanceler != nil {
toolCanceler(conversationID)
}
return true, nil return true, nil
} }
@@ -38,3 +38,19 @@ func TestAbortActiveEinoExecute(t *testing.T) {
t.Fatal("second abort should fail when no active execute") t.Fatal("second abort should fail when no active execute")
} }
} }
func TestConversationIDForActiveMCPExecution(t *testing.T) {
m := NewAgentTaskManager()
conv := "conv-mcp-exec"
_, err := m.StartTask(conv, "test", func(error) {})
if err != nil {
t.Fatalf("StartTask: %v", err)
}
m.RegisterRunningTool(conv, "exec-123")
if got := m.ConversationIDForActiveMCPExecution("exec-123"); got != conv {
t.Fatalf("got %q, want %q", got, conv)
}
if got := m.ConversationIDForActiveMCPExecution("missing"); got != "" {
t.Fatalf("missing should be empty, got %q", got)
}
}
@@ -0,0 +1,80 @@
package handler
import (
"context"
"errors"
"testing"
"cyberstrike-ai/internal/multiagent"
)
func TestCancelTaskInvokesToolCancelerOnFullStop(t *testing.T) {
tm := NewAgentTaskManager()
called := false
tm.SetToolCanceler(func(conversationID string) {
if conversationID == "conv-1" {
called = true
}
})
_, cancel := context.WithCancelCause(context.Background())
_, err := tm.StartTask("conv-1", "hello", cancel)
if err != nil {
t.Fatalf("StartTask: %v", err)
}
ok, err := tm.CancelTask("conv-1", ErrTaskCancelled)
if err != nil || !ok {
t.Fatalf("CancelTask: ok=%v err=%v", ok, err)
}
if !called {
t.Fatal("expected tool canceler to be invoked on full task cancel")
}
}
func TestCancelTaskSkipsToolCancelerOnInterruptContinue(t *testing.T) {
tm := NewAgentTaskManager()
called := false
tm.SetToolCanceler(func(conversationID string) {
called = true
})
_, cancel := context.WithCancelCause(context.Background())
_, err := tm.StartTask("conv-1", "hello", cancel)
if err != nil {
t.Fatalf("StartTask: %v", err)
}
ok, err := tm.CancelTask("conv-1", multiagent.ErrInterruptContinue)
if err != nil || !ok {
t.Fatalf("CancelTask: ok=%v err=%v", ok, err)
}
if called {
t.Fatal("tool canceler must not run for interrupt-continue")
}
}
func TestCancelTaskDefaultCauseIsTaskCancelled(t *testing.T) {
tm := NewAgentTaskManager()
var gotCause error
tm.SetToolCanceler(func(conversationID string) {
if conversationID == "conv-2" {
gotCause = ErrTaskCancelled
}
})
ctx, cancel := context.WithCancelCause(context.Background())
if _, err := tm.StartTask("conv-2", "hello", cancel); err != nil {
t.Fatalf("StartTask: %v", err)
}
if _, err := tm.CancelTask("conv-2", nil); err != nil {
t.Fatalf("CancelTask: %v", err)
}
if !errors.Is(context.Cause(ctx), ErrTaskCancelled) {
t.Fatalf("expected ErrTaskCancelled cause, got %v", context.Cause(ctx))
}
if gotCause != ErrTaskCancelled {
t.Fatalf("expected tool canceler path for default cancel cause")
}
}
+16
View File
@@ -0,0 +1,16 @@
//go:build windows
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RunCommandWS 交互式 PTY 终端依赖 Unix PTY(见 terminal_ws_unix.go);Windows 暂不支持。
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
c.JSON(http.StatusNotImplemented, gin.H{
"error": "Interactive WebSocket terminal is not supported on Windows; use POST /terminal/run or /terminal/run/stream instead.",
})
}
+71
View File
@@ -0,0 +1,71 @@
package hitl
import (
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
const retentionPurgeInterval = time.Hour
// Service manages HITL audit log retention (decided hitl_interrupts rows).
type Service struct {
db *database.DB
cfg *config.Config
logger *zap.Logger
}
// NewService creates a HITL audit log retention service.
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
return &Service{db: db, cfg: cfg, logger: logger}
}
// RetentionDays returns configured retention; 0 means keep forever.
func (s *Service) RetentionDays() int {
if s == nil || s.cfg == nil {
return config.HitlConfig{}.RetentionDaysEffective()
}
return s.cfg.Hitl.RetentionDaysEffective()
}
// PurgeExpired deletes decided HITL log rows older than retention_days when configured.
func (s *Service) PurgeExpired() {
if s == nil || s.db == nil || s.cfg == nil {
return
}
days := s.cfg.Hitl.RetentionDaysEffective()
if days <= 0 {
return
}
cutoff := time.Now().AddDate(0, 0, -days)
n, err := s.db.PurgeHitlInterruptLogsBefore(cutoff)
if err != nil {
if s.logger != nil {
s.logger.Warn("清理过期人机协同审计日志失败", zap.Error(err))
}
return
}
if n > 0 && s.logger != nil {
s.logger.Info("已清理过期人机协同审计日志", zap.Int64("deleted", n), zap.Int("retention_days", days))
}
}
// StartRetentionLoop periodically purges expired HITL audit log rows.
func StartRetentionLoop(s *Service, logger *zap.Logger) {
if s == nil {
return
}
go func() {
ticker := time.NewTicker(retentionPurgeInterval)
defer ticker.Stop()
for range ticker.C {
s.PurgeExpired()
if logger != nil {
logger.Debug("hitl audit log retention tick completed")
}
}
}()
}
+50
View File
@@ -0,0 +1,50 @@
package hitl
import (
"path/filepath"
"testing"
"time"
appconfig "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "hitl.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS hitl_interrupts (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
mode TEXT NOT NULL,
tool_name TEXT NOT NULL,
status TEXT NOT NULL,
decision TEXT,
created_at DATETIME NOT NULL,
decided_at DATETIME
)`); err != nil {
t.Fatalf("create table: %v", err)
}
old := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339)
if _, err := db.Exec(`INSERT INTO hitl_interrupts
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
VALUES ('old-1', 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, old, old); err != nil {
t.Fatalf("insert: %v", err)
}
zero := 0
svc := NewService(db, &appconfig.Config{
Hitl: appconfig.HitlConfig{RetentionDays: &zero},
}, zap.NewNop())
svc.PurgeExpired()
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'old-1'`).Scan(new(string)); err != nil {
t.Fatalf("record should remain when retention_days=0: %v", err)
}
}
+17
View File
@@ -814,6 +814,23 @@ func (m *ExternalMCPManager) CancelToolExecution(id string) bool {
return m.CancelToolExecutionWithNote(id, "") return m.CancelToolExecutionWithNote(id, "")
} }
// ActiveRunningExecutionIDs 返回当前进程内仍登记 cancel 的外部 MCP executionId 快照。
func (m *ExternalMCPManager) ActiveRunningExecutionIDs() map[string]struct{} {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
if len(m.runningCancels) == 0 {
return nil
}
out := make(map[string]struct{}, len(m.runningCancels))
for id := range m.runningCancels {
out[id] = struct{}{}
}
return out
}
// updateStats 更新统计信息 // updateStats 更新统计信息
func (m *ExternalMCPManager) updateStats(toolName string, failed bool) { func (m *ExternalMCPManager) updateStats(toolName string, failed bool) {
now := time.Now() now := time.Now()
+100 -16
View File
@@ -921,9 +921,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
return finalResult, executionID, nil return finalResult, executionID, nil
} }
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致), // BeginToolExecution 创建 running 状态的执行记录,供 Eino 等非 CallTool 路径在工具开始时落库。
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。 func (s *Server) BeginToolExecution(toolName string, args map[string]interface{}) string {
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if s == nil { if s == nil {
return "" return ""
} }
@@ -931,21 +930,73 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
args = map[string]interface{}{} args = map[string]interface{}{}
} }
executionID := uuid.New().String() executionID := uuid.New().String()
now := time.Now() execution := &ToolExecution{
failed := invokeErr != nil
exec := &ToolExecution{
ID: executionID, ID: executionID,
ToolName: toolName, ToolName: toolName,
Arguments: args, Arguments: args,
StartTime: now, Status: "running",
EndTime: &now, StartTime: time.Now(),
Duration: 0,
} }
s.mu.Lock()
s.executions[executionID] = execution
s.cleanupOldExecutions()
s.mu.Unlock()
if s.storage != nil {
if err := s.storage.SaveToolExecution(execution); err != nil {
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
}
}
return executionID
}
// FinishToolExecution 完成先前 BeginToolExecution 创建的记录;executionID 为空时等同 RecordCompletedToolInvocation。
func (s *Server) FinishToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if s == nil {
return ""
}
if args == nil {
args = map[string]interface{}{}
}
id := strings.TrimSpace(executionID)
if id == "" {
return s.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
}
now := time.Now()
failed := invokeErr != nil
var finalResult *ToolResult
s.mu.Lock()
exec, inMem := s.executions[id]
if !inMem || exec == nil {
exec = &ToolExecution{
ID: id,
ToolName: toolName,
Arguments: args,
StartTime: now,
}
s.executions[id] = exec
} else if toolName != "" {
exec.ToolName = toolName
}
if len(args) > 0 {
exec.Arguments = args
}
exec.EndTime = &now
if exec.StartTime.IsZero() {
exec.StartTime = now
}
exec.Duration = now.Sub(exec.StartTime)
if failed { if failed {
exec.Status = "failed" st, msg := executionStatusAndMessage(invokeErr)
exec.Error = invokeErr.Error() exec.Status = st
exec.Error = msg
if strings.TrimSpace(resultText) != "" { if strings.TrimSpace(resultText) != "" {
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}} finalResult = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
exec.Result = finalResult
} }
} else { } else {
exec.Status = "completed" exec.Status = "completed"
@@ -953,15 +1004,31 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
if strings.TrimSpace(text) == "" { if strings.TrimSpace(text) == "" {
text = "(无输出)" text = "(无输出)"
} }
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}} finalResult = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
exec.Result = finalResult
} }
s.mu.Unlock()
if s.storage != nil { if s.storage != nil {
if err := s.storage.SaveToolExecution(exec); err != nil { if err := s.storage.SaveToolExecution(exec); err != nil {
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err)) s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
} }
} }
s.updateStats(toolName, failed)
return executionID s.updateStats(exec.ToolName, failed)
if s.storage != nil {
s.mu.Lock()
delete(s.executions, id)
s.mu.Unlock()
}
return id
}
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
return s.FinishToolExecution("", toolName, args, resultText, invokeErr)
} }
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。 // UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
@@ -1103,6 +1170,23 @@ func (s *Server) CancelToolExecution(id string) bool {
return s.CancelToolExecutionWithNote(id, "") return s.CancelToolExecutionWithNote(id, "")
} }
// ActiveRunningExecutionIDs 返回当前进程内仍登记 cancel 的 executionId 快照。
func (s *Server) ActiveRunningExecutionIDs() map[string]struct{} {
if s == nil {
return nil
}
s.runningCancelsMu.Lock()
defer s.runningCancelsMu.Unlock()
if len(s.runningCancels) == 0 {
return nil
}
out := make(map[string]struct{}, len(s.runningCancels))
for id := range s.runningCancels {
out[id] = struct{}{}
}
return out
}
// initDefaultPrompts 初始化默认提示词模板 // initDefaultPrompts 初始化默认提示词模板
func (s *Server) initDefaultPrompts() { func (s *Server) initDefaultPrompts() {
s.mu.Lock() s.mu.Lock()
+2
View File
@@ -199,6 +199,8 @@ type ToolExecution struct {
StartTime time.Time `json:"startTime"` StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"` EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"` Duration time.Duration `json:"duration,omitempty"`
// ConversationID 仅 API 展示用(进行中的 Agent 任务),不写入 tool_executions 表。
ConversationID string `json:"conversationId,omitempty"`
} }
// ToolStats 工具统计信息 // ToolStats 工具统计信息
+101
View File
@@ -0,0 +1,101 @@
package monitor
import (
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
const (
staleRunningMinAge = 45 * time.Second
staleRunningReconcileGap = 2 * time.Minute
)
// ExecutionReconciler 在启动或运行期将无对应协程的 running 执行记录收尾为 cancelled。
type ExecutionReconciler struct {
db *database.DB
mcpServer *mcp.Server
externalMgr *mcp.ExternalMCPManager
logger *zap.Logger
}
// NewExecutionReconciler creates a reconciler for orphaned MCP tool executions.
func NewExecutionReconciler(db *database.DB, mcpServer *mcp.Server, externalMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ExecutionReconciler {
return &ExecutionReconciler{
db: db,
mcpServer: mcpServer,
externalMgr: externalMgr,
logger: logger,
}
}
// ReconcileOnStartup marks every persisted running row as cancelled (safe right after process start).
func (r *ExecutionReconciler) ReconcileOnStartup() {
if r == nil || r.db == nil {
return
}
now := time.Now()
n, err := r.db.CancelOrphanedRunningToolExecutions(now, "执行已中断(服务重启)")
if err != nil {
if r.logger != nil {
r.logger.Warn("启动时清理孤儿 running 工具执行记录失败", zap.Error(err))
}
return
}
if n > 0 && r.logger != nil {
r.logger.Info("启动时已收尾孤儿 running 工具执行记录", zap.Int64("count", n))
}
}
func (r *ExecutionReconciler) activeExecutionIDs() map[string]struct{} {
ids := make(map[string]struct{})
if r.mcpServer != nil {
for id := range r.mcpServer.ActiveRunningExecutionIDs() {
ids[id] = struct{}{}
}
}
if r.externalMgr != nil {
for id := range r.externalMgr.ActiveRunningExecutionIDs() {
ids[id] = struct{}{}
}
}
return ids
}
// ReconcileStaleRunning finalizes running rows that are not tracked in-memory and older than staleRunningMinAge.
func (r *ExecutionReconciler) ReconcileStaleRunning() {
if r == nil || r.db == nil {
return
}
now := time.Now()
n, err := r.db.FinalizeStaleRunningToolExecutions(now, staleRunningMinAge, r.activeExecutionIDs(), "执行已中断(会话已结束)")
if err != nil {
if r.logger != nil {
r.logger.Warn("定期收尾 stale running 工具执行记录失败", zap.Error(err))
}
return
}
if n > 0 && r.logger != nil {
r.logger.Info("已收尾 stale running 工具执行记录", zap.Int64("count", n))
}
}
// StartStaleRunningReconcileLoop periodically reconciles orphaned running tool executions.
func StartStaleRunningReconcileLoop(r *ExecutionReconciler, logger *zap.Logger) {
if r == nil {
return
}
go func() {
ticker := time.NewTicker(staleRunningReconcileGap)
defer ticker.Stop()
for range ticker.C {
r.ReconcileStaleRunning()
if logger != nil {
logger.Debug("monitor stale running reconcile tick completed")
}
}
}()
}
+38
View File
@@ -0,0 +1,38 @@
package monitor
import (
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestExecutionReconciler_ReconcileOnStartup(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
if err := db.SaveToolExecution(&mcp.ToolExecution{
ID: "run-1", ToolName: "hydra", Status: "running", StartTime: time.Now().Add(-time.Hour),
}); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
r := NewExecutionReconciler(db, mcp.NewServer(zap.NewNop()), nil, zap.NewNop())
r.ReconcileOnStartup()
got, err := db.GetToolExecution("run-1")
if err != nil {
t.Fatalf("GetToolExecution: %v", err)
}
if got.Status != "cancelled" {
t.Fatalf("expected cancelled after startup reconcile, got %s", got.Status)
}
}
+16
View File
@@ -0,0 +1,16 @@
package multiagent
import (
"fmt"
"github.com/cloudwego/eino/adk"
)
// InitADK configures global Eino ADK settings. Call once at process startup before
// any ADK middleware or agents are created.
func InitADK() error {
if err := adk.SetLanguage(adk.LanguageChinese); err != nil {
return fmt.Errorf("adk set language: %w", err)
}
return nil
}
+145 -55
View File
@@ -18,6 +18,7 @@ import (
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/einoobserve" "cyberstrike-ai/internal/einoobserve"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
@@ -90,7 +91,7 @@ type einoADKRunLoopArgs struct {
FilesystemMonitorRecord einomcp.ExecutionRecorder FilesystemMonitorRecord einomcp.ExecutionRecorder
MCPExecutionBinder *MCPExecutionBinder MCPExecutionBinder *MCPExecutionBinder
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。 // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Setexecute/MCP 桥 Fire 时立即推送 tool_resultADK 晚到经 toolResultSent 去重)
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
DA adk.Agent DA adk.Agent
@@ -196,6 +197,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
pendingByID[tc.ToolCallID] = tc pendingByID[tc.ToolCallID] = tc
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID) pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
} }
markPendingWithMonitor := func(tc toolCallPendingInfo) {
markPending(tc)
beginEinoADKFilesystemToolMonitor(
args.FilesystemMonitorAgent,
args.FilesystemMonitorRecord,
args.MCPExecutionBinder,
tc.ToolCallID,
tc.ToolName,
)
}
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) { popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
pendingMu.Lock() pendingMu.Lock()
defer pendingMu.Unlock() defer pendingMu.Unlock()
@@ -288,6 +299,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
var toolResultSent sync.Map // toolCallID -> struct{}ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文) var toolResultSent sync.Map // toolCallID -> struct{}ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) { tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
// 仅由 ADK schema.Tool 事件调用;MCP/execute 桥在 reduction 前的 ToolInvokeNotify 不得推送 tool_result
// 否则全量输出会先占位并触发 toolResultSent 去重,导致 UI/监控展示与 agent 实际收到的截断正文不一致。
if progress == nil { if progress == nil {
return return
} }
@@ -305,6 +318,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"isError": isErr, "isError": isErr,
"result": content, "result": content,
"resultPreview": preview, "resultPreview": preview,
"agentFacing": true, // 与 reduction 后送入 ChatModel 的正文一致,供前端展示
"conversationId": conversationID, "conversationId": conversationID,
"einoAgent": agentName, "einoAgent": agentName,
"einoRole": einoRoleTag(agentName), "einoRole": einoRoleTag(agentName),
@@ -331,7 +345,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
toolCallID = tid toolCallID = tid
} }
recordPendingExecuteStdoutDup(toolName, content, isErr) recordPendingExecuteStdoutDup(toolName, content, isErr)
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr) recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, args.MCPExecutionBinder, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil { if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" { if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content) args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
@@ -339,12 +353,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data) progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
} }
if args.ToolInvokeNotify != nil {
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
removePendingByID(strings.TrimSpace(toolCallID))
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
})
}
if args.EinoCallbacks != nil { if args.EinoCallbacks != nil {
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{ ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
@@ -539,6 +547,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
return true, nil return true, nil
} }
// 仅在退避重试后真正收到数据/完成一步时清零,避免重启后首个无错 ADK 事件误把计数打回 0。
confirmTransientRetryRecovery := func() {
if transientRetrier.attempt() > 0 {
transientRetrier.reset()
}
}
takePartial := func(runErr error) (*RunResult, error) { takePartial := func(runErr error) (*RunResult, error) {
if len(runAccumulatedMsgs) <= baseAccumulatedCount { if len(runAccumulatedMsgs) <= baseAccumulatedCount {
return nil, runErr return nil, runErr
@@ -551,10 +566,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
for { for {
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中" // iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending
select { ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
case <-ctx.Done(): if iterCtxErr != nil {
flushAllPendingAsFailed(ctx.Err()) flushAllPendingAsFailed(iterCtxErr)
if progress != nil { if progress != nil {
if isInterruptContinue(ctx) { if isInterruptContinue(ctx) {
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{ progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
@@ -563,17 +578,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"kind": "interrupt_continue", "kind": "interrupt_continue",
}) })
} else { } else {
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{ progress("error", iterCtxErr.Error(), map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
}) })
} }
} }
return takePartial(ctx.Err()) return takePartial(iterCtxErr)
default:
} }
ev, ok := iter.Next()
if !ok { if !ok {
// iter 结束并不总是“正常完成”: // iter 结束并不总是“正常完成”:
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。 // 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
@@ -627,8 +639,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if restarted { if restarted {
continue continue
} }
} else {
transientRetrier.reset()
} }
if ev.AgentName != "" && progress != nil { if ev.AgentName != "" && progress != nil {
iterEinoAgent := orchestratorName iterEinoAgent := orchestratorName
@@ -691,34 +701,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool { if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
toolName := strings.TrimSpace(mv.ToolName) toolName := strings.TrimSpace(mv.ToolName)
var toolBuf strings.Builder content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
streamToolCallID := "" isErr := einoToolResultIsError(toolName, content)
var toolStreamRecvErr error content = einoToolResultBody(content)
for {
chunk, rerr := mv.MessageStream.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
toolStreamRecvErr = rerr
break
}
if chunk == nil {
continue
}
if chunk.Content != "" {
toolBuf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
streamToolCallID = tid
}
}
content := toolBuf.String()
isErr := false
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
isErr = true
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
if streamToolCallID != "" { if streamToolCallID != "" {
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)} opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...)) runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
@@ -730,6 +715,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
zap.String("agent", ev.AgentName), zap.String("agent", ev.AgentName),
zap.String("tool", toolName)) zap.String("tool", toolName))
} }
if toolStreamRecvErr == nil {
confirmTransientRetryRecovery()
}
continue continue
} }
@@ -977,7 +965,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 { if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged}) lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
} }
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending) tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。 // 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 { if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls)) runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
@@ -1001,6 +989,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if restarted { if restarted {
continue continue
} }
} else {
confirmTransientRetryRecovery()
} }
continue continue
} }
@@ -1010,7 +1000,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue continue
} }
runAccumulatedMsgs = append(runAccumulatedMsgs, msg) runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending) tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
if mv.Role == schema.Assistant { if mv.Role == schema.Assistant {
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" { if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
@@ -1085,15 +1075,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
content := msg.Content content := msg.Content
isErr := false isErr := einoToolResultIsError(toolName, content)
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) { content = einoToolResultBody(content)
isErr = true
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
toolCallID := strings.TrimSpace(msg.ToolCallID) toolCallID := strings.TrimSpace(msg.ToolCallID)
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName) tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
} }
confirmTransientRetryRecovery()
} }
mcpIDsMu.Lock() mcpIDsMu.Lock()
@@ -1121,17 +1109,119 @@ func einoPartialRunLastOutputHint() string {
"[Run ended abnormally; continue from the trace above without repeating completed steps.]" "[Run ended abnormally; continue from the trace above without repeating completed steps.]"
} }
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本 // friendlyEinoExecuteInvokeTail 将 Eino execute 超时/中断/流异常转为简短提示
// 命令非零退出(ExecuteExitError)已有 exec 对齐的正文,不再追加「执行未正常结束」。
func friendlyEinoExecuteInvokeTail(invokeErr error) string { func friendlyEinoExecuteInvokeTail(invokeErr error) string {
if invokeErr == nil { if invokeErr == nil {
return "" return ""
} }
var exitErr *ExecuteExitError
if errors.As(invokeErr, &exitErr) {
return ""
}
if errors.Is(invokeErr, context.DeadlineExceeded) { if errors.Is(invokeErr, context.DeadlineExceeded) {
return einoExecuteTimeoutUserHint() return einoExecuteTimeoutUserHint()
} }
if errors.Is(invokeErr, context.Canceled) {
return ""
}
if strings.Contains(invokeErr.Error(), "shell inactivity timeout") {
return ""
}
return "[执行未正常结束] " + invokeErr.Error() return "[执行未正常结束] " + invokeErr.Error()
} }
// einoToolResultIsError 统一判断 Eino 工具结果是否应标记为错误(与 MCP exec 的 IsError 对齐)。
func einoToolResultIsError(toolName, content string) bool {
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
return true
}
if strings.TrimSpace(toolName) == "execute" && security.IsCommandFailureResult(content) {
return true
}
return false
}
// einoToolResultBody 去掉工具错误前缀,返回展示/持久化正文。
func einoToolResultBody(content string) string {
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
return strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
return content
}
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
if iter == nil {
return nil, false, nil
}
type nextRes struct {
ev *adk.AgentEvent
ok bool
}
ch := make(chan nextRes, 1)
go func() {
e, o := iter.Next()
ch <- nextRes{e, o}
}()
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case res := <-ch:
return res.ev, res.ok, nil
}
}
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
if stream == nil {
return "", "", nil
}
type streamMsg struct {
chunk *schema.Message
err error
}
recvCh := make(chan streamMsg, 8)
go func() {
defer close(recvCh)
for {
ch, rerr := stream.Recv()
recvCh <- streamMsg{chunk: ch, err: rerr}
if rerr != nil {
return
}
}
}()
var buf strings.Builder
for {
select {
case <-ctx.Done():
return buf.String(), toolCallID, ctx.Err()
case sm, open := <-recvCh:
if !open {
return buf.String(), toolCallID, nil
}
rerr := sm.err
if errors.Is(rerr, io.EOF) {
return buf.String(), toolCallID, nil
}
if rerr != nil {
return buf.String(), toolCallID, rerr
}
chunk := sm.chunk
if chunk == nil {
continue
}
if chunk.Content != "" {
buf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
toolCallID = tid
}
}
}
}
func buildEinoRunResultFromAccumulated( func buildEinoRunResultFromAccumulated(
orchMode string, orchMode string,
runAccumulatedMsgs []adk.Message, runAccumulatedMsgs []adk.Message,
@@ -0,0 +1,74 @@
package multiagent
import (
"context"
"errors"
"io"
"testing"
"time"
"github.com/cloudwego/eino/schema"
)
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
sw.Close()
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if content != "hello" {
t.Fatalf("content=%q want hello", content)
}
if tid != "tc-1" {
t.Fatalf("toolCallID=%q want tc-1", tid)
}
}
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
t.Cleanup(func() { sw.Close() })
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
content, _, err := recvSchemaMessageStream(ctx, sr)
if !errors.Is(err, context.Canceled) {
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
}
}
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
want := errors.New("stream broken")
_ = sw.Send(nil, want)
sw.Close()
_, _, err := recvSchemaMessageStream(context.Background(), sr)
if !errors.Is(err, want) {
t.Fatalf("want %v, got %v", want, err)
}
}
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
if err != nil || content != "" || tid != "" {
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
}
}
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
_ = sw.Send(nil, io.EOF)
sw.Close()
_, _, err := recvSchemaMessageStream(context.Background(), sr)
if err != nil {
t.Fatalf("EOF should not surface as error, got %v", err)
}
}
@@ -0,0 +1,59 @@
package multiagent
import (
"strings"
"time"
"cyberstrike-ai/internal/config"
)
const defaultEmptyResponseContinueMaxAttempts = 5
// IsEinoEmptyResponseResult 判断 Run 是否以「未捕获助手正文」占位结束(非真实用户可见回复)。
func IsEinoEmptyResponseResult(result *RunResult) bool {
if result == nil {
return false
}
return isEinoEmptyResponseText(result.Response)
}
func isEinoEmptyResponseText(s string) bool {
s = strings.TrimSpace(s)
if s == "" {
return false
}
return strings.Contains(s, "no assistant text was captured") ||
strings.Contains(s, "未捕获到助手文本输出")
}
// HasEinoResumeTrace 轨迹非空,续跑才有上下文可恢复。
func HasEinoResumeTrace(result *RunResult) bool {
if result == nil {
return false
}
s := strings.TrimSpace(result.LastAgentTraceInput)
return s != "" && s != "[]" && s != "null"
}
// EmptyResponseContinueMaxAttemptsFromConfig 无助手正文时 Handler 层退避续跑上限;0=默认 5。
func EmptyResponseContinueMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
if mw != nil && mw.EmptyResponseContinueMaxAttempts > 0 {
return mw.EmptyResponseContinueMaxAttempts
}
return defaultEmptyResponseContinueMaxAttempts
}
// EmptyResponseContinueBackoff 与 run_retry 相同指数退避(2s, 4s, 8s… capped)。
func EmptyResponseContinueBackoff(attempt int, mw *config.MultiAgentEinoMiddlewareConfig) time.Duration {
maxBackoff := defaultEinoRunRetryMaxBackoff
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
}
return einoTransientRetryBackoff(attempt, maxBackoff)
}
// FormatEmptyResponseContinueUserMessage 系统自动续跑时注入的 user 轮次(不写入 messages 表气泡)。
func FormatEmptyResponseContinueUserMessage() string {
return strings.TrimSpace(`系统自动续跑 / Auto resume
上一轮 Eino 会话未产出可见助手正文可能流式中断或仅完成工具调用请基于已有轨迹与工具结果继续推进并给出阶段性总结勿重复已完成步骤`)
}
@@ -0,0 +1,38 @@
package multiagent
import "testing"
func TestIsEinoEmptyResponseResult(t *testing.T) {
empty := &RunResult{
Response: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
"Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
}
if !IsEinoEmptyResponseResult(empty) {
t.Fatal("expected empty placeholder response")
}
ok := &RunResult{Response: "扫描完成,发现 2 个开放端口。"}
if IsEinoEmptyResponseResult(ok) {
t.Fatalf("expected real response, got placeholder match")
}
if IsEinoEmptyResponseResult(nil) {
t.Fatal("nil result should be false")
}
}
func TestHasEinoResumeTrace(t *testing.T) {
if HasEinoResumeTrace(nil) {
t.Fatal("nil")
}
if HasEinoResumeTrace(&RunResult{LastAgentTraceInput: "[]"}) {
t.Fatal("enable resume on empty trace")
}
if !HasEinoResumeTrace(&RunResult{LastAgentTraceInput: `[{"role":"user","content":"hi"}]`}) {
t.Fatal("expected resume trace")
}
}
func TestEmptyResponseContinueMaxAttemptsFromConfig(t *testing.T) {
if got := EmptyResponseContinueMaxAttemptsFromConfig(nil); got != defaultEmptyResponseContinueMaxAttempts {
t.Fatalf("default: got %d want %d", got, defaultEmptyResponseContinueMaxAttempts)
}
}
@@ -0,0 +1,114 @@
package multiagent
import (
"context"
"errors"
"io"
"strings"
"testing"
"cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/schema"
)
type mockStreamingShellExitFail struct {
output string
code int
}
func (m *mockStreamingShellExitFail) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
defer outW.Close()
if m.output != "" {
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
}
code := m.code
_ = outW.Send(&filesystem.ExecuteResponse{ExitCode: &code}, nil)
}()
return outR, nil
}
func TestEinoStreamingShellWrap_CommandFailureFormat(t *testing.T) {
inner := &mockStreamingShellExitFail{
output: "sudo: a password is required\n",
code: 1,
}
notify := einomcp.NewToolInvokeNotifyHolder()
var firedBody string
var firedSuccess bool
var firedErr error
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
firedBody = content
firedSuccess = success
firedErr = invokeErr
})
wrap := &einoStreamingShellWrap{inner: inner, invokeNotify: notify}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var stream strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil {
stream.WriteString(resp.Output)
}
}
if firedSuccess {
t.Fatal("expected success=false")
}
var exitErr *ExecuteExitError
if !errors.As(firedErr, &exitErr) || exitErr.Code != 1 {
t.Fatalf("expected ExecuteExitError code 1, got %v", firedErr)
}
if !strings.HasPrefix(firedBody, einomcp.ToolErrorPrefix) {
t.Fatalf("missing tool error prefix: %q", firedBody)
}
body := strings.TrimPrefix(firedBody, einomcp.ToolErrorPrefix)
if body != security.FormatCommandFailureResult(1, "sudo: a password is required\n") {
t.Fatalf("fire body = %q", body)
}
if !strings.Contains(stream.String(), "sudo:") {
t.Fatalf("stream missing sudo output: %q", stream.String())
}
if strings.Contains(stream.String(), "command exited with non-zero") {
t.Fatalf("stream has legacy noise: %q", stream.String())
}
if strings.Contains(stream.String(), "执行未正常结束") {
t.Fatalf("stream has abnormal tail: %q", stream.String())
}
if !security.IsCommandFailureResult(stream.String()) {
t.Fatalf("stream missing failure status line: %q", stream.String())
}
if tail := friendlyEinoExecuteInvokeTail(firedErr); tail != "" {
t.Fatalf("unexpected invoke tail: %q", tail)
}
if !einoToolResultIsError("execute", firedBody) {
t.Fatal("expected isError for execute failure")
}
}
func TestFriendlyEinoExecuteInvokeTail(t *testing.T) {
if friendlyEinoExecuteInvokeTail(&ExecuteExitError{Code: 1}) != "" {
t.Fatal("exit error should not get abnormal tail")
}
if !strings.Contains(friendlyEinoExecuteInvokeTail(context.DeadlineExceeded), "Timed out") {
t.Fatal("deadline should get timeout hint")
}
if friendlyEinoExecuteInvokeTail(errors.New("broken pipe")) == "" {
t.Fatal("unexpected error should get tail")
}
}
+22 -7
View File
@@ -7,11 +7,25 @@ import (
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
) )
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId) // newEinoExecuteMonitorCallbacks 在 Eino filesystem execute 开始/结束时写入 MCP 监控库并 recorder(executionId)
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片 // 与 CallTool 路径一致,使监控页能展示「执行中」状态
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) { func newEinoExecuteMonitorCallbacks(ag *agent.Agent, recorder einomcp.ExecutionRecorder) (
return func(toolCallID, command, stdout string, success bool, invokeErr error) { begin func(toolCallID, command string) string,
if ag == nil || recorder == nil { finish func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
) {
begin = func(toolCallID, command string) string {
if ag == nil {
return ""
}
args := map[string]interface{}{"command": command}
id := ag.BeginLocalToolExecution("execute", args)
if id != "" && recorder != nil {
recorder(id, toolCallID)
}
return id
}
finish = func(executionID, toolCallID, command, stdout string, success bool, invokeErr error) {
if ag == nil {
return return
} }
var err error var err error
@@ -23,9 +37,10 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
} }
} }
args := map[string]interface{}{"command": command} args := map[string]interface{}{"command": command}
id := ag.RecordLocalToolExecution("execute", args, stdout, err) id := ag.FinishLocalToolExecution(executionID, "execute", args, stdout, err)
if id != "" { if id != "" && recorder != nil && executionID == "" {
recorder(id, toolCallID) recorder(id, toolCallID)
} }
} }
return begin, finish
} }
@@ -51,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。 // 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
// //
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire // 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重) // run loop 收到 Fire 后立即推送 tool_resulttoolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」
// //
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire // 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。 // 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
@@ -63,8 +63,11 @@ type einoStreamingShellWrap struct {
outputChunk func(toolName, toolCallID, chunk string) outputChunk func(toolName, toolCallID, chunk string)
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。 // toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
toolTimeoutMinutes int toolTimeoutMinutes int
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致 // shellNoOutputTimeoutSec:无任何输出时的空闲秒数;0=关闭
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error) shellNoOutputTimeoutSec int
// beginMonitor 在 execute 开始时写入 running 状态;finishMonitor 在流结束后更新为 completed/failed。
beginMonitor func(toolCallID, command string) string
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error)
} }
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
@@ -76,15 +79,26 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
} }
req := *input req := *input
userCmd := strings.TrimSpace(req.Command) userCmd := strings.TrimSpace(req.Command)
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
agentTag := strings.TrimSpace(w.einoAgentName)
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround { if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
req.RunInBackendGround = true req.RunInBackendGround = true
} }
req.Command = prependPythonUnbufferedEnv(req.Command) req.Command = prependPythonUnbufferedEnv(req.Command)
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
agentTag := strings.TrimSpace(w.einoAgentName)
convID := mcp.MCPConversationIDFromContext(ctx) convID := mcp.MCPConversationIDFromContext(ctx)
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx) execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
var monitorExecID string
if w.beginMonitor != nil {
monitorExecID = w.beginMonitor(tid, userCmd)
}
if monitorExecID != "" && convID != "" {
if toolReg := mcp.ToolRunRegistryFromContext(ctx); toolReg != nil {
toolReg.RegisterRunningTool(convID, monitorExecID)
}
}
toolRunReg := mcp.ToolRunRegistryFromContext(ctx)
execCtx, execCancel := context.WithCancel(ctx) execCtx, execCancel := context.WithCancel(ctx)
var timeoutCancel context.CancelFunc var timeoutCancel context.CancelFunc
if w.toolTimeoutMinutes > 0 { if w.toolTimeoutMinutes > 0 {
@@ -104,23 +118,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
} }
if einoExecuteRecvErrIsToolTimeout(err, execCtx) { if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n" hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
if w.recordMonitor != nil { if w.finishMonitor != nil {
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded) w.finishMonitor(monitorExecID, tid, userCmd, hint, false, context.DeadlineExceeded)
} }
if w.invokeNotify != nil && tid != "" { if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded) w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
} }
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
} }
if w.recordMonitor != nil { if w.finishMonitor != nil {
w.recordMonitor(tid, userCmd, "", false, err) w.finishMonitor(monitorExecID, tid, userCmd, "", false, err)
} }
if w.invokeNotify != nil && tid != "" { if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err) w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
} }
return nil, err return nil, err
} }
if sr == nil || w.invokeNotify == nil { if sr == nil {
if timeoutCancel != nil { if timeoutCancel != nil {
timeoutCancel() timeoutCancel()
} }
@@ -132,7 +146,7 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32) outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) { go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry, toolReg mcp.ToolRunRegistry, execID string, toolCallID string, noOutputSec int) {
var innerCloseOnce sync.Once var innerCloseOnce sync.Once
closeInner := func() { closeInner := func() {
innerCloseOnce.Do(func() { inner.Close() }) innerCloseOnce.Do(func() { inner.Close() })
@@ -147,6 +161,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
if reg != nil && conversationID != "" { if reg != nil && conversationID != "" {
defer reg.UnregisterActiveEinoExecute(conversationID) defer reg.UnregisterActiveEinoExecute(conversationID)
} }
if toolReg != nil && conversationID != "" && execID != "" {
defer toolReg.UnregisterRunningTool(conversationID, execID)
}
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。 // ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
stopWatch := make(chan struct{}) stopWatch := make(chan struct{})
@@ -165,50 +182,103 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
exitCode := 0 exitCode := 0
hasExitCode := false hasExitCode := false
idleWatch := security.NewShellInactivityWatch(noOutputSec)
if idleWatch != nil {
defer idleWatch.Stop()
}
type execRecvMsg struct {
resp *filesystem.ExecuteResponse
err error
}
recvCh := make(chan execRecvMsg, 1)
go func() {
for {
resp, rerr := inner.Recv()
recvCh <- execRecvMsg{resp: resp, err: rerr}
if rerr != nil {
return
}
}
}()
fireInactivityTimeout := func() {
success = false
invokeErr = fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
msg := security.ShellNoOutputTimeoutMessage(idleWatch.Sec)
_ = outW.Send(&filesystem.ExecuteResponse{Output: msg}, nil)
sb.WriteString(msg)
if w.outputChunk != nil && toolCallID != "" {
w.outputChunk("execute", toolCallID, msg)
}
if cancel != nil {
cancel()
}
closeInner()
}
recvLoop:
for { for {
resp, rerr := inner.Recv() var idleCh <-chan struct{}
if errors.Is(rerr, io.EOF) { if idleWatch != nil {
break idleCh = idleWatch.Expired
} }
if rerr != nil { select {
success = false case <-idleCh:
invokeErr = rerr fireInactivityTimeout()
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。 break recvLoop
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) { case msg := <-recvCh:
invokeErr = context.DeadlineExceeded rerr := msg.err
break resp := msg.resp
if errors.Is(rerr, io.EOF) {
break recvLoop
} }
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) { if rerr != nil {
invokeErr = context.Canceled
break
}
_ = outW.Send(nil, rerr)
break
}
if resp != nil {
if resp.ExitCode != nil {
hasExitCode = true
exitCode = *resp.ExitCode
}
var appended string
if resp.Output != "" {
sb.WriteString(resp.Output)
appended = resp.Output
}
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
w.outputChunk("execute", tid, appended)
}
if outW.Send(resp, nil) {
success = false success = false
invokeErr = fmt.Errorf("execute stream closed by consumer") invokeErr = rerr
break if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
invokeErr = context.DeadlineExceeded
break recvLoop
}
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
invokeErr = context.Canceled
break recvLoop
}
_ = outW.Send(nil, rerr)
break recvLoop
}
if resp != nil {
if resp.ExitCode != nil {
hasExitCode = true
exitCode = *resp.ExitCode
continue
}
var appended string
if resp.Output != "" {
if security.IsLegacyShellExitNoise(resp.Output) {
continue
}
if idleWatch != nil {
idleWatch.Bump()
}
sb.WriteString(resp.Output)
appended = resp.Output
}
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
w.outputChunk("execute", toolCallID, appended)
}
if outW.Send(resp, nil) {
success = false
invokeErr = fmt.Errorf("execute stream closed by consumer")
break recvLoop
}
} }
} }
} }
if success && hasExitCode && exitCode != 0 { if success && hasExitCode && exitCode != 0 {
success = false success = false
invokeErr = fmt.Errorf("execute exited with code %d", exitCode) invokeErr = &ExecuteExitError{Code: exitCode}
} }
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。 // WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。 // 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
@@ -248,12 +318,24 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil) _ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
} }
} }
if w.recordMonitor != nil { rawOutput := sb.String()
w.recordMonitor(tid, command, sb.String(), success, invokeErr) fireBody := rawOutput
if !success && hasExitCode && exitCode != 0 {
statusLine := security.ExecuteFailureStatusLine(exitCode)
if !strings.Contains(rawOutput, "命令执行失败:") {
_ = outW.Send(&filesystem.ExecuteResponse{Output: statusLine}, nil)
sb.WriteString(statusLine)
}
fireBody = einomcp.ToolErrorPrefix + security.FormatCommandFailureResult(exitCode, rawOutput)
}
if w.finishMonitor != nil {
w.finishMonitor(execID, toolCallID, command, sb.String(), success, invokeErr)
}
if w.invokeNotify != nil {
w.invokeNotify.Fire(toolCallID, "execute", agentTag, success, fireBody, invokeErr)
} }
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
outW.Close() outW.Close()
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg) }(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg, toolRunReg, monitorExecID, tid, w.shellNoOutputTimeoutSec)
return outR, nil return outR, nil
} }
@@ -19,9 +19,15 @@ type mockStreamingShell struct {
immediateErr error immediateErr error
recvErr error recvErr error
output string output string
called bool
lastCommand string
} }
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
m.called = true
if input != nil {
m.lastCommand = input.Command
}
if m.immediateErr != nil { if m.immediateErr != nil {
return nil, m.immediateErr return nil, m.immediateErr
} }
@@ -38,6 +44,129 @@ func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesy
return outR, nil return outR, nil
} }
func TestEinoStreamingShellWrap_PreparesNonInteractiveCommand(t *testing.T) {
inner := &mockStreamingShell{output: "ok\n"}
wrap := &einoStreamingShellWrap{inner: inner}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo ok"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
for {
_, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
}
if !strings.Contains(inner.lastCommand, "PYTHONUNBUFFERED=1") {
t.Fatalf("missing python unbuffer in inner command: %q", inner.lastCommand)
}
}
func TestEinoStreamingShellWrap_NoOutputTimeout(t *testing.T) {
inner := &mockStreamingShellHanging{}
notify := einomcp.NewToolInvokeNotifyHolder()
var fired string
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
fired = content
})
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
shellNoOutputTimeoutSec: 1,
}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil {
got.WriteString(resp.Output)
}
}
if !inner.called {
t.Fatal("inner shell should run (no command blacklist)")
}
out := got.String()
if !strings.Contains(out, "没有新的输出") && !strings.Contains(out, "no new output") {
t.Fatalf("expected inactivity timeout message, got: %q notify=%q", out, fired)
}
}
type mockStreamingShellPartialThenHang struct {
called bool
}
func (m *mockStreamingShellPartialThenHang) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
m.called = true
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
_ = outW.Send(&filesystem.ExecuteResponse{Output: "[sudo] password:\n"}, nil)
<-ctx.Done()
outW.Close()
}()
return outR, nil
}
func TestEinoStreamingShellWrap_InactivityAfterPartialOutput(t *testing.T) {
inner := &mockStreamingShellPartialThenHang{}
wrap := &einoStreamingShellWrap{
inner: inner,
shellNoOutputTimeoutSec: 1,
}
start := time.Now()
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil {
got.WriteString(resp.Output)
}
}
if time.Since(start) > 5*time.Second {
t.Fatalf("expected inactivity timeout ~1s, took %v", time.Since(start))
}
if !strings.Contains(got.String(), "没有新的输出") && !strings.Contains(got.String(), "no new output") {
t.Fatalf("expected inactivity message, got: %q", got.String())
}
}
type mockStreamingShellHanging struct {
called bool
}
func (m *mockStreamingShellHanging) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
m.called = true
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
<-ctx.Done()
outW.Close()
}()
return outR, nil
}
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) { func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel() defer cancel()
@@ -63,10 +63,43 @@ func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName
return map[string]interface{}{} return map[string]interface{}{}
} }
// beginEinoADKFilesystemToolMonitor 在 Eino ADK filesystem 工具开始调用时写入 running 状态。
func beginEinoADKFilesystemToolMonitor(
ag *agent.Agent,
rec einomcp.ExecutionRecorder,
binder *MCPExecutionBinder,
toolCallID, toolName string,
) {
if ag == nil || rec == nil {
return
}
name := strings.TrimSpace(toolName)
if name == "" || strings.EqualFold(name, "execute") {
return
}
if !isBuiltinEinoADKFilesystemToolName(name) {
return
}
tid := strings.TrimSpace(toolCallID)
if tid == "" {
return
}
storedName := "eino_fs::" + strings.ToLower(name)
id := ag.BeginLocalToolExecution(storedName, map[string]interface{}{})
if id == "" {
return
}
rec(id, tid)
if binder != nil {
binder.Bind(tid, id)
}
}
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。 // recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
func recordEinoADKFilesystemToolMonitor( func recordEinoADKFilesystemToolMonitor(
ag *agent.Agent, ag *agent.Agent,
rec einomcp.ExecutionRecorder, rec einomcp.ExecutionRecorder,
binder *MCPExecutionBinder,
toolName string, toolName string,
toolCallID string, toolCallID string,
msgs []adk.Message, msgs []adk.Message,
@@ -94,8 +127,12 @@ func recordEinoADKFilesystemToolMonitor(
invErr = errors.New(t) invErr = errors.New(t)
} }
} }
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr) execID := ""
if id != "" { if binder != nil {
execID = binder.ExecutionID(toolCallID)
}
id := ag.FinishLocalToolExecution(execID, storedName, args, resultText, invErr)
if id != "" && execID == "" {
rec(id, toolCallID) rec(id, toolCallID)
} }
} }
+3 -2
View File
@@ -81,7 +81,7 @@ func RunEinoSingleChatModelAgent(
} }
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder() toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder) einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
mainDefs := ag.ToolsForRole(roleTools) mainDefs := ag.ToolsForRole(roleTools)
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName) mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
if err != nil { if err != nil {
@@ -136,7 +136,7 @@ func RunEinoSingleChatModelAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil) fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr) return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
} }
@@ -184,6 +184,7 @@ func RunEinoSingleChatModelAgent(
Name: einoSingleAgentName, Name: einoSingleAgentName,
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.", Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
Instruction: ins, Instruction: ins,
GenModelInput: literalInstructionGenModelInput,
Model: mainModel, Model: mainModel,
ToolsConfig: mainToolsCfg, ToolsConfig: mainToolsCfg,
MaxIterations: maxIter, MaxIterations: maxIter,
+27 -7
View File
@@ -9,6 +9,7 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/security"
localbk "github.com/cloudwego/eino-ext/adk/backend/local" localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -81,8 +82,10 @@ func subAgentFilesystemMiddleware(
loc *localbk.Local, loc *localbk.Local,
invokeNotify *einomcp.ToolInvokeNotifyHolder, invokeNotify *einomcp.ToolInvokeNotifyHolder,
einoAgentName string, einoAgentName string,
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error), beginMonitor func(toolCallID, command string) string,
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
toolTimeoutMinutes int, toolTimeoutMinutes int,
shellNoOutputTimeoutSec int,
outputChunk func(toolName, toolCallID, chunk string), outputChunk func(toolName, toolCallID, chunk string),
) (adk.ChatModelAgentMiddleware, error) { ) (adk.ChatModelAgentMiddleware, error) {
if loc == nil { if loc == nil {
@@ -91,12 +94,14 @@ func subAgentFilesystemMiddleware(
return filesystem.New(ctx, &filesystem.MiddlewareConfig{ return filesystem.New(ctx, &filesystem.MiddlewareConfig{
Backend: loc, Backend: loc,
StreamingShell: &einoStreamingShellWrap{ StreamingShell: &einoStreamingShellWrap{
inner: loc, inner: security.NewEinoStreamingShell(),
invokeNotify: invokeNotify, invokeNotify: invokeNotify,
einoAgentName: strings.TrimSpace(einoAgentName), einoAgentName: strings.TrimSpace(einoAgentName),
outputChunk: outputChunk, outputChunk: outputChunk,
recordMonitor: recordMonitor, beginMonitor: beginMonitor,
toolTimeoutMinutes: toolTimeoutMinutes, finishMonitor: finishMonitor,
toolTimeoutMinutes: toolTimeoutMinutes,
shellNoOutputTimeoutSec: shellNoOutputTimeoutSec,
}, },
}) })
} }
@@ -108,3 +113,18 @@ func agentToolTimeoutMinutes(cfg *config.Config) int {
} }
return cfg.Agent.ToolTimeoutMinutes return cfg.Agent.ToolTimeoutMinutes
} }
// agentShellNoOutputTimeoutSeconds0=默认 300s5 分钟);-1=关闭;>0=自定义秒数。
func agentShellNoOutputTimeoutSeconds(cfg *config.Config) int {
if cfg == nil {
return 300
}
v := cfg.Agent.ShellNoOutputTimeoutSeconds
if v < 0 {
return 0
}
if v == 0 {
return 300
}
return v
}
+31
View File
@@ -150,6 +150,7 @@ func newEinoSummarizationMiddleware(
} }
if appCfg != nil { if appCfg != nil {
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger) out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
out = refreshUserVerbatimAnchorInMessages(out, db, conversationID, appCfg.MultiAgent.UserVerbatimAnchorMaxRunesEffective(), logger)
} }
return out, nil return out, nil
}, },
@@ -413,6 +414,36 @@ func writeSummarizationTranscript(path string, msgs []adk.Message) error {
return nil return nil
} }
// refreshUserVerbatimAnchorInMessages 压缩后从 messages 表刷新 system 中的用户原文锚点。
func refreshUserVerbatimAnchorInMessages(msgs []adk.Message, db *database.DB, conversationID string, maxRunes int, logger *zap.Logger) []adk.Message {
if maxRunes < 0 || db == nil {
return msgs
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return msgs
}
rows, err := db.GetMessages(conversationID)
if err != nil {
if logger != nil {
logger.Warn("summarization: 刷新用户原文锚点失败",
zap.String("conversationId", conversationID),
zap.Error(err),
)
}
return msgs
}
block := project.BuildUserVerbatimAnchorBlockFromMessages(rows, maxRunes)
if block == "" {
return msgs
}
out := project.RefreshUserVerbatimAnchorInMessages(msgs, block)
if logger != nil {
logger.Info("summarization: 已刷新用户原文锚点", zap.String("conversationId", conversationID))
}
return out
}
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
tc := agent.NewTikTokenCounter() tc := agent.NewTikTokenCounter()
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
+2 -25
View File
@@ -1,35 +1,12 @@
package multiagent package multiagent
import ( import (
"github.com/bytedance/sonic" copenai "cyberstrike-ai/internal/openai"
) )
// stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a // stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a
// chat-completions JSON body. Applied only to summarization Generate calls via // chat-completions JSON body. Applied only to summarization Generate calls via
// model.ModelOptions on the shared ChatModel — main-agent requests are unchanged. // model.ModelOptions on the shared ChatModel — main-agent requests are unchanged.
func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) { func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) {
var payload map[string]any return copenai.StripReasoningFromChatCompletionBody(rawBody)
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
return rawBody, nil
}
changed := false
for _, key := range []string{
"thinking",
"reasoning_effort",
"output_config",
"reasoning",
} {
if _, ok := payload[key]; ok {
delete(payload, key)
changed = true
}
}
if !changed {
return rawBody, nil
}
out, err := sonic.Marshal(payload)
if err != nil {
return rawBody, err
}
return out, nil
} }
+5 -5
View File
@@ -409,9 +409,9 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
"需要写入请使用 upsert_project_fact。", "需要写入请使用 upsert_project_fact。",
project.FactIndexSectionEndMarker, project.FactIndexSectionEndMarker,
"", "",
"# Skills System", transcriptSkillsSystemMarker,
"**How to Use Skills**", "**如何使用 Skill(技能)(渐进式展示):**",
"Remember: Skills make you more capable", "记住:Skill 让你更加强大和稳定",
}, "\n") }, "\n")
out := sanitizeSystemContentForTranscript(system) out := sanitizeSystemContentForTranscript(system)
@@ -421,7 +421,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") { if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
t.Fatalf("static persona should be stripped: %q", out) t.Fatalf("static persona should be stripped: %q", out)
} }
if strings.Contains(out, "# Skills System") || strings.Contains(out, "How to Use Skills") { if strings.Contains(out, transcriptSkillsSystemMarker) || strings.Contains(out, "如何使用 Skill") {
t.Fatalf("skills boilerplate should be stripped: %q", out) t.Fatalf("skills boilerplate should be stripped: %q", out)
} }
if !strings.Contains(out, transcriptStaticSystemOmitNote) { if !strings.Contains(out, transcriptStaticSystemOmitNote) {
@@ -435,7 +435,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) { func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
t.Parallel() t.Parallel()
msgs := []adk.Message{ msgs := []adk.Message{
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n# Skills System\nboiler"), schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n" + transcriptSkillsSystemMarker + "\nboiler"),
schema.UserMessage("hello"), schema.UserMessage("hello"),
schema.AssistantMessage("reply", nil), schema.AssistantMessage("reply", nil),
} }
@@ -20,7 +20,9 @@ const (
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]" transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引" transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
transcriptPersonaStartMarker = "你是CyberStrikeAI" transcriptPersonaStartMarker = "你是CyberStrikeAI"
transcriptSkillsSystemMarker = "# Skills System" // ADK LanguageChinese injects skill middleware prompt with this header (see eino adk/middlewares/skill/prompt.go).
transcriptSkillsSystemMarker = "# Skill 系统"
transcriptSkillsSystemMarkerEnglish = "# Skills System"
) )
type transcriptToolCall struct { type transcriptToolCall struct {
@@ -86,13 +88,23 @@ func stripToolNamesIndexFromSystem(s string) string {
} }
func stripSkillsSystemBoilerplate(s string) string { func stripSkillsSystemBoilerplate(s string) string {
idx := strings.Index(s, transcriptSkillsSystemMarker) idx := indexFirstSubstring(s, transcriptSkillsSystemMarker, transcriptSkillsSystemMarkerEnglish)
if idx < 0 { if idx < 0 {
return strings.TrimSpace(s) return strings.TrimSpace(s)
} }
return strings.TrimSpace(s[:idx]) return strings.TrimSpace(s[:idx])
} }
func indexFirstSubstring(s string, markers ...string) int {
first := -1
for _, m := range markers {
if i := strings.Index(s, m); i >= 0 && (first < 0 || i < first) {
first = i
}
}
return first
}
func extractProjectBlackboardSection(s string) string { func extractProjectBlackboardSection(s string) string {
start := strings.Index(s, project.FactIndexSectionStartMarker) start := strings.Index(s, project.FactIndexSectionStartMarker)
if start < 0 { if start < 0 {
@@ -46,6 +46,10 @@ func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, too
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n") sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
sb.WriteString("3) 不要臆造不存在的工具名。\n\n") sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
} }
if s := strings.TrimSpace(injectShellToolGuidance("", names)); s != "" {
sb.WriteString(s)
sb.WriteString("\n\n")
}
if s := strings.TrimSpace(instruction); s != "" { if s := strings.TrimSpace(instruction); s != "" {
sb.WriteString(s) sb.WriteString(s)
} }
+13 -16
View File
@@ -143,7 +143,7 @@ func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts } func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。 // reset 在退避重试后成功推进(流/消息完整接收)时清零计数,使后续临时错误从第 1 次退避重新开始。
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 } func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int { func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
@@ -190,29 +190,26 @@ func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated [
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
} }
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。 // adkMessagesHasUserContent reports whether the conversation tail is already a user turn
// with the given content. Only the last message counts: matching text in an earlier round
// (e.g. user repeats the same prompt after an assistant reply) must not suppress appending
// the new user turn — Claude 4.6+ rejects requests whose final message is assistant.
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool { func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
want = strings.TrimSpace(want) want = strings.TrimSpace(want)
if want == "" { if want == "" {
return true return true
} }
for i := len(msgs) - 1; i >= 0; i-- { if len(msgs) == 0 {
m := msgs[i] return false
if m == nil {
continue
}
if m.Role == schema.User {
return strings.TrimSpace(m.Content) == want
}
if m.Role == schema.Assistant || m.Role == schema.Tool {
continue
}
break
} }
return false last := msgs[len(msgs)-1]
if last == nil || last.Role != schema.User {
return false
}
return strings.TrimSpace(last.Content) == want
} }
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。 // appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当尾部已是相同 user 句)。
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message { func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) { if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
return msgs return msgs
@@ -105,6 +105,32 @@ func TestEinoTransientRunRetrierReset(t *testing.T) {
} }
} }
func TestEinoTransientRunRetrierConsecutiveFailures(t *testing.T) {
t.Parallel()
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
ctx := context.Background()
runErr := errors.New("internal server error")
args := &einoADKRunLoopArgs{}
base := []adk.Message{schema.UserMessage("hi")}
for want := 1; want <= 3; want++ {
restarted, _, _, _, err := r.tryRetry(ctx, runErr, args, base, nil, len(base))
if err != nil {
t.Fatalf("tryRetry attempt %d: %v", want, err)
}
if !restarted {
t.Fatalf("tryRetry attempt %d: want restarted", want)
}
if got := r.attempt(); got != want {
t.Fatalf("after failure %d: attempt=%d, want %d", want, got, want)
}
}
r.reset()
if r.attempt() != 0 {
t.Fatalf("after successful recovery reset: attempt=%d, want 0", r.attempt())
}
}
func TestAppendUserMessageIfNeeded(t *testing.T) { func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Parallel() t.Parallel()
msgs := []adk.Message{schema.UserMessage("old task")} msgs := []adk.Message{schema.UserMessage("old task")}
@@ -117,3 +143,18 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Fatalf("should not duplicate user message: len=%d", len(dup)) t.Fatalf("should not duplicate user message: len=%d", len(dup))
} }
} }
func TestAppendUserMessageIfNeeded_repeatPromptAfterAssistant(t *testing.T) {
t.Parallel()
msgs := []adk.Message{
schema.UserMessage("扫描 example.com"),
schema.AssistantMessage("开始扫描...", nil),
}
out := appendUserMessageIfNeeded(msgs, "扫描 example.com")
if len(out) != 3 {
t.Fatalf("should append new user turn after assistant reply: len=%d", len(out))
}
if out[2].Role != schema.User || out[2].Content != "扫描 example.com" {
t.Fatalf("tail should be repeated user prompt, got role=%s content=%q", out[2].Role, out[2].Content)
}
}
+15
View File
@@ -0,0 +1,15 @@
package multiagent
import "fmt"
// ExecuteExitError 表示 execute 命令非零退出(预期失败,非超时/中断/流异常)。
type ExecuteExitError struct {
Code int
}
func (e *ExecuteExitError) Error() string {
if e == nil {
return "exit status unknown"
}
return fmt.Sprintf("exit status %d", e.Code)
}
+23
View File
@@ -0,0 +1,23 @@
package multiagent
import (
"context"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
// literalInstructionGenModelInput passes Instruction through as a system message without
// FString template formatting. Eino defaultGenModelInput formats instruction whenever
// SessionValues exist; prompts with literal curly braces (project blackboard "{关系边: ...}",
// JSON examples, link syntax) then fail with "could not find key".
//
// Matches eino/adk/prebuilt/deep genModelInput — the supported fix per Eino docs.
func literalInstructionGenModelInput(ctx context.Context, instruction string, input *adk.AgentInput) ([]adk.Message, error) {
msgs := make([]adk.Message, 0, len(input.Messages)+1)
if instruction != "" {
msgs = append(msgs, schema.SystemMessage(instruction))
}
msgs = append(msgs, input.Messages...)
return msgs, nil
}
@@ -0,0 +1,33 @@
package multiagent
import (
"context"
"strings"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func TestLiteralInstructionGenModelInput_PreservesLiteralCurlyBraces(t *testing.T) {
t.Parallel()
instruction := "- [finding/x] summary {关系边: discovered_on←target/dev}\n" +
"如 finding 上 {from:target/*, type:discovered_on}"
msgs, err := literalInstructionGenModelInput(context.Background(), instruction, &adk.AgentInput{
Messages: []adk.Message{schema.UserMessage("继续")},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(msgs) != 2 {
t.Fatalf("expected 2 messages, got %d", len(msgs))
}
if msgs[0].Role != schema.System {
t.Fatalf("first message must be system, got %s", msgs[0].Role)
}
for _, want := range []string{"{关系边:", "{from:target/*, type:discovered_on}"} {
if !strings.Contains(msgs[0].Content, want) {
t.Fatalf("system content missing %q: %q", want, msgs[0].Content)
}
}
}
+3 -5
View File
@@ -3,7 +3,6 @@ package multiagent
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -75,8 +74,8 @@ func hitlInvokableToolCallMiddleware() compose.InvokableToolMiddleware {
if err != nil { if err != nil {
if IsHumanRejectError(err) { if IsHumanRejectError(err) {
// Human rejection should be a soft tool result so the model can continue iterating. // Human rejection should be a soft tool result so the model can continue iterating.
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.", // tool_search 须保持 JSON,否则 Eino toolsearch 中间件解析历史时会硬崩 ChatModel。
input.Name, strings.TrimSpace(err.Error())) msg := HitlRejectToolResult(input.Name, err.Error())
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END, // transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具, // 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。 // 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
@@ -103,8 +102,7 @@ func hitlStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
edited, err := fn(ctx, input.Name, input.Arguments) edited, err := fn(ctx, input.Name, input.Arguments)
if err != nil { if err != nil {
if IsHumanRejectError(err) { if IsHumanRejectError(err) {
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.", msg := HitlRejectToolResult(input.Name, err.Error())
input.Name, strings.TrimSpace(err.Error()))
hitlClearReturnDirectlyIfTransfer(ctx, input.Name) hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
return &compose.StreamToolOutput{ return &compose.StreamToolOutput{
Result: schema.StreamReaderFromArray([]string{msg}), Result: schema.StreamReaderFromArray([]string{msg}),
@@ -0,0 +1,85 @@
package multiagent
import (
"encoding/json"
"fmt"
"strings"
)
const toolSearchToolName = "tool_search"
// HitlExemptMetaTools 为编排/元工具:不直接执行攻击动作,但会阻塞 agent 控制流。
// tool_search 必须免审批,否则其 HITL 拒绝结果与 Eino toolsearch 中间件不兼容(会硬崩 ChatModel)。
var HitlExemptMetaTools = []string{
toolSearchToolName,
"skill",
"task",
"write_todos",
"transfer_to_agent",
"exit",
"TaskCreate",
"TaskGet",
"TaskUpdate",
"TaskList",
}
// IsToolSearchTool reports whether name is the Eino dynamictool tool_search meta-tool.
func IsToolSearchTool(name string) bool {
return strings.EqualFold(strings.TrimSpace(name), toolSearchToolName)
}
// MergeHitlExemptMetaTools unions configured whitelist with built-in meta-tool exemptions.
func MergeHitlExemptMetaTools(configured []string) []string {
merged := make([]string, 0, len(configured)+len(HitlExemptMetaTools))
seen := make(map[string]struct{}, len(configured)+len(HitlExemptMetaTools))
add := func(name string) {
n := strings.ToLower(strings.TrimSpace(name))
if n == "" {
return
}
if _, ok := seen[n]; ok {
return
}
seen[n] = struct{}{}
merged = append(merged, strings.TrimSpace(name))
}
for _, t := range configured {
add(t)
}
for _, t := range HitlExemptMetaTools {
add(t)
}
return merged
}
type toolSearchHitlRejectPayload struct {
SelectedTools []string `json:"selectedTools"`
HitlRejected bool `json:"_hitlRejected"`
Reason string `json:"reason"`
}
// HitlRejectToolResult returns a tool result body safe for downstream consumers.
// tool_search must stay JSON-shaped so toolsearch.extractSelectedTools does not terminate the graph.
func HitlRejectToolResult(toolName, reason string) string {
reason = strings.TrimSpace(reason)
if !IsToolSearchTool(toolName) {
if reason == "" {
reason = "rejected by reviewer"
}
return fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
strings.TrimSpace(toolName), reason)
}
payload := toolSearchHitlRejectPayload{
SelectedTools: []string{},
HitlRejected: true,
Reason: reason,
}
if payload.Reason == "" {
payload.Reason = "tool_search rejected by reviewer; no dynamic tools unlocked"
}
out, err := json.Marshal(payload)
if err != nil {
return `{"selectedTools":[],"_hitlRejected":true,"reason":"tool_search rejected by reviewer"}`
}
return string(out)
}
@@ -0,0 +1,48 @@
package multiagent
import (
"encoding/json"
"strings"
"testing"
)
func TestHitlRejectToolResult_toolSearchIsJSON(t *testing.T) {
raw := HitlRejectToolResult("tool_search", "rejected by user: timeout")
var payload toolSearchHitlRejectPayload
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(payload.SelectedTools) != 0 {
t.Fatalf("expected empty selectedTools, got %v", payload.SelectedTools)
}
if !payload.HitlRejected {
t.Fatal("expected _hitlRejected true")
}
if !strings.Contains(payload.Reason, "timeout") {
t.Fatalf("reason=%q", payload.Reason)
}
}
func TestHitlRejectToolResult_otherToolKeepsLegacyText(t *testing.T) {
raw := HitlRejectToolResult("nmap", "too risky")
if strings.HasPrefix(raw, "{") {
t.Fatalf("expected legacy text, got %q", raw)
}
if !strings.HasPrefix(raw, "[HITL Reject]") {
t.Fatalf("expected [HITL Reject] prefix, got %q", raw)
}
}
func TestMergeHitlExemptMetaTools_includesToolSearch(t *testing.T) {
merged := MergeHitlExemptMetaTools([]string{"read_file"})
found := false
for _, name := range merged {
if IsToolSearchTool(name) {
found = true
break
}
}
if !found {
t.Fatalf("tool_search missing from %v", merged)
}
}
@@ -6,6 +6,7 @@ import (
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/projectprompt"
) )
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。 // DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
@@ -122,7 +123,9 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
## 表达 ## 表达
在调用工具或给出计划变更前 25 句中文说明当前决策依据与期望证据形态最终对用户交付结构化结论发现摘要证据风险下一步` 在调用工具或给出计划变更前 25 句中文说明当前决策依据与期望证据形态最终对用户交付结构化结论发现摘要证据风险下一步
` + projectprompt.ShellExecExecuteGuidanceSection()
} }
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。 // DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
+44 -16
View File
@@ -20,6 +20,7 @@ import (
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/security"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai" einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -120,7 +121,7 @@ func RunDeepAgent(
mcpIDs = append(mcpIDs, id) mcpIDs = append(mcpIDs, id)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
} }
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder) einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。 // 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
snapshotMCPIDs := func() []string { snapshotMCPIDs := func() []string {
@@ -223,7 +224,7 @@ func RunDeepAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil) subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr) return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
} }
@@ -253,10 +254,11 @@ func RunDeepAgent(
) )
} }
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
Name: id, Name: id,
Description: desc, Description: desc,
Instruction: subInstrFinal, Instruction: subInstrFinal,
Model: subModel, GenModelInput: literalInstructionGenModelInput,
Model: subModel,
ToolsConfig: adk.ToolsConfig{ ToolsConfig: adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{
Tools: subToolsForCfg, Tools: subToolsForCfg,
@@ -358,19 +360,28 @@ func RunDeepAgent(
if einoLoc != nil && einoFSTools { if einoLoc != nil && einoFSTools {
deepBackend = einoLoc deepBackend = einoLoc
deepShell = &einoStreamingShellWrap{ deepShell = &einoStreamingShellWrap{
inner: einoLoc, inner: security.NewEinoStreamingShell(),
invokeNotify: toolInvokeNotify, invokeNotify: toolInvokeNotify,
einoAgentName: orchestratorName, einoAgentName: orchestratorName,
outputChunk: nil, outputChunk: nil,
recordMonitor: einoExecMonitor, beginMonitor: einoExecBegin,
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg), finishMonitor: einoExecFinish,
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
shellNoOutputTimeoutSec: agentShellNoOutputTimeoutSeconds(appCfg),
} }
} }
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。 // noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()} deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
taskEnrichExtra := systemPromptExtra var taskBlackboardSupplement string
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil { if appCfg.Project.Enabled && db != nil {
if pid := strings.TrimSpace(projectID); pid != "" {
if block, err := project.BuildFactIndexBlock(db, pid, appCfg.Project); err == nil {
taskBlackboardSupplement = strings.TrimSpace(block)
}
}
}
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunesEffective(), taskBlackboardSupplement); mw != nil {
deepHandlers = append(deepHandlers, mw) deepHandlers = append(deepHandlers, mw)
} }
if len(mainOrchestratorPre) > 0 { if len(mainOrchestratorPre) > 0 {
@@ -421,6 +432,22 @@ func RunDeepAgent(
var da adk.Agent var da adk.Agent
switch orchMode { switch orchMode {
case "plan_execute": case "plan_execute":
plannerModelCfg := &einoopenai.ChatModelConfig{
APIKey: appCfg.OpenAI.APIKey,
BaseURL: strings.TrimSuffix(appCfg.OpenAI.BaseURL, "/"),
Model: appCfg.OpenAI.Model,
HTTPClient: httpClient,
}
reasoning.ApplyPlanExecutePlannerModelConfig(plannerModelCfg, &appCfg.OpenAI)
peMainModel, perr := einoopenai.NewChatModel(ctx, plannerModelCfg)
if perr != nil {
return nil, fmt.Errorf("plan_execute 规划模型: %w", perr)
}
if logger != nil {
logger.Info("plan_execute: planner/replanner 使用无 reasoning 的独立 ChatModelToolChoiceForced 兼容)",
zap.String("model", appCfg.OpenAI.Model),
)
}
execModel, perr := einoopenai.NewChatModel(ctx, baseModelCfg) execModel, perr := einoopenai.NewChatModel(ctx, baseModelCfg)
if perr != nil { if perr != nil {
return nil, fmt.Errorf("plan_execute 执行器模型: %w", perr) return nil, fmt.Errorf("plan_execute 执行器模型: %w", perr)
@@ -428,13 +455,13 @@ func RunDeepAgent(
// 构建 filesystem 中间件(与 Deep sub-agent 一致) // 构建 filesystem 中间件(与 Deep sub-agent 一致)
var peFsMw adk.ChatModelAgentMiddleware var peFsMw adk.ChatModelAgentMiddleware
if einoSkillMW != nil && einoFSTools && einoLoc != nil { if einoSkillMW != nil && einoFSTools && einoLoc != nil {
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil) peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err) return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
} }
} }
peRoot, perr := NewPlanExecuteRoot(ctx, &PlanExecuteRootArgs{ peRoot, perr := NewPlanExecuteRoot(ctx, &PlanExecuteRootArgs{
MainToolCallingModel: mainModel, MainToolCallingModel: peMainModel,
ExecModel: execModel, ExecModel: execModel,
OrchInstruction: orchInstruction, OrchInstruction: orchInstruction,
ToolsCfg: mainToolsCfg, ToolsCfg: mainToolsCfg,
@@ -469,6 +496,7 @@ func RunDeepAgent(
Name: orchestratorName, Name: orchestratorName,
Description: orchDescription, Description: orchDescription,
Instruction: supInstr, Instruction: supInstr,
GenModelInput: literalInstructionGenModelInput,
Model: mainModel, Model: mainModel,
ToolsConfig: mainToolsCfg, ToolsConfig: mainToolsCfg,
MaxIterations: deepMaxIter, MaxIterations: deepMaxIter,
@@ -0,0 +1,33 @@
package multiagent
import (
"strings"
"cyberstrike-ai/internal/projectprompt"
)
func shellToolsPresent(toolNames []string) bool {
for _, n := range toolNames {
switch strings.ToLower(strings.TrimSpace(n)) {
case "exec", "execute":
return true
}
}
return false
}
// injectShellToolGuidance 在系统提示末尾追加 exec/execute 分工(仅当工具列表含 exec 或 execute)。
func injectShellToolGuidance(instruction string, toolNames []string) string {
if !shellToolsPresent(toolNames) {
return instruction
}
block := strings.TrimSpace(projectprompt.ShellExecExecuteGuidanceSection())
if block == "" {
return instruction
}
s := strings.TrimSpace(instruction)
if s == "" {
return block
}
return s + "\n\n" + block
}
@@ -0,0 +1,17 @@
package multiagent
import (
"strings"
"testing"
)
func TestInjectShellToolGuidance(t *testing.T) {
got := injectShellToolGuidance("base", []string{"nmap"})
if got != "base" {
t.Fatalf("expected unchanged, got %q", got)
}
got = injectShellToolGuidance("base", []string{"exec", "nmap"})
if !strings.Contains(got, "exec/execute") || !strings.Contains(got, "base") {
t.Fatalf("expected shell guidance appended, got %q", got)
}
}
+12 -9
View File
@@ -3,6 +3,7 @@ package multiagent
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
@@ -11,7 +12,7 @@ import (
"github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/components/tool"
) )
const defaultSubAgentUserContextMaxRunes = 2000 const userContextSupplementHeader = "\n\n## 用户历史输入(原文,子代理必读)\n"
// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator // taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator
// and appends the user's original conversation messages to the task description. // and appends the user's original conversation messages to the task description.
@@ -30,13 +31,14 @@ type taskContextEnrichMiddleware struct {
// newTaskContextEnrichMiddleware returns a middleware that enriches task // newTaskContextEnrichMiddleware returns a middleware that enriches task
// descriptions with user conversation context. Returns nil if disabled // descriptions with user conversation context. Returns nil if disabled
// (maxRunes < 0) or no user messages exist. // (maxRunes < 0) or no user messages exist.
// projectBlackboard 仅传项目黑板索引块(BuildFactIndexBlock);勿传完整 systemPromptExtra。
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware { func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
supplement := buildUserContextSupplement(userMessage, history, maxRunes) supplement := buildUserContextSupplement(userMessage, history, maxRunes)
if bb := strings.TrimSpace(projectBlackboard); bb != "" { if bb := strings.TrimSpace(projectBlackboard); bb != "" {
if supplement != "" { if supplement != "" {
supplement += "\n\n## 项目黑板索引\n" + bb supplement += "\n\n" + bb
} else { } else {
supplement = "\n\n## 项目黑板索引\n" + bb supplement = "\n\n" + bb
} }
} }
if supplement == "" { if supplement == "" {
@@ -86,9 +88,6 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
if maxRunes < 0 { if maxRunes < 0 {
return "" return ""
} }
if maxRunes == 0 {
maxRunes = defaultSubAgentUserContextMaxRunes
}
var userMsgs []string var userMsgs []string
for _, h := range history { for _, h := range history {
@@ -107,12 +106,16 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
return "" return ""
} }
joined := strings.Join(userMsgs, "\n---\n") lines := make([]string, 0, len(userMsgs))
if len([]rune(joined)) > maxRunes { for i, msg := range userMsgs {
lines = append(lines, fmt.Sprintf("[第%d轮] %s", i+1, msg))
}
joined := strings.Join(lines, "\n")
if maxRunes > 0 && len([]rune(joined)) > maxRunes {
joined = truncateKeepFirstLast(userMsgs, maxRunes) joined = truncateKeepFirstLast(userMsgs, maxRunes)
} }
return "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + joined return userContextSupplementHeader + joined
} }
// truncateKeepFirstLast keeps the first and last user messages, giving each // truncateKeepFirstLast keeps the first and last user messages, giving each
@@ -74,7 +74,7 @@ func TestBuildUserContextSupplement_DisabledByNegative(t *testing.T) {
func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) { func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) {
msg := strings.Repeat("A", 200) msg := strings.Repeat("A", 200)
result := buildUserContextSupplement(msg, nil, 50) result := buildUserContextSupplement(msg, nil, 50)
header := "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" header := userContextSupplementHeader
body := strings.TrimPrefix(result, header) body := strings.TrimPrefix(result, header)
if len([]rune(body)) > 50 { if len([]rune(body)) > 50 {
t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body))) t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body)))
@@ -89,7 +89,7 @@ func TestBuildUserContextSupplement_TruncateKeepsFirstAndLast(t *testing.T) {
history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)}) history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)})
} }
last := "最后一条指令" last := "最后一条指令"
result := buildUserContextSupplement(last, history, 0) result := buildUserContextSupplement(last, history, 800)
if !strings.Contains(result, "http://target.com") { if !strings.Contains(result, "http://target.com") {
t.Error("first message (target URL) should survive truncation") t.Error("first message (target URL) should survive truncation")
} }
+6 -3
View File
@@ -806,10 +806,12 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool {
// Eino HTTP Client Bridge // Eino HTTP Client Bridge
// ============================================================ // ============================================================
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含层 transport 包装: // NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含层 transport 包装:
// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明 // 1. 当 cfg.Provider 为 claude 时,套 claudeRoundTripper,把 OpenAI /chat/completions 透明
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。 // 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行 // 2. reasoningToolChoiceCompatRoundTrippertool_choice=required/object 时剥离 thinking 字段,避免
// plan_execute replanner 等强制工具调用与推理模式冲突(部分网关返回 400)。
// 3. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai // (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。 // SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
// //
@@ -825,6 +827,7 @@ func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client
if transport == nil { if transport == nil {
transport = http.DefaultTransport transport = http.DefaultTransport
} }
transport = &reasoningToolChoiceCompatRoundTripper{base: transport}
if isClaudeProvider(cfg) { if isClaudeProvider(cfg) {
transport = &claudeRoundTripper{ transport = &claudeRoundTripper{
base: transport, base: transport,
+79
View File
@@ -0,0 +1,79 @@
package openai
import (
"github.com/bytedance/sonic"
)
// reasoningPayloadKeys are OpenAI-compatible root fields that enable "thinking" /
// extended-reasoning modes on gateways such as DashScope/Qwen and MiniMax.
var reasoningPayloadKeys = []string{
"thinking",
"reasoning_effort",
"output_config",
"reasoning",
}
// StripReasoningFromChatCompletionBody removes thinking / reasoning fields from a
// chat-completions JSON body.
func StripReasoningFromChatCompletionBody(rawBody []byte) ([]byte, error) {
var payload map[string]any
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
return rawBody, nil
}
if !stripReasoningFields(payload) {
return rawBody, nil
}
out, err := sonic.Marshal(payload)
if err != nil {
return rawBody, err
}
return out, nil
}
// StripReasoningIfForcedToolChoice removes thinking / reasoning fields when the
// request sets tool_choice to "required" or an object. Several providers reject
// that combination (e.g. DashScope: "tool_choice does not support being set to
// required or object in thinking mode").
func StripReasoningIfForcedToolChoice(rawBody []byte) ([]byte, error) {
var payload map[string]any
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
return rawBody, nil
}
if !forcedToolChoiceIncompatibleWithThinking(payload) {
return rawBody, nil
}
if !stripReasoningFields(payload) {
return rawBody, nil
}
out, err := sonic.Marshal(payload)
if err != nil {
return rawBody, err
}
return out, nil
}
func stripReasoningFields(payload map[string]any) bool {
changed := false
for _, key := range reasoningPayloadKeys {
if _, ok := payload[key]; ok {
delete(payload, key)
changed = true
}
}
return changed
}
func forcedToolChoiceIncompatibleWithThinking(payload map[string]any) bool {
tc, ok := payload["tool_choice"]
if !ok || tc == nil {
return false
}
switch v := tc.(type) {
case string:
return v == "required"
case map[string]any:
return true
default:
return false
}
}
+120
View File
@@ -0,0 +1,120 @@
package openai
import (
"io"
"net/http"
"strings"
"testing"
)
func TestStripReasoningFromChatCompletionBody(t *testing.T) {
in := []byte(`{"model":"deepseek-chat","messages":[],"thinking":{"type":"enabled"},"reasoning_effort":"high"}`)
out, err := StripReasoningFromChatCompletionBody(in)
if err != nil {
t.Fatal(err)
}
s := string(out)
if strings.Contains(s, "thinking") || strings.Contains(s, "reasoning_effort") {
t.Fatalf("expected reasoning fields stripped, got %s", s)
}
if !strings.Contains(s, `"model":"deepseek-chat"`) {
t.Fatalf("expected model preserved, got %s", s)
}
plain := []byte(`{"model":"gpt-4o","messages":[]}`)
out2, err := StripReasoningFromChatCompletionBody(plain)
if err != nil {
t.Fatal(err)
}
if string(out2) != string(plain) {
t.Fatalf("expected unchanged payload, got %s", out2)
}
}
func TestStripReasoningIfForcedToolChoice(t *testing.T) {
cases := []struct {
name string
in string
strip bool
contain string
}{
{
name: "required strips thinking",
in: `{"model":"minimax","messages":[],"thinking":{"type":"enabled"},"tool_choice":"required","tools":[]}`,
strip: true,
},
{
name: "object tool_choice strips thinking",
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"},"tool_choice":{"type":"function","function":{"name":"respond"}}}`,
strip: true,
},
{
name: "auto keeps thinking",
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"},"tool_choice":"auto"}`,
strip: false,
contain: "thinking",
},
{
name: "no tool_choice keeps thinking",
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"}}`,
strip: false,
contain: "thinking",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
out, err := StripReasoningIfForcedToolChoice([]byte(tc.in))
if err != nil {
t.Fatal(err)
}
s := string(out)
hasThinking := strings.Contains(s, "thinking")
if tc.strip && hasThinking {
t.Fatalf("expected thinking stripped, got %s", s)
}
if !tc.strip && tc.contain != "" && !strings.Contains(s, tc.contain) {
t.Fatalf("expected %q in %s", tc.contain, s)
}
if !tc.strip && string(out) != tc.in {
t.Fatalf("expected unchanged payload, got %s", s)
}
})
}
}
func TestReasoningToolChoiceCompatRoundTripper(t *testing.T) {
var gotBody string
rt := &reasoningToolChoiceCompatRoundTripper{
base: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
b, _ := io.ReadAll(req.Body)
gotBody = string(b)
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)),
Header: http.Header{"Content-Type": []string{"application/json"}},
}, nil
}),
}
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1/chat/completions", strings.NewReader(
`{"model":"m","thinking":{"type":"enabled"},"tool_choice":"required","messages":[]}`,
))
if err != nil {
t.Fatal(err)
}
_, err = rt.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if strings.Contains(gotBody, "thinking") {
t.Fatalf("expected thinking stripped in transit, got %s", gotBody)
}
if !strings.Contains(gotBody, `"tool_choice":"required"`) {
t.Fatalf("expected tool_choice preserved, got %s", gotBody)
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
@@ -0,0 +1,43 @@
package openai
import (
"bytes"
"io"
"net/http"
"strconv"
"strings"
)
// reasoningToolChoiceCompatRoundTripper strips thinking/reasoning fields from
// chat/completions requests that force tool_choice, which some gateways reject
// when thinking mode is enabled on the same request.
type reasoningToolChoiceCompatRoundTripper struct {
base http.RoundTripper
}
func (rt *reasoningToolChoiceCompatRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if rt == nil || rt.base == nil || req == nil || req.Body == nil {
if rt != nil && rt.base != nil {
return rt.base.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
if req.Method != http.MethodPost || !strings.HasSuffix(req.URL.Path, "/chat/completions") {
return rt.base.RoundTrip(req)
}
body, err := io.ReadAll(req.Body)
_ = req.Body.Close()
if err != nil {
return nil, err
}
patched, perr := StripReasoningIfForcedToolChoice(body)
if perr != nil {
patched = body
}
req.Body = io.NopCloser(bytes.NewReader(patched))
req.ContentLength = int64(len(patched))
req.Header.Set("Content-Length", strconv.Itoa(len(patched)))
return rt.base.RoundTrip(req)
}
+170
View File
@@ -0,0 +1,170 @@
package project
import (
"fmt"
"strings"
"cyberstrike-ai/internal/database"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
const (
// UserVerbatimSectionHeading 用户原文锚点可读标题(块内保留,供 Agent 阅读)。
UserVerbatimSectionHeading = "## 用户历史输入(原文保留,勿省略或改写)"
// UserVerbatimSectionStartMarker / EndMarkerHTML 注释边界,供程序化替换;对模型无指令语义。
UserVerbatimSectionStartMarker = "<!-- user-verbatim-start -->"
UserVerbatimSectionEndMarker = "<!-- user-verbatim-end -->"
)
// ExtractUserContentsFromMessages 按时间顺序提取 user 角色消息的原文(跳过空白)。
func ExtractUserContentsFromMessages(msgs []database.Message) []string {
out := make([]string, 0, len(msgs))
for i := range msgs {
if !strings.EqualFold(strings.TrimSpace(msgs[i].Role), "user") {
continue
}
content := strings.TrimSpace(msgs[i].Content)
if content == "" {
continue
}
out = append(out, content)
}
return out
}
// BuildUserVerbatimAnchorBlockFromMessages 从 messages 表行构建用户原文锚点块。
// maxRunes: 0 = 不截断;>0 = 总 rune 上限(仍保留每一轮,仅对超长单条做尾部截断提示)。
func BuildUserVerbatimAnchorBlockFromMessages(msgs []database.Message, maxRunes int) string {
return BuildUserVerbatimAnchorBlock(ExtractUserContentsFromMessages(msgs), maxRunes)
}
// BuildUserVerbatimAnchorBlock 将各轮用户原文格式化为 system prompt 锚点块。
func BuildUserVerbatimAnchorBlock(userContents []string, maxRunes int) string {
if len(userContents) == 0 {
return ""
}
lines := make([]string, 0, len(userContents))
for _, content := range userContents {
content = strings.TrimSpace(content)
if content == "" {
continue
}
lines = append(lines, fmt.Sprintf("[第%d轮] %s", len(lines)+1, content))
}
if len(lines) == 0 {
return ""
}
body := strings.Join(lines, "\n")
if maxRunes > 0 {
body = capUserVerbatimBody(body, maxRunes)
}
return wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n" + body)
}
func capUserVerbatimBody(body string, maxRunes int) string {
rs := []rune(body)
if len(rs) <= maxRunes {
return body
}
suffix := "\n\n...(用户原文锚点已达配置上限,更早轮次可能被截断;完整原文见 messages 表)..."
suffixRunes := []rune(suffix)
keep := maxRunes - len(suffixRunes)
if keep <= 0 {
return string(rs[:maxRunes])
}
return string(rs[:keep]) + suffix
}
func wrapUserVerbatimBlock(content string) string {
content = strings.TrimSpace(content)
if content == "" {
return ""
}
return UserVerbatimSectionStartMarker + "\n" + content + "\n" + UserVerbatimSectionEndMarker + "\n"
}
// ReplaceUserVerbatimAnchorSection 用 freshBlock 替换 content 中已有的用户原文锚点段。
func ReplaceUserVerbatimAnchorSection(content, freshBlock string) (string, bool) {
content = strings.TrimSpace(content)
freshBlock = strings.TrimSpace(freshBlock)
if freshBlock == "" {
return content, false
}
start, ok := userVerbatimSectionStart(content)
if !ok {
return content, false
}
end, ok := userVerbatimSectionEnd(content, start)
if !ok {
return content, false
}
return strings.TrimSpace(content[:start] + freshBlock + content[end:]), true
}
func userVerbatimSectionStart(content string) (int, bool) {
idx := strings.Index(content, UserVerbatimSectionStartMarker)
if idx < 0 {
return 0, false
}
return idx, true
}
func userVerbatimSectionEnd(content string, start int) (int, bool) {
if start < 0 || start >= len(content) {
return 0, false
}
tail := content[start:]
idx := strings.LastIndex(tail, UserVerbatimSectionEndMarker)
if idx < 0 {
return 0, false
}
return start + idx + len(UserVerbatimSectionEndMarker), true
}
// RefreshUserVerbatimAnchorInMessages 在 summarization 等压缩后,用 freshBlock 刷新 system 中的用户原文锚点。
// 若尚无锚点段,则追加到首条 system 消息;若无 system 消息则在开头插入一条。
func RefreshUserVerbatimAnchorInMessages(msgs []adk.Message, freshBlock string) []adk.Message {
freshBlock = strings.TrimSpace(freshBlock)
if freshBlock == "" || len(msgs) == 0 {
return msgs
}
out := make([]adk.Message, len(msgs))
changed := false
for i, msg := range msgs {
if msg == nil || msg.Role != schema.System {
out[i] = msg
continue
}
newContent, ok := ReplaceUserVerbatimAnchorSection(msg.Content, freshBlock)
if !ok {
out[i] = msg
continue
}
cloned := *msg
cloned.Content = newContent
out[i] = &cloned
changed = true
}
if changed {
return out
}
for i, msg := range msgs {
if msg == nil || msg.Role != schema.System {
continue
}
cloned := *msg
cloned.Content = AppendSystemPromptBlock(cloned.Content, freshBlock)
out[i] = &cloned
return out
}
prefix := make([]adk.Message, 0, len(msgs)+1)
prefix = append(prefix, schema.SystemMessage(freshBlock))
return append(prefix, msgs...)
}
@@ -0,0 +1,96 @@
package project
import (
"strings"
"testing"
"cyberstrike-ai/internal/database"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func TestBuildUserVerbatimAnchorBlock_MultiTurn(t *testing.T) {
msgs := []database.Message{
{Role: "user", Content: "目标 https://a.com 仅测 /api"},
{Role: "assistant", Content: "好的"},
{Role: "user", Content: "用 admin:test 登录"},
}
block := BuildUserVerbatimAnchorBlockFromMessages(msgs, 0)
if block == "" {
t.Fatal("expected non-empty block")
}
if !strings.Contains(block, UserVerbatimSectionStartMarker) {
t.Error("missing start marker")
}
if !strings.Contains(block, "[第1轮]") || !strings.Contains(block, "https://a.com") {
t.Error("missing first user turn")
}
if !strings.Contains(block, "[第2轮]") || !strings.Contains(block, "admin:test") {
t.Error("missing second user turn")
}
if strings.Contains(block, "好的") {
t.Error("assistant content should not appear")
}
}
func TestReplaceUserVerbatimAnchorSection(t *testing.T) {
old := "prefix\n\n" + wrapUserVerbatimBlock("## old\n\n[第1轮] a") + "\nsuffix"
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] b\n[第2轮] c")
out, ok := ReplaceUserVerbatimAnchorSection(old, newBlock)
if !ok {
t.Fatal("expected replace ok")
}
if !strings.Contains(out, "[第2轮] c") {
t.Errorf("expected new block, got %q", out)
}
if !strings.HasPrefix(strings.TrimSpace(out), "prefix") {
t.Error("prefix should remain")
}
if !strings.Contains(out, "suffix") {
t.Error("suffix should remain")
}
}
func TestRefreshUserVerbatimAnchorInMessages_ReplaceExisting(t *testing.T) {
oldBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] old")
msgs := []adk.Message{
schema.SystemMessage("instr\n\n" + oldBlock),
schema.UserMessage("hi"),
}
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] new")
out := RefreshUserVerbatimAnchorInMessages(msgs, newBlock)
if len(out) != 2 {
t.Fatalf("message count: got %d", len(out))
}
if !strings.Contains(out[0].Content, "[第1轮] new") {
t.Errorf("system content: %q", out[0].Content)
}
if strings.Contains(out[0].Content, "[第1轮] old") {
t.Error("old anchor should be replaced")
}
}
func TestRefreshUserVerbatimAnchorInMessages_InsertWhenMissing(t *testing.T) {
msgs := []adk.Message{
schema.SystemMessage("base instruction"),
schema.UserMessage("hi"),
}
block := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] anchor")
out := RefreshUserVerbatimAnchorInMessages(msgs, block)
if !strings.Contains(out[0].Content, "[第1轮] anchor") {
t.Errorf("expected appended anchor, got %q", out[0].Content)
}
}
func TestBuildUserVerbatimAnchorBlock_MaxRunes(t *testing.T) {
long := strings.Repeat("字", 200)
block := BuildUserVerbatimAnchorBlock([]string{long}, 50)
body := block
if idx := strings.Index(body, UserVerbatimSectionStartMarker); idx >= 0 {
body = strings.TrimPrefix(body[idx+len(UserVerbatimSectionStartMarker):], "\n")
}
if len([]rune(body)) > 120 {
t.Errorf("expected capped body, got %d runes", len([]rune(body)))
}
}
+68
View File
@@ -0,0 +1,68 @@
package project
import (
"fmt"
"os"
"path/filepath"
"strings"
)
func sanitizeWorkspacePathSegment(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return "default"
}
s = strings.ReplaceAll(s, string(filepath.Separator), "-")
s = strings.ReplaceAll(s, "/", "-")
s = strings.ReplaceAll(s, "\\", "-")
s = strings.ReplaceAll(s, "..", "__")
if len(s) > 180 {
s = s[:180]
}
return s
}
// WorkspaceRootDir returns the relative workspace root for downloads and local analysis.
// Project-bound sessions share projects/<id>/; otherwise conversations/<id>/.
func WorkspaceRootDir(configuredBase, projectID, conversationID string) string {
base := strings.TrimSpace(configuredBase)
if base == "" {
base = filepath.Join("tmp", "workspace")
}
if pid := strings.TrimSpace(projectID); pid != "" {
return filepath.Join(base, "projects", sanitizeWorkspacePathSegment(pid))
}
conv := strings.TrimSpace(conversationID)
if conv == "" {
conv = "default"
}
return filepath.Join(base, "conversations", sanitizeWorkspacePathSegment(conv))
}
// EnsureWorkspace creates the workspace directory and returns its absolute path.
func EnsureWorkspace(root string) (string, error) {
abs, err := filepath.Abs(strings.TrimSpace(root))
if err != nil {
return "", fmt.Errorf("workspace abs: %w", err)
}
if err := os.MkdirAll(abs, 0o755); err != nil {
return "", fmt.Errorf("workspace mkdir: %w", err)
}
return abs, nil
}
// BuildWorkspaceBlock instructs the agent to use the session workspace instead of /tmp.
func BuildWorkspaceBlock(absPath string) string {
absPath = strings.TrimSpace(absPath)
if absPath == "" {
return ""
}
return fmt.Sprintf(`## 会话工作目录下载与本地分析
**必须使用以下目录**保存 curl/wget 下载的文件临时 HTML/JS以及 read_file/glob/grep 的检索范围
`+"`%s`"+`
- **禁止**使用系统 `+"`/tmp`"+` 或其它全局临时目录多项目/多会话会互窜遗留文件
- 下载示例`+"`curl -o '%s/page.html' 'https://target/'`"+`exec 时可将 `+"`workdir`"+` 设为该目录。
- 读取前用 glob/grep/read_file **限定在该目录**下搜索勿在 `+"`/tmp`"+` 盲目检索`, absPath, absPath)
}
+52
View File
@@ -0,0 +1,52 @@
package project
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestWorkspaceRootDirProjectScoped(t *testing.T) {
got := WorkspaceRootDir("", "proj-1", "conv-1")
want := filepath.Join("tmp", "workspace", "projects", "proj-1")
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestWorkspaceRootDirConversationScoped(t *testing.T) {
got := WorkspaceRootDir("/data/ws", "", "conv-abc")
want := filepath.Join("/data/ws", "conversations", "conv-abc")
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestEnsureWorkspaceCreatesDir(t *testing.T) {
root := filepath.Join(t.TempDir(), "nested", "workspace")
abs, err := EnsureWorkspace(root)
if err != nil {
t.Fatalf("EnsureWorkspace: %v", err)
}
st, err := os.Stat(abs)
if err != nil {
t.Fatalf("Stat: %v", err)
}
if !st.IsDir() {
t.Fatal("expected directory")
}
}
func TestBuildWorkspaceBlockMentionsPath(t *testing.T) {
block := BuildWorkspaceBlock("/opt/csai/tmp/workspace/projects/p1")
if block == "" {
t.Fatal("expected non-empty block")
}
if !strings.Contains(block, "/opt/csai/tmp/workspace/projects/p1") {
t.Fatalf("block missing path: %s", block)
}
if !strings.Contains(block, "/tmp") {
t.Fatalf("block should warn about /tmp: %s", block)
}
}
+11
View File
@@ -0,0 +1,11 @@
package projectprompt
// ShellExecExecuteGuidanceSection 供单代理/多代理系统提示追加:exec 与 execute 分工(尽量短)。
func ShellExecExecuteGuidanceSection() string {
return `Shellexec/execute):有专用 MCP 工具时优先专用工具;系统命令(管道、workdir、后台 &)用 execskills/ 内脚本(配合 read_file、skill)用 execute;多步扫描分拆调用,禁止一条 shell 串多个扫描器。下载/临时文件须写入系统提示中的「会话工作目录」,禁止用 /tmp。`
}
// ShellExecExecuteGuidanceReconSuffix 侦察子代理可选追加(一行)。
func ShellExecExecuteGuidanceReconSuffix() string {
return `枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。`
}
+29
View File
@@ -26,6 +26,35 @@ const (
wireOutputConfig wireOutputConfig
) )
// ApplyPlanExecutePlannerModelConfig configures the plan_execute planner/replanner
// ChatModel. Those Eino agents call WithToolChoice(Forced); several gateways reject
// thinking / reasoning fields on the same request (tool_choice required/object).
// Executor should keep the normal ApplyToEinoChatModelConfig path.
func ApplyPlanExecutePlannerModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig) {
if cfg == nil || oa == nil {
return
}
offOA := *oa
offReasoning := oa.Reasoning
offReasoning.Mode = "off"
offOA.Reasoning = offReasoning
ApplyToEinoChatModelConfig(cfg, &offOA, nil)
clearReasoningFromChatModelConfig(cfg)
}
func clearReasoningFromChatModelConfig(cfg *einoopenai.ChatModelConfig) {
if cfg == nil {
return
}
cfg.ReasoningEffort = ""
if cfg.ExtraFields != nil {
for _, key := range []string{"thinking", "reasoning_effort", "output_config", "reasoning"} {
delete(cfg.ExtraFields, key)
}
}
applyThinkingDisabled(cfg)
}
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg. // ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set. // Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) { func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
+24
View File
@@ -49,6 +49,30 @@ func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
} }
} }
func TestApplyPlanExecutePlannerModelConfig_stripsReasoningWhenGlobalOn(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
BaseURL: "https://antchat.example.com/v1",
Model: "minimax-m3",
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "high",
},
}
ApplyPlanExecutePlannerModelConfig(cfg, oa)
if cfg.ReasoningEffort != "" {
t.Fatalf("expected ReasoningEffort cleared, got %q", cfg.ReasoningEffort)
}
th, ok := cfg.ExtraFields["thinking"].(map[string]any)
if !ok || th["type"] != "disabled" {
t.Fatalf("expected thinking disabled, got %#v", cfg.ExtraFields)
}
if _, ok := cfg.ExtraFields["reasoning_effort"]; ok {
t.Fatalf("expected reasoning_effort stripped, got %#v", cfg.ExtraFields)
}
}
func TestApplyReasoningOff_disablesThinking(t *testing.T) { func TestApplyReasoningOff_disablesThinking(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{} cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{ oa := &config.OpenAIConfig{
@@ -0,0 +1,56 @@
package security
import (
"errors"
"fmt"
"os/exec"
"strings"
)
// FormatCommandFailureResult 与 exec 工具 ToolResult 文案一致(不含 ToolErrorPrefix)。
func FormatCommandFailureResult(exitCode int, output string) string {
output = strings.TrimSpace(output)
errMsg := fmt.Sprintf("exit status %d", exitCode)
if output == "" {
return fmt.Sprintf("命令执行失败: %s", errMsg)
}
if strings.HasPrefix(output, "命令执行失败:") {
return output
}
return fmt.Sprintf("命令执行失败: %s\n输出: %s", errMsg, output)
}
// FormatCommandFailureFromErr 根据 exec/execute 返回的 error 生成统一失败文案(IsError 正文)。
func FormatCommandFailureFromErr(err error, output string) string {
if err == nil {
return strings.TrimSpace(output)
}
var exitError *exec.ExitError
if errors.As(err, &exitError) {
return FormatCommandFailureResult(exitError.ExitCode(), output)
}
output = strings.TrimSpace(output)
if output == "" {
return fmt.Sprintf("命令执行失败: %v", err)
}
if strings.HasPrefix(output, "命令执行失败:") {
return output
}
return fmt.Sprintf("命令执行失败: %v\n输出: %s", err, output)
}
// ExecuteFailureStatusLine 流式 execute 结束时追加的单行状态(输出正文已在流中推送过)。
func ExecuteFailureStatusLine(exitCode int) string {
return fmt.Sprintf("\n命令执行失败: exit status %d", exitCode)
}
// IsCommandFailureResult 判断工具结果正文是否表示命令非零退出(用于 execute / exec 对齐 isError)。
func IsCommandFailureResult(content string) bool {
return strings.Contains(content, "命令执行失败:")
}
// IsLegacyShellExitNoise 过滤旧版 shell 流中冗余的 exit code 行。
func IsLegacyShellExitNoise(s string) bool {
trimmed := strings.TrimSpace(s)
return strings.HasPrefix(trimmed, "command exited with non-zero code ")
}
@@ -0,0 +1,54 @@
package security
import (
"errors"
"os/exec"
"strings"
"testing"
)
func TestFormatCommandFailureResult(t *testing.T) {
got := FormatCommandFailureResult(1, "sudo: password required")
want := "命令执行失败: exit status 1\n输出: sudo: password required"
if got != want {
t.Fatalf("got %q want %q", got, want)
}
if FormatCommandFailureResult(2, "") != "命令执行失败: exit status 2" {
t.Fatal("empty output format")
}
if FormatCommandFailureResult(1, "命令执行失败: exit status 1") != "命令执行失败: exit status 1" {
t.Fatal("should not double-wrap")
}
}
func TestIsCommandFailureResult(t *testing.T) {
if !IsCommandFailureResult("sudo: err\n命令执行失败: exit status 1") {
t.Fatal("expected true")
}
if IsCommandFailureResult("sudo: err only") {
t.Fatal("expected false")
}
}
func TestFormatCommandFailureFromErr(t *testing.T) {
cmd := exec.Command("sh", "-c", "exit 42")
err := cmd.Run()
got := FormatCommandFailureFromErr(err, "oops")
if got != "命令执行失败: exit status 42\n输出: oops" {
t.Fatalf("got %q", got)
}
timeoutErr := errors.New("shell inactivity timeout (300s)")
got2 := FormatCommandFailureFromErr(timeoutErr, "already timed out")
if !strings.Contains(got2, "shell inactivity timeout") || !strings.Contains(got2, "already timed out") {
t.Fatalf("got %q", got2)
}
}
func TestIsLegacyShellExitNoise(t *testing.T) {
if !IsLegacyShellExitNoise("command exited with non-zero code 1\n") {
t.Fatal("expected legacy noise")
}
if IsLegacyShellExitNoise("sudo: failed") {
t.Fatal("unexpected noise")
}
}
+139 -110
View File
@@ -32,10 +32,11 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
// Executor 安全工具执行器 // Executor 安全工具执行器
type Executor struct { type Executor struct {
config *config.SecurityConfig config *config.SecurityConfig
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server mcpServer *mcp.Server
logger *zap.Logger logger *zap.Logger
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300-1=关闭(见 SetShellNoOutputTimeoutSeconds
} }
// NewExecutor 创建新的执行器 // NewExecutor 创建新的执行器
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
return executor return executor
} }
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
e.shellNoOutputTimeoutSec = sec
}
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1) // buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
func (e *Executor) buildToolIndex() { func (e *Executor) buildToolIndex() {
e.toolIndex = make(map[string]*config.ToolConfig) e.toolIndex = make(map[string]*config.ToolConfig)
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 执行命令 // 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
applyDefaultTerminalEnv(cmd) applyDefaultTerminalEnv(cmd)
attachNonInteractiveStdin(cmd)
_ = prepareShellCmdSession(cmd) _ = prepareShellCmdSession(cmd)
e.logger.Info("执行安全工具", e.logger.Info("执行安全工具",
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
var err error var err error
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。 // 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
output, err = streamCommandOutput(ctx, cmd, cb) output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
if err != nil && shouldRetryWithPTY(output) { if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试", e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
zap.String("tool", toolName), zap.String("tool", toolName),
@@ -155,9 +162,8 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
output, err = runCommandWithPTY(ctx, cmd2, cb) output, err = runCommandWithPTY(ctx, cmd2, cb)
} }
} else { } else {
outputBytes, err2 := cmd.CombinedOutput() // 非流式:内存缓冲 + ctx 取消杀进程组;行为对齐原 CombinedOutput,避免双流管道 fan-in 死锁。
output = string(outputBytes) output, err = combinedOutputCancellable(ctx, cmd)
err = err2
if err != nil && shouldRetryWithPTY(output) { if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试", e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
zap.String("tool", toolName), zap.String("tool", toolName),
@@ -685,83 +691,21 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac
// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。 // IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。
// command1 & command2 不算完全后台(command2 仍在前台执行)。 // command1 & command2 不算完全后台(command2 仍在前台执行)。
func IsBackgroundShellCommand(command string) bool { func IsBackgroundShellCommand(command string) bool {
// 移除首尾空格
command = strings.TrimSpace(command) command = strings.TrimSpace(command)
if command == "" { if command == "" {
return false return false
} }
positions := findStandaloneAmpersandPositions(command)
// 检查命令中所有不在引号内的 & 符号 if len(positions) == 0 {
// 找到最后一个 & 符号,检查它是否在命令末尾
inSingleQuote := false
inDoubleQuote := false
escaped := false
lastAmpersandPos := -1
for i, r := range command {
if escaped {
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
if r == '\'' && !inDoubleQuote {
inSingleQuote = !inSingleQuote
continue
}
if r == '"' && !inSingleQuote {
inDoubleQuote = !inDoubleQuote
continue
}
if r == '&' && !inSingleQuote && !inDoubleQuote {
// 检查 & 前后是否有空格或换行(确保是独立的 &,而不是变量名的一部分)
isStandalone := false
// 检查前面:空格、制表符、换行符,或者是命令开头
if i == 0 {
isStandalone = true
} else {
prev := command[i-1]
if prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r' {
isStandalone = true
}
}
// 检查后面:空格、制表符、换行符,或者是命令末尾
if isStandalone {
if i == len(command)-1 {
// 在末尾,肯定是独立的 &
lastAmpersandPos = i
} else {
next := command[i+1]
if next == ' ' || next == '\t' || next == '\n' || next == '\r' {
// 后面有空格,是独立的 &
lastAmpersandPos = i
}
}
}
}
}
// 如果没有找到 & 符号,不是后台命令
if lastAmpersandPos == -1 {
return false return false
} }
last := positions[len(positions)-1]
// 检查最后一个 & 后面是否还有非空内容 afterAmpersand := strings.TrimSpace(command[last+1:])
afterAmpersand := strings.TrimSpace(command[lastAmpersandPos+1:]) if afterAmpersand != "" {
if afterAmpersand == "" { return false
// & 在末尾或后面只有空白字符,这是完全后台命令
// 检查 & 前面是否有内容
beforeAmpersand := strings.TrimSpace(command[:lastAmpersandPos])
return beforeAmpersand != ""
} }
beforeAmpersand := strings.TrimSpace(command[:last])
// 如果 & 后面还有非空内容,说明是 command1 & command2 的情况 return beforeAmpersand != ""
// 这种情况下,command2会在前台执行,所以不算完全后台命令
return false
} }
// executeSystemCommand 执行系统命令 // executeSystemCommand 执行系统命令
@@ -797,6 +741,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
zap.String("command", command), zap.String("command", command),
) )
command = PrepareShellCommandForExecute(command)
// 获取shell类型(可选,默认为sh) // 获取shell类型(可选,默认为sh)
shell := "sh" shell := "sh"
if s, ok := args["shell"].(string); ok && s != "" { if s, ok := args["shell"].(string); ok && s != "" {
@@ -820,8 +766,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
} else { } else {
cmd = exec.CommandContext(ctx, shell, "-c", command) cmd = exec.CommandContext(ctx, shell, "-c", command)
} }
applyDefaultTerminalEnv(cmd) ConfigureShellCmdForAgentExecute(cmd)
_ = prepareShellCmdSession(cmd)
// 执行命令 // 执行命令
e.logger.Info("执行系统命令", e.logger.Info("执行系统命令",
@@ -837,10 +782,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&") commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&")
commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand) commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand)
// 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid // 构建新命令:后台作业重定向标准流后 echo $pid(与 RedirectBackgroundJobStdio 一致)
// 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行, pidCommand := RedirectBackgroundJobStdio(commandWithoutAmpersand+" &") + " pid=$!; echo $pid"
// bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。
pidCommand := fmt.Sprintf("%s </dev/null >/dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand)
// 创建新命令来获取PID // 创建新命令来获取PID
var pidCmd *exec.Cmd var pidCmd *exec.Cmd
@@ -850,8 +793,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
} else { } else {
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand) pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
} }
applyDefaultTerminalEnv(pidCmd) ConfigureShellCmdForAgentExecute(pidCmd)
_ = prepareShellCmdSession(pidCmd)
// 获取stdout管道 // 获取stdout管道
stdout, err := pidCmd.StdoutPipe() stdout, err := pidCmd.StdoutPipe()
@@ -963,29 +905,25 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
var err error var err error
// 若上层提供工具输出增量回调,则边执行边流式读取。 // 若上层提供工具输出增量回调,则边执行边流式读取。
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
output, err = streamCommandOutput(ctx, cmd, cb) output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
if err != nil && shouldRetryWithPTY(output) { if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试") e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
cmd2 := exec.CommandContext(ctx, shell, "-c", command) cmd2 := exec.CommandContext(ctx, shell, "-c", command)
if workDir != "" { if workDir != "" {
cmd2.Dir = workDir cmd2.Dir = workDir
} }
applyDefaultTerminalEnv(cmd2) ConfigureShellCmdForAgentExecute(cmd2)
_ = prepareShellCmdSession(cmd2)
output, err = runCommandWithPTY(ctx, cmd2, cb) output, err = runCommandWithPTY(ctx, cmd2, cb)
} }
} else { } else {
outputBytes, err2 := cmd.CombinedOutput() output, err = combinedOutputCancellable(ctx, cmd)
output = string(outputBytes)
err = err2
if err != nil && shouldRetryWithPTY(output) { if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试") e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
cmd2 := exec.CommandContext(ctx, shell, "-c", command) cmd2 := exec.CommandContext(ctx, shell, "-c", command)
if workDir != "" { if workDir != "" {
cmd2.Dir = workDir cmd2.Dir = workDir
} }
applyDefaultTerminalEnv(cmd2) ConfigureShellCmdForAgentExecute(cmd2)
_ = prepareShellCmdSession(cmd2)
output, err = runCommandWithPTY(ctx, cmd2, nil) output, err = runCommandWithPTY(ctx, cmd2, nil)
} }
} }
@@ -999,7 +937,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
Content: []mcp.Content{ Content: []mcp.Content{
{ {
Type: "text", Type: "text",
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)), Text: FormatCommandFailureFromErr(err, output),
}, },
}, },
IsError: true, IsError: true,
@@ -1022,12 +960,58 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
}, nil }, nil
} }
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr // combinedOutputCancellable 行为对齐 cmd.CombinedOutputstdout/stderr 写入内存缓冲),
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。 // 但在 ctx 取消时 terminateCmdTree 终止整棵进程树。
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) { // 非流式路径不使用双流管道 fan-in,避免 stderr 撑满管道缓冲区时与 stdout 互相阻塞导致死锁。
if err := prepareShellCmdSession(cmd); err != nil { // 无输出空闲检测由上层 agent.tool_timeout_minutes 兜底,不改变原 CombinedOutput 语义。
func combinedOutputCancellable(ctx context.Context, cmd *exec.Cmd) (string, error) {
var stdoutBuf, stderrBuf strings.Builder
cmd.Stdout = &stdoutBuf
cmd.Stderr = &stderrBuf
session, err := StartShellSession(cmd)
if err != nil {
return "", err return "", err
} }
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
stopWatch := make(chan struct{})
go func() {
select {
case <-ctx.Done():
TerminateShellCmdSession(session)
case <-stopWatch:
}
}()
defer close(stopWatch)
var waitErr error
select {
case waitErr = <-done:
case <-ctx.Done():
waitErr = <-done
return joinCommandOutput(stdoutBuf.String(), stderrBuf.String()), ctx.Err()
}
return joinCommandOutput(stdoutBuf.String(), stderrBuf.String()), waitErr
}
func joinCommandOutput(stdout, stderr string) string {
if stderr == "" {
return stdout
}
if stdout == "" {
return stderr
}
return stdout + stderr
}
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
stdoutPipe, err := cmd.StdoutPipe() stdoutPipe, err := cmd.StdoutPipe()
if err != nil { if err != nil {
return "", err return "", err
@@ -1037,7 +1021,8 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
_ = stdoutPipe.Close() _ = stdoutPipe.Close()
return "", err return "", err
} }
if err := cmd.Start(); err != nil { session, err := StartShellSession(cmd)
if err != nil {
_ = stdoutPipe.Close() _ = stdoutPipe.Close()
_ = stderrPipe.Close() _ = stderrPipe.Close()
return "", err return "", err
@@ -1047,7 +1032,7 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
terminateCmdTree(cmd) TerminateShellCmdSession(session)
case <-stopWatch: case <-stopWatch:
} }
}() }()
@@ -1086,23 +1071,61 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
if deltaBuilder.Len() == 0 { if deltaBuilder.Len() == 0 {
return return
} }
cb(deltaBuilder.String()) if cb != nil {
cb(deltaBuilder.String())
}
deltaBuilder.Reset() deltaBuilder.Reset()
lastFlush = time.Now() lastFlush = time.Now()
} }
for chunk := range chunks { idleWatch := NewShellInactivityWatch(noOutputSec)
outBuilder.WriteString(chunk) if idleWatch != nil {
deltaBuilder.WriteString(chunk) defer idleWatch.Stop()
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次 }
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
fireInactivity := func() {
TerminateShellCmdSession(session)
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
outBuilder.WriteString(msg)
if cb != nil {
cb(msg)
}
_ = session.Wait()
}
chunksLoop:
for {
var idleCh <-chan struct{}
if idleWatch != nil {
idleCh = idleWatch.Expired
}
select {
case <-ctx.Done():
TerminateShellCmdSession(session)
flush() flush()
_ = session.Wait()
return outBuilder.String(), ctx.Err()
case <-idleCh:
fireInactivity()
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
case chunk, ok := <-chunks:
if !ok {
break chunksLoop
}
if chunk != "" && idleWatch != nil {
idleWatch.Bump()
}
outBuilder.WriteString(chunk)
deltaBuilder.WriteString(chunk)
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
flush()
}
} }
} }
flush() flush()
// 等待命令结束,返回最终退出状态 // 等待命令结束,返回最终退出状态
waitErr := cmd.Wait() waitErr := session.Wait()
return outBuilder.String(), waitErr return outBuilder.String(), waitErr
} }
@@ -1116,6 +1139,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
if cmd.Env == nil { if cmd.Env == nil {
cmd.Env = os.Environ() cmd.Env = os.Environ()
} }
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖 // 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
has := func(k string) bool { has := func(k string) bool {
prefix := k + "=" prefix := k + "="
@@ -1159,7 +1183,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
// PTY 方案为类 UnixWindows 走原逻辑 // PTY 方案为类 UnixWindows 走原逻辑
if cb != nil { if cb != nil {
return streamCommandOutput(ctx, cmd, cb) return streamCommandOutput(ctx, cmd, cb, 0)
} }
_ = prepareShellCmdSession(cmd) _ = prepareShellCmdSession(cmd)
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
@@ -1173,13 +1197,18 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
} }
defer func() { _ = ptmx.Close() }() defer func() { _ = ptmx.Close() }()
rootPID := 0
if cmd.Process != nil {
rootPID = cmd.Process.Pid
}
// ctx 取消时尽快终止子进程 // ctx 取消时尽快终止子进程
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
_ = ptmx.Close() // 触发读退出 _ = ptmx.Close() // 触发读退出
terminateCmdTree(cmd) terminateProcessGroup(rootPID, cmd)
case <-done: case <-done:
} }
}() }()

Some files were not shown because too many files have changed in this diff Show More