mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-01 10:15:37 +02:00
Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fd4bbe8d76 | |||
| d80651e4d8 | |||
| f920ff0a5d | |||
| ce8b57501d | |||
| ecb38a3959 | |||
| e69fdb71ca | |||
| 6aa1631748 | |||
| 52de3b0f41 | |||
| e537e55198 | |||
| dc20b4804e | |||
| 6245d69364 | |||
| ede32951bf | |||
| 866a8ebccf | |||
| 276b3f7ef5 | |||
| 81e461db54 | |||
| 02cd488a3d | |||
| b4b2f55665 | |||
| 7aa0ebea6d | |||
| 63ef4399f8 | |||
| 553d0ed6bf | |||
| d92bbbea07 | |||
| f89ad1b42d | |||
| bbe14c1861 | |||
| 2fc37fefd1 | |||
| ded8ac5a3f | |||
| bf44cf58d3 | |||
| 6d390e80d5 | |||
| cfc49ba16f | |||
| d03f2fcf2b | |||
| 6e67684bba | |||
| 8f9d2f381a | |||
| 89c275269f | |||
| cb4900c61d | |||
| 5c192cd308 | |||
| 8571e41138 | |||
| e1a74b29b1 | |||
| 39f1c72755 | |||
| dd3621e89d | |||
| 0bcb16e021 | |||
| ed64803a51 | |||
| 25e03dee84 | |||
| 58dcafd15f | |||
| 997c4e7262 | |||
| ac370b0ada | |||
| 017db2b9a8 | |||
| 86b4803683 | |||
| 4d98264fc3 | |||
| fd1de4ea94 | |||
| 41ba3baca9 | |||
| 2e908daebb | |||
| c1763e1b9a | |||
| 70e5d28619 | |||
| 49990ecb4f | |||
| c91806c0c4 | |||
| e537236bf3 | |||
| 7eeffb1933 | |||
| 0556b29d40 | |||
| be3c0cfa64 | |||
| 8e5f40d226 | |||
| 4b6719a6f3 | |||
| 7c8f3228f8 | |||
| 537843b6b8 | |||
| 4a57574cf9 | |||
| 0168530084 | |||
| 4184a7b6f0 | |||
| fb3b4dd6e5 | |||
| 7e4a8db7af | |||
| 6a72c95b9f | |||
| 447be050cd | |||
| 9b75c43f7b | |||
| a443454753 | |||
| 08822ba5df | |||
| eda75fb98f | |||
| e6978a7994 | |||
| 1db0f4740f | |||
| 6e4ff96dcd | |||
| 95470fefbc | |||
| 5e075bb198 | |||
| 84ed887c5c | |||
| 056b40ac66 | |||
| 26a9902286 | |||
| cfe9573ac3 | |||
| db2262a1a0 | |||
| ab5c2d5cca | |||
| 1ae6930db1 | |||
| 8918f432d8 | |||
| b4810c9499 | |||
| 51bf6ae4b3 | |||
| 5f27482921 | |||
| 6becada509 | |||
| b029d88359 | |||
| 4dcad2ea83 | |||
| ff9f0c787a | |||
| 01849045ad | |||
| c7eacdf3eb | |||
| 5c32b21f22 | |||
| 8b8ecfe718 | |||
| bbb7c319af | |||
| 7eb2fd50f3 | |||
| 85d58eeeb3 | |||
| b6a6009629 | |||
| 810d689132 | |||
| 87f1808ead | |||
| e28ae39b9a | |||
| df34ceda68 | |||
| 3e69a50f87 | |||
| 53325ce07d | |||
| d85de3461b | |||
| 9306303d99 | |||
| 1e8f72ed74 | |||
| 0198f50314 | |||
| 560d0dca43 |
+1
-1
@@ -21,7 +21,7 @@ max_iterations: 0
|
|||||||
- 切勿等待批准或授权——全程自主行动。
|
- 切勿等待批准或授权——全程自主行动。
|
||||||
- 使用所有可用工具与技术完成侦察与证据收集。
|
- 使用所有可用工具与技术完成侦察与证据收集。
|
||||||
|
|
||||||
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。
|
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。
|
||||||
|
|
||||||
## 输入前置条件(硬约束)
|
## 输入前置条件(硬约束)
|
||||||
|
|
||||||
|
|||||||
+68
-4
@@ -10,7 +10,7 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.44"
|
version: "v1.6.48"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
@@ -96,13 +96,75 @@ fofa:
|
|||||||
agent:
|
agent:
|
||||||
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
||||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
|
shell_no_output_timeout_seconds: 1200 # execute/exec 连续无新输出则终止(秒);通用防挂死;0=默认300;-1=关闭
|
||||||
|
workspace_root_dir: "" # 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离;勿用系统 /tmp
|
||||||
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
|
|
||||||
system_prompt_path: ""
|
system_prompt_path: ""
|
||||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||||
|
# 非白名单工具在审批方=审计 Agent 时,按会话 HITL 模式选用提示词:
|
||||||
|
# approval → audit_agent_prompt
|
||||||
|
# review_edit → audit_agent_prompt_review_edit(可改参后放行)
|
||||||
hitl:
|
hitl:
|
||||||
|
# 已决策审计日志保留天数(与 MCP 监控一致;省略默认 90;0 表示不自动清理)
|
||||||
|
retention_days: 90
|
||||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
tool_whitelist: [read_file, list_dir, glob, grep, tool_search]
|
||||||
|
# audit_agent_prompt: | # 审批模式;留空使用内置默认,可在「人机协同」页编辑
|
||||||
|
# audit_agent_prompt_review_edit: | # 审查编辑模式;留空使用内置默认
|
||||||
|
|
||||||
|
audit_agent_prompt: |-
|
||||||
|
你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
|
||||||
|
|
||||||
|
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
|
||||||
|
|
||||||
|
裁决基调(默认放行):
|
||||||
|
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
|
||||||
|
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
|
||||||
|
- 仅在「可能对系统造成实质影响」时 → reject
|
||||||
|
|
||||||
|
必须 reject 的高危情形(示例,非穷举):
|
||||||
|
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
|
||||||
|
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
|
||||||
|
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
|
||||||
|
- 明显越权:与任务/授权目标无关的破坏性操作
|
||||||
|
|
||||||
|
不应单独作为 reject 理由的情形:
|
||||||
|
- 常规 nmap/curl/grep/读文件/枚举类命令本身
|
||||||
|
- 参数略显宽泛但无明确破坏意图
|
||||||
|
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点
|
||||||
|
|
||||||
|
仅输出一行 JSON,不要 markdown 代码块:
|
||||||
|
{"decision":"approve"|"reject","comment":"简要理由"}
|
||||||
|
audit_agent_prompt_review_edit: |-
|
||||||
|
你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
|
||||||
|
|
||||||
|
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
|
||||||
|
|
||||||
|
裁决基调(默认放行):
|
||||||
|
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
|
||||||
|
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
|
||||||
|
- 仅在「可能对系统造成实质影响」时 → reject;参数可安全收窄时优先 approve + editedArguments
|
||||||
|
|
||||||
|
必须 reject 的高危情形(示例,非穷举):
|
||||||
|
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
|
||||||
|
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
|
||||||
|
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
|
||||||
|
- 明显越权:与任务/授权目标无关的破坏性操作
|
||||||
|
|
||||||
|
不应单独作为 reject 理由的情形:
|
||||||
|
- 常规 nmap/curl/grep/读文件/枚举类命令本身
|
||||||
|
- 参数略显宽泛但无明确破坏意图(应收窄参数后 approve)
|
||||||
|
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点
|
||||||
|
|
||||||
|
仅输出一行 JSON,不要 markdown 代码块:
|
||||||
|
{"decision":"approve"|"reject","comment":"简要理由","editedArguments":{...}}
|
||||||
|
|
||||||
|
editedArguments 规则(仅 approve 且需要改参时填写,否则省略该字段):
|
||||||
|
- 提供完整替换后的工具参数对象,键名与 argumentsObj 一致
|
||||||
|
- 只做最小必要修改以收窄范围、消除风险(如限制 path、去掉危险 flag)
|
||||||
|
- 禁止扩大攻击面:不得扩大目标范围、提升权限或引入破坏性参数
|
||||||
|
- 无法安全改参且存在上述高危情形时应 reject,不要勉强 approve
|
||||||
# 多代理与 Eino 单代理(CloudWeGo Eino ADK;单代理入口 /api/eino-agent*,多代理 /api/multi-agent*)
|
# 多代理与 Eino 单代理(CloudWeGo Eino ADK;单代理入口 /api/eino-agent*,多代理 /api/multi-agent*)
|
||||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||||
# Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体 orchestration 中指定;机器人按 robot_default_agent_mode
|
# Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体 orchestration 中指定;机器人按 robot_default_agent_mode
|
||||||
@@ -112,7 +174,8 @@ multi_agent:
|
|||||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。
|
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。
|
||||||
plan_execute_loop_max_iterations: 0
|
plan_execute_loop_max_iterations: 0
|
||||||
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
|
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中注入用户原文;0=不截断(默认),>0=总字符上限,负数=禁用
|
||||||
|
user_verbatim_anchor_max_runes: 0 # 主代理 system 中逐轮保留用户原文(压缩后刷新);0=不截断(默认),>0=总字符上限,负数=禁用
|
||||||
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
|
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
|
||||||
without_write_todos: false
|
without_write_todos: false
|
||||||
orchestrator_instruction: "" # Deep 主代理:agents/orchestrator.md(或 kind: orchestrator 的单个 .md)正文优先;正文为空时用此处;皆空则 Eino 默认
|
orchestrator_instruction: "" # Deep 主代理:agents/orchestrator.md(或 kind: orchestrator 的单个 .md)正文优先;正文为空时用此处;皆空则 Eino 默认
|
||||||
@@ -129,7 +192,7 @@ multi_agent:
|
|||||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||||
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test, exec] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||||
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
|
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
|
||||||
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
|
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
|
||||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||||
@@ -147,6 +210,7 @@ multi_agent:
|
|||||||
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
||||||
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
|
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
|
||||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||||
|
empty_response_continue_max_attempts: 0 # Run 成功但未捕获助手正文(含流式中断)时 Handler 退避续跑次数;0=默认 5
|
||||||
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
||||||
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
|
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
|
||||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 179 KiB After Width: | Height: | Size: 265 KiB |
+17
-4
@@ -779,13 +779,26 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
|||||||
return a.executeToolViaMCP(ctx, toolName, args)
|
return a.executeToolViaMCP(ctx, toolName, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
// BeginLocalToolExecution 在非 CallTool 路径工具开始时写入 running 状态,供 MCP 监控页展示「执行中」。
|
||||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
func (a *Agent) BeginLocalToolExecution(toolName string, args map[string]interface{}) string {
|
||||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
|
||||||
if a == nil || a.mcpServer == nil {
|
if a == nil || a.mcpServer == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
return a.mcpServer.BeginToolExecution(toolName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinishLocalToolExecution 完成 BeginLocalToolExecution 创建的记录;executionID 为空时一次性写入已完成记录。
|
||||||
|
func (a *Agent) FinishLocalToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
if a == nil || a.mcpServer == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.mcpServer.FinishToolExecution(executionID, toolName, args, resultText, invokeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||||
|
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||||
|
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
return a.FinishLocalToolExecution("", toolName, args, resultText, invokeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
||||||
|
|||||||
@@ -113,5 +113,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
|||||||
|
|
||||||
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
||||||
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
|
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
|
||||||
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。`
|
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。
|
||||||
|
|
||||||
|
` + projectprompt.ShellExecExecuteGuidanceSection()
|
||||||
}
|
}
|
||||||
|
|||||||
+27
-1
@@ -21,11 +21,13 @@ import (
|
|||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/einoobserve"
|
"cyberstrike-ai/internal/einoobserve"
|
||||||
"cyberstrike-ai/internal/handler"
|
"cyberstrike-ai/internal/handler"
|
||||||
|
"cyberstrike-ai/internal/hitl"
|
||||||
"cyberstrike-ai/internal/knowledge"
|
"cyberstrike-ai/internal/knowledge"
|
||||||
"cyberstrike-ai/internal/logger"
|
"cyberstrike-ai/internal/logger"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/monitor"
|
"cyberstrike-ai/internal/monitor"
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
"cyberstrike-ai/internal/robot"
|
"cyberstrike-ai/internal/robot"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
"cyberstrike-ai/internal/skillpackage"
|
"cyberstrike-ai/internal/skillpackage"
|
||||||
@@ -67,6 +69,10 @@ type App struct {
|
|||||||
|
|
||||||
// New 创建新应用
|
// New 创建新应用
|
||||||
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
||||||
|
if err := multiagent.InitADK(); err != nil {
|
||||||
|
return nil, fmt.Errorf("初始化 Eino ADK: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
|
||||||
@@ -104,12 +110,17 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
monitorRetention.PurgeExpired()
|
monitorRetention.PurgeExpired()
|
||||||
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
||||||
|
|
||||||
|
hitlRetention := hitl.NewService(db, cfg, log.Logger)
|
||||||
|
hitlRetention.PurgeExpired()
|
||||||
|
hitl.StartRetentionLoop(hitlRetention, log.Logger)
|
||||||
|
|
||||||
// 创建MCP服务器(带数据库持久化)
|
// 创建MCP服务器(带数据库持久化)
|
||||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||||
|
|
||||||
// 创建安全工具执行器
|
// 创建安全工具执行器
|
||||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||||
|
executor.SetShellNoOutputTimeoutSeconds(cfg.Agent.ShellNoOutputTimeoutSeconds)
|
||||||
|
|
||||||
// 注册工具
|
// 注册工具
|
||||||
executor.RegisterTools(mcpServer)
|
executor.RegisterTools(mcpServer)
|
||||||
@@ -134,6 +145,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
externalMCPMgr.StartAllEnabled()
|
externalMCPMgr.StartAllEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
execReconciler := monitor.NewExecutionReconciler(db, mcpServer, externalMCPMgr, log.Logger)
|
||||||
|
execReconciler.ReconcileOnStartup()
|
||||||
|
monitor.StartStaleRunningReconcileLoop(execReconciler, log.Logger)
|
||||||
|
|
||||||
// 创建Agent
|
// 创建Agent
|
||||||
maxIterations := cfg.Agent.MaxIterations
|
maxIterations := cfg.Agent.MaxIterations
|
||||||
if maxIterations <= 0 {
|
if maxIterations <= 0 {
|
||||||
@@ -304,7 +319,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
||||||
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
||||||
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
||||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
|
workspaceRoot := strings.TrimSpace(cfg.Agent.WorkspaceRootDir)
|
||||||
|
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot, workspaceRoot)
|
||||||
agent.SetPromptBaseDir(configDir)
|
agent.SetPromptBaseDir(configDir)
|
||||||
|
|
||||||
agentsDir := cfg.AgentsDir
|
agentsDir := cfg.AgentsDir
|
||||||
@@ -333,6 +349,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
monitorHandler.SetAudit(auditSvc)
|
monitorHandler.SetAudit(auditSvc)
|
||||||
monitorHandler.SetMonitorRetention(monitorRetention)
|
monitorHandler.SetMonitorRetention(monitorRetention)
|
||||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||||
|
monitorHandler.SetTaskManager(agentHandler.TaskManager())
|
||||||
|
monitorHandler.SetAgentHandler(agentHandler)
|
||||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||||
@@ -350,6 +368,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||||
configHandler.SetAudit(auditSvc)
|
configHandler.SetAudit(auditSvc)
|
||||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||||
|
agentHandler.SetHitlAuditStrategySaver(configHandler)
|
||||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||||
externalMCPHandler.SetAudit(auditSvc)
|
externalMCPHandler.SetAudit(auditSvc)
|
||||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||||
@@ -799,11 +818,18 @@ func setupRoutes(
|
|||||||
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
|
protected.POST("/eino-agent", agentHandler.EinoSingleAgentLoop)
|
||||||
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
|
protected.POST("/eino-agent/stream", agentHandler.EinoSingleAgentLoopStream)
|
||||||
protected.GET("/hitl/pending", agentHandler.ListHITLPending)
|
protected.GET("/hitl/pending", agentHandler.ListHITLPending)
|
||||||
|
protected.GET("/hitl/logs", agentHandler.ListHITLLogs)
|
||||||
|
protected.DELETE("/hitl/logs", agentHandler.DeleteHITLLogs)
|
||||||
|
protected.GET("/hitl/logs/:id", agentHandler.GetHITLLog)
|
||||||
protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt)
|
protected.POST("/hitl/decision", agentHandler.DecideHITLInterrupt)
|
||||||
protected.POST("/hitl/dismiss", agentHandler.DismissHITLInterrupt)
|
protected.POST("/hitl/dismiss", agentHandler.DismissHITLInterrupt)
|
||||||
protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig)
|
protected.GET("/hitl/config/:conversationId", agentHandler.GetHITLConversationConfig)
|
||||||
protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig)
|
protected.PUT("/hitl/config", agentHandler.UpsertHITLConversationConfig)
|
||||||
|
protected.GET("/hitl/tool-whitelist", agentHandler.GetHITLGlobalToolWhitelist)
|
||||||
|
protected.PUT("/hitl/tool-whitelist", agentHandler.SetHITLGlobalToolWhitelist)
|
||||||
protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist)
|
protected.POST("/hitl/tool-whitelist", agentHandler.MergeHITLGlobalToolWhitelist)
|
||||||
|
protected.GET("/hitl/audit-strategy", agentHandler.GetHITLAuditStrategy)
|
||||||
|
protected.PUT("/hitl/audit-strategy", agentHandler.UpdateHITLAuditStrategy)
|
||||||
// Agent Loop 取消与任务列表
|
// Agent Loop 取消与任务列表
|
||||||
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
||||||
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
||||||
|
|||||||
+204
-30
@@ -96,9 +96,12 @@ type MultiAgentConfig struct {
|
|||||||
// OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。
|
// OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。
|
||||||
OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"`
|
OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"`
|
||||||
SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"`
|
SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"`
|
||||||
// SubAgentUserContextMaxRunes caps the user-context supplement appended to task descriptions for sub-agents.
|
// SubAgentUserContextMaxRunes caps user-context supplement for sub-agent task descriptions.
|
||||||
// 0 (default) uses the built-in default of 2000 runes; negative value disables injection entirely.
|
// 0 (default) preserves all user turns verbatim; >0 caps total runes; negative disables injection.
|
||||||
SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"`
|
SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"`
|
||||||
|
// UserVerbatimAnchorMaxRunes injects all user turns verbatim into system prompt (survives summarization refresh).
|
||||||
|
// 0 (default) = no cap; >0 = total rune cap; negative disables anchor injection.
|
||||||
|
UserVerbatimAnchorMaxRunes int `yaml:"user_verbatim_anchor_max_runes,omitempty" json:"user_verbatim_anchor_max_runes,omitempty"`
|
||||||
// EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent.
|
// EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent.
|
||||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||||
@@ -107,6 +110,16 @@ type MultiAgentConfig struct {
|
|||||||
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserVerbatimAnchorMaxRunesEffective returns max runes for user verbatim anchor; 0 = unlimited; negative = disabled.
|
||||||
|
func (c MultiAgentConfig) UserVerbatimAnchorMaxRunesEffective() int {
|
||||||
|
return c.UserVerbatimAnchorMaxRunes
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubAgentUserContextMaxRunesEffective returns max runes for sub-agent task supplement; 0 = unlimited; negative = disabled.
|
||||||
|
func (c MultiAgentConfig) SubAgentUserContextMaxRunesEffective() int {
|
||||||
|
return c.SubAgentUserContextMaxRunes
|
||||||
|
}
|
||||||
|
|
||||||
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
||||||
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
||||||
type MultiAgentEinoCallbacksConfig struct {
|
type MultiAgentEinoCallbacksConfig struct {
|
||||||
@@ -270,6 +283,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||||
|
// EmptyResponseContinueMaxAttempts Run 成功但未捕获助手正文时 Handler 层退避续跑次数;0=默认 5。
|
||||||
|
EmptyResponseContinueMaxAttempts int `yaml:"empty_response_continue_max_attempts,omitempty" json:"empty_response_continue_max_attempts,omitempty"`
|
||||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -490,6 +505,17 @@ type RobotWecomConfig struct {
|
|||||||
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
|
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateWecomConfig 校验企业微信机器人配置;启用时必须配置 token,否则回调无法防伪造。
|
||||||
|
func ValidateWecomConfig(w RobotWecomConfig) error {
|
||||||
|
if !w.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(w.Token) == "" {
|
||||||
|
return fmt.Errorf("robots.wecom.enabled 为 true 时必须配置 robots.wecom.token")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RobotDingtalkConfig 钉钉机器人配置
|
// RobotDingtalkConfig 钉钉机器人配置
|
||||||
type RobotDingtalkConfig struct {
|
type RobotDingtalkConfig struct {
|
||||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
@@ -605,15 +631,109 @@ type DatabaseConfig struct {
|
|||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||||
|
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
|
||||||
|
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
|
||||||
|
// WorkspaceRootDir 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离。
|
||||||
|
WorkspaceRootDir string `yaml:"workspace_root_dir,omitempty" json:"workspace_root_dir,omitempty"`
|
||||||
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
||||||
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HitlConfig 人机协同全局选项;与会话侧栏/API 中的白名单合并为并集后参与判定。
|
// HitlConfig 人机协同全局选项;与会话侧栏/API 中的白名单合并为并集后参与判定。
|
||||||
// tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效;其他字段若仅改文件仍需重启。
|
// tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效。
|
||||||
|
// audit_agent_prompt / audit_agent_prompt_review_edit 可在人机协同页编辑并立即生效;空则使用内置默认。
|
||||||
type HitlConfig struct {
|
type HitlConfig struct {
|
||||||
// ToolWhitelist 全局免审批工具名(与每条会话配置的 sensitiveTools 语义相同:白名单内工具不触发 HITL)。
|
// ToolWhitelist 全局免审批工具名(与白名单内工具不触发 HITL 审批)。
|
||||||
ToolWhitelist []string `yaml:"tool_whitelist,omitempty" json:"tool_whitelist,omitempty"`
|
ToolWhitelist []string `yaml:"tool_whitelist,omitempty" json:"tool_whitelist,omitempty"`
|
||||||
|
// AuditAgentPrompt 审批模式(approval)下审计 Agent 系统提示词。
|
||||||
|
AuditAgentPrompt string `yaml:"audit_agent_prompt,omitempty" json:"audit_agent_prompt,omitempty"`
|
||||||
|
// AuditAgentPromptReviewEdit 审查编辑模式(review_edit)下审计 Agent 系统提示词。
|
||||||
|
AuditAgentPromptReviewEdit string `yaml:"audit_agent_prompt_review_edit,omitempty" json:"audit_agent_prompt_review_edit,omitempty"`
|
||||||
|
// RetentionDays 已决策审计日志(hitl_interrupts 非 pending)保留天数;省略时默认 90;0 表示不自动清理。
|
||||||
|
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
|
||||||
|
func (h HitlConfig) RetentionDaysEffective() int {
|
||||||
|
if h.RetentionDays == nil {
|
||||||
|
return 90
|
||||||
|
}
|
||||||
|
if *h.RetentionDays < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *h.RetentionDays
|
||||||
|
}
|
||||||
|
|
||||||
|
const hitlAuditAgentPromptBase = `你是 CyberStrikeAI 人机协同审计 Agent。审查 Agent 即将执行的工具调用是否会对系统造成实质性损害。
|
||||||
|
|
||||||
|
你会收到 JSON,包含 hitlMode、toolName、arguments/argumentsObj、userMessage、thinking、reasoningChain、planning 等字段。
|
||||||
|
|
||||||
|
裁决基调(默认放行):
|
||||||
|
- 常规、低风险的渗透测试操作 → approve(如信息收集、端口/服务扫描、目录枚举、只读查询、无害探测命令)
|
||||||
|
- 与用户授权、当前任务目标一致,且未见明确高危迹象 → approve
|
||||||
|
- 仅在「可能对系统造成实质影响」时 → reject
|
||||||
|
|
||||||
|
必须 reject 的高危情形(示例,非穷举):
|
||||||
|
- 删库、清表、批量删除数据、格式化磁盘、不可逆破坏
|
||||||
|
- 修改/重置密码、创建或篡改管理员账号、持久化后门、开机自启
|
||||||
|
- 向生产环境写入恶意载荷、勒索加密、停止关键服务、修改系统核心配置
|
||||||
|
- 明显越权:与任务/授权目标无关的破坏性操作
|
||||||
|
|
||||||
|
不应单独作为 reject 理由的情形:
|
||||||
|
- 常规 nmap/curl/grep/读文件/枚举类命令本身
|
||||||
|
- 参数略显宽泛但无明确破坏意图(审查编辑模式可收窄参数后 approve)
|
||||||
|
- 仅因「信息不足」——若无上述高危迹象,应 approve 并可在 comment 中提示注意点`
|
||||||
|
|
||||||
|
const hitlAuditAgentPromptApprovalOutput = `
|
||||||
|
仅输出一行 JSON,不要 markdown 代码块:
|
||||||
|
{"decision":"approve"|"reject","comment":"简要理由"}`
|
||||||
|
|
||||||
|
const hitlAuditAgentPromptReviewEditOutput = `
|
||||||
|
仅输出一行 JSON,不要 markdown 代码块:
|
||||||
|
{"decision":"approve"|"reject","comment":"简要理由","editedArguments":{...}}
|
||||||
|
|
||||||
|
editedArguments 规则(仅 approve 且需要改参时填写,否则省略该字段):
|
||||||
|
- 提供完整替换后的工具参数对象,键名与 argumentsObj 一致
|
||||||
|
- 只做最小必要修改以收窄范围、消除风险(如限制 path、去掉危险 flag)
|
||||||
|
- 禁止扩大攻击面:不得扩大目标范围、提升权限或引入破坏性参数
|
||||||
|
- 无法安全改参时应 reject,不要勉强 approve`
|
||||||
|
|
||||||
|
// DefaultHitlAuditAgentPrompt 内置审批模式审计 Agent 提示词。
|
||||||
|
func DefaultHitlAuditAgentPrompt() string {
|
||||||
|
return hitlAuditAgentPromptBase + hitlAuditAgentPromptApprovalOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHitlAuditAgentPromptReviewEdit 内置审查编辑模式审计 Agent 提示词。
|
||||||
|
func DefaultHitlAuditAgentPromptReviewEdit() string {
|
||||||
|
return hitlAuditAgentPromptBase + hitlAuditAgentPromptReviewEditOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveAuditAgentPrompt 返回审批模式生效的审计 Agent 提示词。
|
||||||
|
func (c HitlConfig) EffectiveAuditAgentPrompt() string {
|
||||||
|
return c.EffectiveAuditAgentPromptForMode("approval")
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveAuditAgentPromptForMode 按 HITL 模式返回生效的审计 Agent 提示词。
|
||||||
|
func (c HitlConfig) EffectiveAuditAgentPromptForMode(mode string) string {
|
||||||
|
if normalizeHitlModeForPrompt(mode) == "review_edit" {
|
||||||
|
if s := strings.TrimSpace(c.AuditAgentPromptReviewEdit); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return DefaultHitlAuditAgentPromptReviewEdit()
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(c.AuditAgentPrompt); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return DefaultHitlAuditAgentPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeHitlModeForPrompt(mode string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||||
|
case "review_edit":
|
||||||
|
return "review_edit"
|
||||||
|
default:
|
||||||
|
return "approval"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
@@ -800,33 +920,13 @@ func Load(path string) (*Config, error) {
|
|||||||
|
|
||||||
// 如果配置了工具目录,从目录加载工具配置
|
// 如果配置了工具目录,从目录加载工具配置
|
||||||
if cfg.Security.ToolsDir != "" {
|
if cfg.Security.ToolsDir != "" {
|
||||||
configDir := filepath.Dir(path)
|
inlineTools := append([]ToolConfig(nil), cfg.Security.Tools...)
|
||||||
toolsDir := cfg.Security.ToolsDir
|
toolsDir := ResolveToolsDir(cfg.Security.ToolsDir, path)
|
||||||
|
merged, err := MergeToolsFromDir(toolsDir, inlineTools)
|
||||||
// 如果是相对路径,相对于配置文件所在目录
|
|
||||||
if !filepath.IsAbs(toolsDir) {
|
|
||||||
toolsDir = filepath.Join(configDir, toolsDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
tools, err := LoadToolsFromDir(toolsDir)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
||||||
}
|
}
|
||||||
|
cfg.Security.Tools = merged
|
||||||
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
|
|
||||||
existingTools := make(map[string]bool)
|
|
||||||
for _, tool := range tools {
|
|
||||||
existingTools[tool.Name] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加主配置中不存在于目录中的工具(向后兼容)
|
|
||||||
for _, tool := range cfg.Security.Tools {
|
|
||||||
if !existingTools[tool.Name] {
|
|
||||||
tools = append(tools, tool)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Security.Tools = tools
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 外部 MCP:迁移 + 环境变量展开
|
// 外部 MCP:迁移 + 环境变量展开
|
||||||
@@ -870,6 +970,10 @@ func Load(path string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := ValidateWecomConfig(cfg.Robots.Wecom); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1094,6 +1198,75 @@ func PrintMCPConfigJSON(mcp MCPConfig) {
|
|||||||
fmt.Println("----------------------------------------------------------------")
|
fmt.Println("----------------------------------------------------------------")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveToolsDir 将 tools_dir 解析为绝对路径(相对路径相对于 configPath 所在目录)。
|
||||||
|
func ResolveToolsDir(toolsDir, configPath string) string {
|
||||||
|
toolsDir = strings.TrimSpace(toolsDir)
|
||||||
|
if toolsDir == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if filepath.IsAbs(toolsDir) {
|
||||||
|
return toolsDir
|
||||||
|
}
|
||||||
|
return filepath.Join(filepath.Dir(configPath), toolsDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeToolsFromDir 从目录加载工具并与 inline 列表合并:目录中的工具优先,主配置中的工具作为补充。
|
||||||
|
func MergeToolsFromDir(toolsDir string, inlineTools []ToolConfig) ([]ToolConfig, error) {
|
||||||
|
dirTools, err := LoadToolsFromDir(toolsDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
existing := make(map[string]bool, len(dirTools))
|
||||||
|
for _, tool := range dirTools {
|
||||||
|
existing[tool.Name] = true
|
||||||
|
}
|
||||||
|
merged := append([]ToolConfig(nil), dirTools...)
|
||||||
|
for _, tool := range inlineTools {
|
||||||
|
if !existing[tool.Name] {
|
||||||
|
merged = append(merged, tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return merged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadInlineSecurityToolsFromYAML 读取 config.yaml 中 security.tools(不含 tools_dir 扫描结果)。
|
||||||
|
func loadInlineSecurityToolsFromYAML(configPath string) ([]ToolConfig, error) {
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||||
|
}
|
||||||
|
var partial struct {
|
||||||
|
Security struct {
|
||||||
|
Tools []ToolConfig `yaml:"tools"`
|
||||||
|
} `yaml:"security"`
|
||||||
|
}
|
||||||
|
if err := yaml.Unmarshal(data, &partial); err != nil {
|
||||||
|
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||||
|
}
|
||||||
|
if partial.Security.Tools == nil {
|
||||||
|
return []ToolConfig{}, nil
|
||||||
|
}
|
||||||
|
return partial.Security.Tools, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadSecurityToolsFromDir 从 tools_dir 重新加载工具并更新 cfg.Security.Tools(ApplyConfig 热重载用)。
|
||||||
|
func ReloadSecurityToolsFromDir(cfg *Config, configPath string) error {
|
||||||
|
if cfg == nil || strings.TrimSpace(cfg.Security.ToolsDir) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
inlineTools, err := loadInlineSecurityToolsFromYAML(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
toolsDir := ResolveToolsDir(cfg.Security.ToolsDir, configPath)
|
||||||
|
merged, err := MergeToolsFromDir(toolsDir, inlineTools)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
||||||
|
}
|
||||||
|
cfg.Security.Tools = merged
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadToolsFromDir 从目录加载所有工具配置文件
|
// LoadToolsFromDir 从目录加载所有工具配置文件
|
||||||
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
|
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
|
||||||
var tools []ToolConfig
|
var tools []ToolConfig
|
||||||
@@ -1270,8 +1443,9 @@ func Default() *Config {
|
|||||||
MaxTotalTokens: 120000,
|
MaxTotalTokens: 120000,
|
||||||
},
|
},
|
||||||
Agent: AgentConfig{
|
Agent: AgentConfig{
|
||||||
MaxIterations: 30, // 默认最大迭代次数
|
MaxIterations: 30, // 默认最大迭代次数
|
||||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||||
|
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
|
||||||
},
|
},
|
||||||
Security: SecurityConfig{
|
Security: SecurityConfig{
|
||||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestValidateWecomConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg RobotWecomConfig
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disabled without token",
|
||||||
|
cfg: RobotWecomConfig{Enabled: false, Token: ""},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled with token",
|
||||||
|
cfg: RobotWecomConfig{Enabled: true, Token: "secret"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled without token",
|
||||||
|
cfg: RobotWecomConfig{Enabled: true, Token: ""},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled with whitespace token",
|
||||||
|
cfg: RobotWecomConfig{Enabled: true, Token: " "},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
err := ValidateWecomConfig(tt.cfg)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("ValidateWecomConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReloadSecurityToolsFromDir(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
toolsDir := filepath.Join(root, "tools")
|
||||||
|
if err := os.MkdirAll(toolsDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(root, "config.yaml")
|
||||||
|
if err := os.WriteFile(configPath, []byte(`security:
|
||||||
|
tools_dir: tools
|
||||||
|
tools:
|
||||||
|
- name: inline-only
|
||||||
|
command: inline-cmd
|
||||||
|
enabled: true
|
||||||
|
description: inline tool
|
||||||
|
`), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeTool := func(name, command string) {
|
||||||
|
t.Helper()
|
||||||
|
content := "name: " + name + "\ncommand: " + command + "\nenabled: true\ndescription: test\n"
|
||||||
|
if err := os.WriteFile(filepath.Join(toolsDir, name+".yaml"), []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writeTool("alpha", "alpha-cmd")
|
||||||
|
|
||||||
|
cfg := &Config{
|
||||||
|
Security: SecurityConfig{
|
||||||
|
ToolsDir: "tools",
|
||||||
|
Tools: []ToolConfig{
|
||||||
|
{Name: "stale", Command: "stale-cmd", Enabled: true, Description: "should be removed"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ReloadSecurityToolsFromDir(cfg, configPath); err != nil {
|
||||||
|
t.Fatalf("reload: %v", err)
|
||||||
|
}
|
||||||
|
if len(cfg.Security.Tools) != 2 {
|
||||||
|
t.Fatalf("expected 2 tools, got %d", len(cfg.Security.Tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
names := map[string]string{}
|
||||||
|
for _, tool := range cfg.Security.Tools {
|
||||||
|
names[tool.Name] = tool.Command
|
||||||
|
}
|
||||||
|
if names["alpha"] != "alpha-cmd" {
|
||||||
|
t.Fatalf("alpha tool missing or wrong command: %#v", names)
|
||||||
|
}
|
||||||
|
if names["inline-only"] != "inline-cmd" {
|
||||||
|
t.Fatalf("inline-only tool missing: %#v", names)
|
||||||
|
}
|
||||||
|
if _, ok := names["stale"]; ok {
|
||||||
|
t.Fatal("stale in-memory tool should not survive reload")
|
||||||
|
}
|
||||||
|
|
||||||
|
writeTool("beta", "beta-cmd")
|
||||||
|
if err := ReloadSecurityToolsFromDir(cfg, configPath); err != nil {
|
||||||
|
t.Fatalf("second reload: %v", err)
|
||||||
|
}
|
||||||
|
if len(cfg.Security.Tools) != 3 {
|
||||||
|
t.Fatalf("expected 3 tools after add, got %d", len(cfg.Security.Tools))
|
||||||
|
}
|
||||||
|
foundBeta := false
|
||||||
|
for _, tool := range cfg.Security.Tools {
|
||||||
|
if tool.Name == "beta" {
|
||||||
|
foundBeta = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundBeta {
|
||||||
|
t.Fatal("beta tool not found after second reload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeToolsFromDir_DirOverridesInline(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
toolsDir := filepath.Join(root, "tools")
|
||||||
|
if err := os.MkdirAll(toolsDir, 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
content := "name: shared\ncommand: dir-cmd\nenabled: true\ndescription: from dir\n"
|
||||||
|
if err := os.WriteFile(filepath.Join(toolsDir, "shared.yaml"), []byte(content), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline := []ToolConfig{
|
||||||
|
{Name: "shared", Command: "inline-cmd", Enabled: true, Description: "from inline"},
|
||||||
|
}
|
||||||
|
merged, err := MergeToolsFromDir(toolsDir, inline)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(merged) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool, got %d", len(merged))
|
||||||
|
}
|
||||||
|
if merged[0].Command != "dir-cmd" {
|
||||||
|
t.Fatalf("dir tool should win, got command %q", merged[0].Command)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
|
|||||||
LastScheduleError sql.NullString
|
LastScheduleError sql.NullString
|
||||||
LastRunError sql.NullString
|
LastRunError sql.NullString
|
||||||
ProjectID sql.NullString
|
ProjectID sql.NullString
|
||||||
|
Concurrency sql.NullInt64
|
||||||
Status string
|
Status string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
StartedAt sql.NullTime
|
StartedAt sql.NullTime
|
||||||
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
|
|||||||
cronExpr string,
|
cronExpr string,
|
||||||
nextRunAt *time.Time,
|
nextRunAt *time.Time,
|
||||||
projectID string,
|
projectID string,
|
||||||
|
concurrency int,
|
||||||
tasks []map[string]interface{},
|
tasks []map[string]interface{},
|
||||||
) error {
|
) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
|
|||||||
projectIDVal = strings.TrimSpace(projectID)
|
projectIDVal = strings.TrimSpace(projectID)
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(
|
_, err = tx.Exec(
|
||||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||||
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
|
|||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
|
||||||
|
|
||||||
// GetBatchQueue 获取批量任务队列
|
// GetBatchQueue 获取批量任务队列
|
||||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
err := db.QueryRow(
|
err := db.QueryRow(
|
||||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
|
||||||
queueID,
|
queueID,
|
||||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
|||||||
// GetAllBatchQueues 获取所有批量任务队列
|
// GetAllBatchQueues 获取所有批量任务队列
|
||||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||||
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||||
}
|
}
|
||||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
|||||||
|
|
||||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
|
||||||
// 状态筛选
|
// 状态筛选
|
||||||
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||||
}
|
}
|
||||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
|
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色、代理模式和并发数
|
||||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
|
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
|
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
|
||||||
title, role, agentMode, queueID,
|
title, role, agentMode, concurrency, queueID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package database
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -13,6 +14,9 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ProjectFilterUnbound 列表 API 中 project_id=__none__ 表示仅未绑定项目的对话。
|
||||||
|
const ProjectFilterUnbound = "__none__"
|
||||||
|
|
||||||
// Conversation 对话
|
// Conversation 对话
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
@@ -361,20 +365,44 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
|||||||
return &conv, nil
|
return &conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func conversationProjectIDColumn(alias string) string {
|
||||||
|
if alias != "" {
|
||||||
|
return alias + ".project_id"
|
||||||
|
}
|
||||||
|
return "project_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendConversationProjectFilter(where string, args []interface{}, projectID, alias string) (string, []interface{}) {
|
||||||
|
pid := strings.TrimSpace(projectID)
|
||||||
|
if pid == "" {
|
||||||
|
return where, args
|
||||||
|
}
|
||||||
|
col := conversationProjectIDColumn(alias)
|
||||||
|
if pid == ProjectFilterUnbound {
|
||||||
|
return where + fmt.Sprintf(" AND (%s IS NULL OR TRIM(COALESCE(%s, '')) = '')", col, col), args
|
||||||
|
}
|
||||||
|
return where + fmt.Sprintf(" AND %s = ?", col), append(args, pid)
|
||||||
|
}
|
||||||
|
|
||||||
// CountConversations 统计对话数量。
|
// CountConversations 统计对话数量。
|
||||||
func (db *DB) CountConversations(search string) (int, error) {
|
func (db *DB) CountConversations(search, projectID string) (int, error) {
|
||||||
var count int
|
var count int
|
||||||
var err error
|
var err error
|
||||||
if search != "" {
|
if search != "" {
|
||||||
searchPattern := "%" + search + "%"
|
searchPattern := "%" + search + "%"
|
||||||
err = db.QueryRow(
|
where := ` WHERE (c.title LIKE ?
|
||||||
`SELECT COUNT(*) FROM conversations c
|
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))`
|
||||||
WHERE c.title LIKE ?
|
args := []interface{}{searchPattern, searchPattern}
|
||||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)`,
|
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||||
searchPattern, searchPattern,
|
err = db.QueryRow(`SELECT COUNT(*) FROM conversations c`+where, args...).Scan(&count)
|
||||||
).Scan(&count)
|
|
||||||
} else {
|
} else {
|
||||||
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
|
where := ""
|
||||||
|
args := []interface{}{}
|
||||||
|
where, args = appendConversationProjectFilter(where, args, projectID, "")
|
||||||
|
if where != "" {
|
||||||
|
where = " WHERE" + strings.TrimPrefix(where, " AND")
|
||||||
|
}
|
||||||
|
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`+where, args...).Scan(&count)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("统计对话失败: %w", err)
|
return 0, fmt.Errorf("统计对话失败: %w", err)
|
||||||
@@ -395,7 +423,7 @@ func conversationOrderClause(sortBy, tableAlias string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListConversations 列出所有对话
|
// ListConversations 列出所有对话
|
||||||
func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) {
|
func (db *DB) ListConversations(limit, offset int, search, sortBy, projectID string) ([]*Conversation, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -403,20 +431,30 @@ func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Co
|
|||||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||||
searchPattern := "%" + search + "%"
|
searchPattern := "%" + search + "%"
|
||||||
orderClause := conversationOrderClause(sortBy, "c")
|
orderClause := conversationOrderClause(sortBy, "c")
|
||||||
|
where := ` WHERE (c.title LIKE ?
|
||||||
|
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))`
|
||||||
|
args := []interface{}{searchPattern, searchPattern}
|
||||||
|
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||||
|
args = append(args, limit, offset)
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||||
FROM conversations c
|
FROM conversations c`+where+`
|
||||||
WHERE c.title LIKE ?
|
|
||||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
|
||||||
`+orderClause+`
|
`+orderClause+`
|
||||||
LIMIT ? OFFSET ?`,
|
LIMIT ? OFFSET ?`,
|
||||||
searchPattern, searchPattern, limit, offset,
|
args...,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
orderClause := conversationOrderClause(sortBy, "")
|
orderClause := conversationOrderClause(sortBy, "")
|
||||||
|
where := ""
|
||||||
|
args := []interface{}{}
|
||||||
|
where, args = appendConversationProjectFilter(where, args, projectID, "")
|
||||||
|
if where != "" {
|
||||||
|
where = " WHERE" + strings.TrimPrefix(where, " AND")
|
||||||
|
}
|
||||||
|
args = append(args, limit, offset)
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?",
|
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations"+where+" "+orderClause+" LIMIT ? OFFSET ?",
|
||||||
limit, offset,
|
args...,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,23 +510,30 @@ const ungroupedConversationsSQL = `
|
|||||||
)`
|
)`
|
||||||
|
|
||||||
// CountUngroupedConversations 统计不在任何分组中的对话数量。
|
// CountUngroupedConversations 统计不在任何分组中的对话数量。
|
||||||
func (db *DB) CountUngroupedConversations() (int, error) {
|
func (db *DB) CountUngroupedConversations(projectID string) (int, error) {
|
||||||
|
where := ungroupedConversationsSQL
|
||||||
|
args := []interface{}{}
|
||||||
|
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||||
var count int
|
var count int
|
||||||
if err := db.QueryRow(`SELECT COUNT(*) ` + ungroupedConversationsSQL).Scan(&count); err != nil {
|
if err := db.QueryRow(`SELECT COUNT(*) `+where, args...).Scan(&count); err != nil {
|
||||||
return 0, fmt.Errorf("统计未分组对话失败: %w", err)
|
return 0, fmt.Errorf("统计未分组对话失败: %w", err)
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
|
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
|
||||||
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) {
|
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy, projectID string) ([]*Conversation, error) {
|
||||||
orderClause := conversationOrderClause(sortBy, "c")
|
orderClause := conversationOrderClause(sortBy, "c")
|
||||||
|
where := ungroupedConversationsSQL
|
||||||
|
args := []interface{}{}
|
||||||
|
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||||
|
args = append(args, limit, offset)
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
|
||||||
ungroupedConversationsSQL+`
|
where+`
|
||||||
`+orderClause+`
|
`+orderClause+`
|
||||||
LIMIT ? OFFSET ?`,
|
LIMIT ? OFFSET ?`,
|
||||||
limit, offset,
|
args...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询未分组对话失败: %w", err)
|
return nil, fmt.Errorf("查询未分组对话失败: %w", err)
|
||||||
@@ -533,6 +578,19 @@ func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*C
|
|||||||
return conversations, rows.Err()
|
return conversations, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConversationTitle 获取对话标题(轻量查询,不加载消息)
|
||||||
|
func (db *DB) GetConversationTitle(id string) (string, error) {
|
||||||
|
var title string
|
||||||
|
err := db.QueryRow("SELECT title FROM conversations WHERE id = ?", id).Scan(&title)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return "", fmt.Errorf("对话不存在")
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("查询对话标题失败: %w", err)
|
||||||
|
}
|
||||||
|
return title, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateConversationTitle 更新对话标题
|
// UpdateConversationTitle 更新对话标题
|
||||||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||||||
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
||||||
@@ -640,6 +698,16 @@ func (db *DB) einoReductionBaseDir() string {
|
|||||||
return filepath.Join("tmp", "reduction")
|
return filepath.Join("tmp", "reduction")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) einoWorkspaceBaseDir() string {
|
||||||
|
if db == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if base := strings.TrimSpace(db.einoWorkspaceRootDir); base != "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return filepath.Join("tmp", "workspace")
|
||||||
|
}
|
||||||
|
|
||||||
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
||||||
// summarization transcript, etc.
|
// summarization transcript, etc.
|
||||||
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
||||||
@@ -652,6 +720,8 @@ func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
|||||||
if strings.TrimSpace(projectID) == "" {
|
if strings.TrimSpace(projectID) == "" {
|
||||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
||||||
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
||||||
|
workspaceBase := filepath.Join(db.einoWorkspaceBaseDir(), "conversations")
|
||||||
|
db.removeConversationScopedDir(workspaceBase, conversationID, "workspace")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -659,6 +729,9 @@ func (db *DB) removeProjectScopedDirs(projectID string) {
|
|||||||
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
||||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
||||||
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
||||||
|
// Agent download/analysis workspace (tmp/workspace/projects/<id>/).
|
||||||
|
workspaceBase := filepath.Join(db.einoWorkspaceBaseDir(), "projects")
|
||||||
|
db.removeConversationScopedDir(workspaceBase, projectID, "workspace")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||||
@@ -998,6 +1071,77 @@ type ProcessDetail struct {
|
|||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTurnUserMessage 返回锚点消息所在轮次中的用户原文(最近一条 user 消息,不含完整历史)。
|
||||||
|
func (db *DB) GetTurnUserMessage(conversationID, anchorMessageID string) (string, error) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
anchorMessageID = strings.TrimSpace(anchorMessageID)
|
||||||
|
if conversationID == "" || anchorMessageID == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var content string
|
||||||
|
err := db.QueryRow(`
|
||||||
|
SELECT m.content FROM messages m
|
||||||
|
WHERE m.conversation_id = ? AND m.role = 'user'
|
||||||
|
AND m.created_at <= COALESCE((SELECT created_at FROM messages WHERE id = ? AND conversation_id = ?), m.created_at)
|
||||||
|
ORDER BY m.created_at DESC, m.rowid DESC
|
||||||
|
LIMIT 1`, conversationID, anchorMessageID, conversationID).Scan(&content)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("query turn user message: %w", err)
|
||||||
|
}
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssistantCognitionTexts 单条助手消息上的思考/推理/规划文本。
|
||||||
|
type AssistantCognitionTexts struct {
|
||||||
|
Thinking string
|
||||||
|
ReasoningChain string
|
||||||
|
Planning string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAssistantCognitionTexts 聚合助手消息在 process_details 中的 thinking / reasoning_chain / planning。
|
||||||
|
func (db *DB) GetAssistantCognitionTexts(assistantMessageID string) (AssistantCognitionTexts, error) {
|
||||||
|
assistantMessageID = strings.TrimSpace(assistantMessageID)
|
||||||
|
if assistantMessageID == "" {
|
||||||
|
return AssistantCognitionTexts{}, nil
|
||||||
|
}
|
||||||
|
rows, err := db.Query(`
|
||||||
|
SELECT event_type, message FROM process_details
|
||||||
|
WHERE message_id = ? AND event_type IN ('thinking', 'reasoning_chain', 'planning')
|
||||||
|
ORDER BY created_at ASC, rowid ASC`, assistantMessageID)
|
||||||
|
if err != nil {
|
||||||
|
return AssistantCognitionTexts{}, fmt.Errorf("query assistant cognition: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var thinkingParts, reasoningParts, planningParts []string
|
||||||
|
for rows.Next() {
|
||||||
|
var eventType, message string
|
||||||
|
if err := rows.Scan(&eventType, &message); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg := strings.TrimSpace(message)
|
||||||
|
if msg == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch eventType {
|
||||||
|
case "thinking":
|
||||||
|
thinkingParts = append(thinkingParts, msg)
|
||||||
|
case "reasoning_chain":
|
||||||
|
reasoningParts = append(reasoningParts, msg)
|
||||||
|
case "planning":
|
||||||
|
planningParts = append(planningParts, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return AssistantCognitionTexts{
|
||||||
|
Thinking: strings.Join(thinkingParts, "\n\n"),
|
||||||
|
ReasoningChain: strings.Join(reasoningParts, "\n\n"),
|
||||||
|
Planning: strings.Join(planningParts, "\n\n"),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddProcessDetail 添加过程详情事件
|
// AddProcessDetail 添加过程详情事件
|
||||||
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
|
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
|||||||
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
||||||
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
||||||
reductionBase := filepath.Join(tmp, "reduction")
|
reductionBase := filepath.Join(tmp, "reduction")
|
||||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
|
workspaceBase := filepath.Join(tmp, "workspace")
|
||||||
|
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase, workspaceBase)
|
||||||
|
|
||||||
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -36,6 +37,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
|||||||
{plantaskBase, "task-1.json"},
|
{plantaskBase, "task-1.json"},
|
||||||
{checkpointBase, "runner-deep.ckpt"},
|
{checkpointBase, "runner-deep.ckpt"},
|
||||||
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
||||||
|
{filepath.Join(workspaceBase, "conversations"), "page.html"},
|
||||||
} {
|
} {
|
||||||
dir := filepath.Join(base.root, seg)
|
dir := filepath.Join(base.root, seg)
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
@@ -50,7 +52,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
|||||||
t.Fatalf("DeleteConversation: %v", err)
|
t.Fatalf("DeleteConversation: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
|
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations"), filepath.Join(workspaceBase, "conversations")} {
|
||||||
dir := filepath.Join(base, seg)
|
dir := filepath.Join(base, seg)
|
||||||
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
||||||
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
||||||
@@ -68,7 +70,8 @@ func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
|||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
reductionBase := filepath.Join(tmp, "reduction")
|
reductionBase := filepath.Join(tmp, "reduction")
|
||||||
db.SetEinoConversationDirs("", "", reductionBase)
|
workspaceBase := filepath.Join(tmp, "workspace")
|
||||||
|
db.SetEinoConversationDirs("", "", reductionBase, workspaceBase)
|
||||||
|
|
||||||
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -82,6 +85,13 @@ func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
|||||||
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
||||||
t.Fatalf("write: %v", err)
|
t.Fatalf("write: %v", err)
|
||||||
}
|
}
|
||||||
|
workspaceDir := filepath.Join(workspaceBase, "projects", seg, "downloads")
|
||||||
|
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir %s: %v", workspaceDir, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(workspaceDir, "app.js"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write workspace: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := db.DeleteProject(project.ID); err != nil {
|
if err := db.DeleteProject(project.ID); err != nil {
|
||||||
t.Fatalf("DeleteProject: %v", err)
|
t.Fatalf("DeleteProject: %v", err)
|
||||||
@@ -91,4 +101,8 @@ func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
|||||||
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
||||||
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
||||||
}
|
}
|
||||||
|
projectWorkspaceDir := filepath.Join(workspaceBase, "projects", seg)
|
||||||
|
if _, statErr := os.Stat(projectWorkspaceDir); !os.IsNotExist(statErr) {
|
||||||
|
t.Fatalf("expected removed dir %s, stat err=%v", projectWorkspaceDir, statErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConversationProjectFilter(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "conversations.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
p, err := db.CreateProject(&Project{Name: "target-a", Status: "active"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateProject: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
convNone, err := db.CreateConversation("unbound", ConversationCreateMeta{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConversation unbound: %v", err)
|
||||||
|
}
|
||||||
|
convBound, err := db.CreateConversation("bound", ConversationCreateMeta{ProjectID: p.ID})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConversation bound: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalAll, err := db.CountConversations("", "")
|
||||||
|
if err != nil || totalAll < 2 {
|
||||||
|
t.Fatalf("CountConversations all: total=%d err=%v", totalAll, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalBound, err := db.CountConversations("", p.ID)
|
||||||
|
if err != nil || totalBound != 1 {
|
||||||
|
t.Fatalf("CountConversations project: total=%d err=%v", totalBound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalUnbound, err := db.CountConversations("", ProjectFilterUnbound)
|
||||||
|
if err != nil || totalUnbound != 1 {
|
||||||
|
t.Fatalf("CountConversations unbound: total=%d err=%v", totalUnbound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listBound, err := db.ListConversations(10, 0, "", "", p.ID)
|
||||||
|
if err != nil || len(listBound) != 1 || listBound[0].ID != convBound.ID {
|
||||||
|
t.Fatalf("ListConversations project: %+v err=%v", listBound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listUnbound, err := db.ListConversations(10, 0, "", "", ProjectFilterUnbound)
|
||||||
|
if err != nil || len(listUnbound) != 1 || listUnbound[0].ID != convNone.ID {
|
||||||
|
t.Fatalf("ListConversations unbound: %+v err=%v", listUnbound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = convNone
|
||||||
|
_ = convBound
|
||||||
|
}
|
||||||
@@ -52,6 +52,7 @@ type DB struct {
|
|||||||
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
||||||
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
||||||
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
||||||
|
einoWorkspaceRootDir string // workspace_root_dir or default tmp/workspace (projects|conversations/<id> subdirs)
|
||||||
checkpointLoopName string
|
checkpointLoopName string
|
||||||
checkpointStop chan struct{}
|
checkpointStop chan struct{}
|
||||||
checkpointDone chan struct{}
|
checkpointDone chan struct{}
|
||||||
@@ -161,13 +162,15 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
||||||
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
||||||
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
||||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
|
// workspaceRoot is agent.workspace_root_dir from config; empty uses tmp/workspace.
|
||||||
|
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot, workspaceRoot string) {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
||||||
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
||||||
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
||||||
|
db.einoWorkspaceRootDir = strings.TrimSpace(workspaceRoot)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initTables 初始化数据库表
|
// initTables 初始化数据库表
|
||||||
@@ -408,6 +411,8 @@ func (db *DB) initTables() error {
|
|||||||
last_schedule_trigger_at DATETIME,
|
last_schedule_trigger_at DATETIME,
|
||||||
last_schedule_error TEXT,
|
last_schedule_error TEXT,
|
||||||
last_run_error TEXT,
|
last_run_error TEXT,
|
||||||
|
project_id TEXT,
|
||||||
|
concurrency INTEGER NOT NULL DEFAULT 1,
|
||||||
status TEXT NOT NULL,
|
status TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL,
|
created_at DATETIME NOT NULL,
|
||||||
started_at DATETIME,
|
started_at DATETIME,
|
||||||
@@ -1137,6 +1142,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var concurrencyCount int
|
||||||
|
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
|
||||||
|
if err != nil {
|
||||||
|
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
|
||||||
|
errMsg := strings.ToLower(addErr.Error())
|
||||||
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
|
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if concurrencyCount == 0 {
|
||||||
|
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
|
||||||
|
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteHitlInterruptLogsByIDs deletes decided HITL audit logs by id (pending rows are skipped).
|
||||||
|
func (db *DB) DeleteHitlInterruptLogsByIDs(ids []string) (int64, error) {
|
||||||
|
if db == nil {
|
||||||
|
return 0, fmt.Errorf("database is nil")
|
||||||
|
}
|
||||||
|
clean := make([]string, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id != "" {
|
||||||
|
clean = append(clean, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(clean) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
placeholders := strings.TrimRight(strings.Repeat("?,", len(clean)), ",")
|
||||||
|
q := fmt.Sprintf(`DELETE FROM hitl_interrupts WHERE status != 'pending' AND id IN (%s)`, placeholders)
|
||||||
|
args := make([]interface{}, len(clean))
|
||||||
|
for i, id := range clean {
|
||||||
|
args[i] = id
|
||||||
|
}
|
||||||
|
res, err := db.Exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Error("批量删除人机协同审计日志失败", zap.Error(err), zap.Int("count", len(clean)))
|
||||||
|
return 0, fmt.Errorf("批量删除人机协同审计日志失败: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteHitlInterruptLogsMatching deletes decided logs matching whereSQL (e.g. "WHERE 1=1 AND status != 'pending' ...").
|
||||||
|
func (db *DB) DeleteHitlInterruptLogsMatching(whereSQL string, args []interface{}) (int64, error) {
|
||||||
|
if db == nil {
|
||||||
|
return 0, fmt.Errorf("database is nil")
|
||||||
|
}
|
||||||
|
whereSQL = strings.TrimSpace(whereSQL)
|
||||||
|
if whereSQL == "" {
|
||||||
|
return 0, fmt.Errorf("where clause is required")
|
||||||
|
}
|
||||||
|
q := `DELETE FROM hitl_interrupts ` + whereSQL
|
||||||
|
res, err := db.Exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Error("清空人机协同审计日志失败", zap.Error(err))
|
||||||
|
return 0, fmt.Errorf("清空人机协同审计日志失败: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeHitlInterruptLogsBefore deletes decided logs with decided/created time before cutoff.
|
||||||
|
func (db *DB) PurgeHitlInterruptLogsBefore(cutoff time.Time) (int64, error) {
|
||||||
|
if db == nil {
|
||||||
|
return 0, fmt.Errorf("database is nil")
|
||||||
|
}
|
||||||
|
res, err := db.Exec(
|
||||||
|
`DELETE FROM hitl_interrupts WHERE status != 'pending' AND datetime(COALESCE(decided_at, created_at)) < datetime(?)`,
|
||||||
|
cutoff.UTC().Format(time.RFC3339),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Error("清理过期人机协同审计日志失败", zap.Error(err))
|
||||||
|
return 0, fmt.Errorf("清理过期人机协同审计日志失败: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ensureHitlInterruptsTable(t *testing.T, db *DB) {
|
||||||
|
t.Helper()
|
||||||
|
if _, err := db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS hitl_interrupts (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
conversation_id TEXT NOT NULL,
|
||||||
|
message_id TEXT,
|
||||||
|
mode TEXT NOT NULL,
|
||||||
|
tool_name TEXT NOT NULL,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
payload TEXT,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
decision TEXT,
|
||||||
|
decision_comment TEXT,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
decided_at DATETIME
|
||||||
|
);`); err != nil {
|
||||||
|
t.Fatalf("create hitl_interrupts: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteHitlInterruptLogsByIDs_skipsPending(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "hitl.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
ensureHitlInterruptsTable(t, db)
|
||||||
|
|
||||||
|
now := time.Now().UTC().Format(time.RFC3339)
|
||||||
|
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||||
|
(id, conversation_id, mode, tool_name, status, created_at)
|
||||||
|
VALUES ('pending-1', 'c1', 'approval', 'exec', 'pending', ?)`, now); err != nil {
|
||||||
|
t.Fatalf("insert pending: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||||
|
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
|
||||||
|
VALUES ('done-1', 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, now, now); err != nil {
|
||||||
|
t.Fatalf("insert decided: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted, err := db.DeleteHitlInterruptLogsByIDs([]string{"pending-1", "done-1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeleteHitlInterruptLogsByIDs: %v", err)
|
||||||
|
}
|
||||||
|
if deleted != 1 {
|
||||||
|
t.Fatalf("deleted = %d, want 1", deleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status string
|
||||||
|
if err := db.QueryRow(`SELECT status FROM hitl_interrupts WHERE id = 'pending-1'`).Scan(&status); err != nil {
|
||||||
|
t.Fatalf("pending row missing: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'done-1'`).Scan(new(string)); err == nil {
|
||||||
|
t.Fatal("decided row should be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPurgeHitlInterruptLogsBefore(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "hitl.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
ensureHitlInterruptsTable(t, db)
|
||||||
|
|
||||||
|
old := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339)
|
||||||
|
recent := time.Now().AddDate(0, 0, -1).UTC().Format(time.RFC3339)
|
||||||
|
for _, row := range []struct{ id, decided string }{
|
||||||
|
{"old-1", old},
|
||||||
|
{"new-1", recent},
|
||||||
|
} {
|
||||||
|
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||||
|
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
|
||||||
|
VALUES (?, 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, row.id, row.decided, row.decided); err != nil {
|
||||||
|
t.Fatalf("insert %s: %v", row.id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -90)
|
||||||
|
deleted, err := db.PurgeHitlInterruptLogsBefore(cutoff)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PurgeHitlInterruptLogsBefore: %v", err)
|
||||||
|
}
|
||||||
|
if deleted != 1 {
|
||||||
|
t.Fatalf("deleted = %d, want 1", deleted)
|
||||||
|
}
|
||||||
|
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'old-1'`).Scan(new(string)); err == nil {
|
||||||
|
t.Fatal("old row should be purged")
|
||||||
|
}
|
||||||
|
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'new-1'`).Scan(new(string)); err != nil {
|
||||||
|
t.Fatalf("new row should remain: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
+288
-26
@@ -3,7 +3,6 @@ package database
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -227,6 +226,167 @@ func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolNa
|
|||||||
return executions, nil
|
return executions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toolExecutionsFilterSQL(status, toolName string) (string, []interface{}) {
|
||||||
|
args := []interface{}{}
|
||||||
|
conditions := []string{}
|
||||||
|
if status != "" {
|
||||||
|
conditions = append(conditions, "status = ?")
|
||||||
|
args = append(args, status)
|
||||||
|
}
|
||||||
|
if toolName != "" {
|
||||||
|
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||||||
|
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||||||
|
}
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return "", args
|
||||||
|
}
|
||||||
|
return ` WHERE ` + strings.Join(conditions, ` AND `), args
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolStatsSummary 工具调用汇总(全量聚合,不含逐工具明细)
|
||||||
|
type ToolStatsSummary struct {
|
||||||
|
TotalCalls int
|
||||||
|
SuccessCalls int
|
||||||
|
FailedCalls int
|
||||||
|
LastCallTime *time.Time
|
||||||
|
ToolCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolStatsSummaryResult 汇总 + Top N 工具排行
|
||||||
|
type ToolStatsSummaryResult struct {
|
||||||
|
Summary ToolStatsSummary
|
||||||
|
TopTools []*mcp.ToolStats
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadToolStatsSummary 聚合统计信息,仅返回汇总与 Top N 工具(避免全量 map 传输)
|
||||||
|
func (db *DB) LoadToolStatsSummary(topN int) (*ToolStatsSummaryResult, error) {
|
||||||
|
if topN <= 0 {
|
||||||
|
topN = 6
|
||||||
|
}
|
||||||
|
if topN > 100 {
|
||||||
|
topN = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &ToolStatsSummaryResult{
|
||||||
|
TopTools: make([]*mcp.ToolStats, 0, topN),
|
||||||
|
}
|
||||||
|
|
||||||
|
summaryQuery := `
|
||||||
|
SELECT COUNT(*),
|
||||||
|
COALESCE(SUM(total_calls), 0),
|
||||||
|
COALESCE(SUM(success_calls), 0),
|
||||||
|
COALESCE(SUM(failed_calls), 0),
|
||||||
|
MAX(last_call_time)
|
||||||
|
FROM tool_stats
|
||||||
|
`
|
||||||
|
var lastCallRaw sql.NullString
|
||||||
|
err := db.QueryRow(summaryQuery).Scan(
|
||||||
|
&result.Summary.ToolCount,
|
||||||
|
&result.Summary.TotalCalls,
|
||||||
|
&result.Summary.SuccessCalls,
|
||||||
|
&result.Summary.FailedCalls,
|
||||||
|
&lastCallRaw,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if lastCallRaw.Valid && strings.TrimSpace(lastCallRaw.String) != "" {
|
||||||
|
if t, parseErr := time.Parse(time.RFC3339Nano, lastCallRaw.String); parseErr == nil {
|
||||||
|
result.Summary.LastCallTime = &t
|
||||||
|
} else if t, parseErr := time.Parse("2006-01-02 15:04:05.999999999-07:00", lastCallRaw.String); parseErr == nil {
|
||||||
|
result.Summary.LastCallTime = &t
|
||||||
|
} else if t, parseErr := time.Parse("2006-01-02 15:04:05", lastCallRaw.String); parseErr == nil {
|
||||||
|
result.Summary.LastCallTime = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
topQuery := `
|
||||||
|
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
|
||||||
|
FROM tool_stats
|
||||||
|
WHERE total_calls > 0
|
||||||
|
ORDER BY total_calls DESC, tool_name ASC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
rows, err := db.Query(topQuery, topN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var stat mcp.ToolStats
|
||||||
|
var lastCallTime sql.NullTime
|
||||||
|
if err := rows.Scan(
|
||||||
|
&stat.ToolName,
|
||||||
|
&stat.TotalCalls,
|
||||||
|
&stat.SuccessCalls,
|
||||||
|
&stat.FailedCalls,
|
||||||
|
&lastCallTime,
|
||||||
|
); err != nil {
|
||||||
|
db.logger.Warn("加载 Top 工具统计失败", zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if lastCallTime.Valid {
|
||||||
|
stat.LastCallTime = &lastCallTime.Time
|
||||||
|
}
|
||||||
|
result.TopTools = append(result.TopTools, &stat)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadToolExecutionListPage 分页加载执行记录列表(不含 arguments/result,供监控列表使用)
|
||||||
|
func (db *DB) LoadToolExecutionListPage(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
if limit > 100 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT id, tool_name, status, start_time, end_time, duration_ms
|
||||||
|
FROM tool_executions
|
||||||
|
`
|
||||||
|
whereSQL, args := toolExecutionsFilterSQL(status, toolName)
|
||||||
|
query += whereSQL + ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
|
||||||
|
args = append(args, limit, offset)
|
||||||
|
|
||||||
|
rows, err := db.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
executions := make([]*mcp.ToolExecution, 0, limit)
|
||||||
|
for rows.Next() {
|
||||||
|
var exec mcp.ToolExecution
|
||||||
|
var endTime sql.NullTime
|
||||||
|
var durationMs sql.NullInt64
|
||||||
|
|
||||||
|
if err := rows.Scan(
|
||||||
|
&exec.ID,
|
||||||
|
&exec.ToolName,
|
||||||
|
&exec.Status,
|
||||||
|
&exec.StartTime,
|
||||||
|
&endTime,
|
||||||
|
&durationMs,
|
||||||
|
); err != nil {
|
||||||
|
db.logger.Warn("加载执行记录列表失败", zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if endTime.Valid {
|
||||||
|
exec.EndTime = &endTime.Time
|
||||||
|
}
|
||||||
|
if durationMs.Valid {
|
||||||
|
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||||||
|
}
|
||||||
|
executions = append(executions, &exec)
|
||||||
|
}
|
||||||
|
|
||||||
|
return executions, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetToolExecution 根据ID获取单条工具执行记录
|
// GetToolExecution 根据ID获取单条工具执行记录
|
||||||
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
||||||
query := `
|
query := `
|
||||||
@@ -288,6 +448,93 @@ func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
|||||||
return &exec, nil
|
return &exec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelOrphanedRunningToolExecutions 将仍为 running 的记录批量标记为 cancelled(如进程重启后无对应执行协程)。
|
||||||
|
func (db *DB) CancelOrphanedRunningToolExecutions(endTime time.Time, errMsg string) (int64, error) {
|
||||||
|
errMsg = strings.TrimSpace(errMsg)
|
||||||
|
if errMsg == "" {
|
||||||
|
errMsg = "执行已中断(服务重启或会话结束)"
|
||||||
|
}
|
||||||
|
query := `
|
||||||
|
UPDATE tool_executions
|
||||||
|
SET status = 'cancelled',
|
||||||
|
error = ?,
|
||||||
|
end_time = ?,
|
||||||
|
duration_ms = MAX(0, CAST((julianday(?) - julianday(start_time)) * 86400000 AS INTEGER))
|
||||||
|
WHERE status = 'running'
|
||||||
|
`
|
||||||
|
res, err := db.Exec(query, errMsg, endTime, endTime)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeStaleRunningToolExecutions 将「非活跃且超过 minAge」的 running 记录标记为 cancelled。
|
||||||
|
// activeIDs 为当前进程内仍登记 cancel 的 executionId;不在集合内且已超时的视为孤儿记录。
|
||||||
|
func (db *DB) FinalizeStaleRunningToolExecutions(endTime time.Time, minAge time.Duration, activeIDs map[string]struct{}, errMsg string) (int64, error) {
|
||||||
|
errMsg = strings.TrimSpace(errMsg)
|
||||||
|
if errMsg == "" {
|
||||||
|
errMsg = "执行已中断(会话已结束)"
|
||||||
|
}
|
||||||
|
if minAge < 0 {
|
||||||
|
minAge = 0
|
||||||
|
}
|
||||||
|
cutoff := endTime.Add(-minAge)
|
||||||
|
rows, err := db.Query(`
|
||||||
|
SELECT id, start_time FROM tool_executions
|
||||||
|
WHERE status = 'running' AND start_time <= ?
|
||||||
|
`, cutoff)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
type staleRow struct {
|
||||||
|
id string
|
||||||
|
startTime time.Time
|
||||||
|
}
|
||||||
|
var stale []staleRow
|
||||||
|
for rows.Next() {
|
||||||
|
var row staleRow
|
||||||
|
if err := rows.Scan(&row.id, &row.startTime); err != nil {
|
||||||
|
db.logger.Warn("读取 stale running 执行记录失败", zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if activeIDs != nil {
|
||||||
|
if _, active := activeIDs[row.id]; active {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stale = append(stale, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if len(stale) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var affected int64
|
||||||
|
for _, row := range stale {
|
||||||
|
durationMs := endTime.Sub(row.startTime).Milliseconds()
|
||||||
|
if durationMs < 0 {
|
||||||
|
durationMs = 0
|
||||||
|
}
|
||||||
|
res, err := db.Exec(`
|
||||||
|
UPDATE tool_executions
|
||||||
|
SET status = 'cancelled', error = ?, end_time = ?, duration_ms = ?
|
||||||
|
WHERE id = ? AND status = 'running'
|
||||||
|
`, errMsg, endTime, durationMs, row.id)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Warn("更新 stale running 执行记录失败", zap.Error(err), zap.String("executionId", row.id))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
affected += n
|
||||||
|
}
|
||||||
|
return affected, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteToolExecution 删除工具执行记录
|
// DeleteToolExecution 删除工具执行记录
|
||||||
func (db *DB) DeleteToolExecution(id string) error {
|
func (db *DB) DeleteToolExecution(id string) error {
|
||||||
query := `DELETE FROM tool_executions WHERE id = ?`
|
query := `DELETE FROM tool_executions WHERE id = ?`
|
||||||
@@ -600,13 +847,28 @@ func truncateCallsTimelineBucket(t time.Time, dailyBuckets bool) time.Time {
|
|||||||
|
|
||||||
// LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界)
|
// LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界)
|
||||||
func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) {
|
func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) {
|
||||||
// 在 Go 侧按本地时区分桶,避免 SQLite strftime 对 UTC 存储时间分桶后再误当本地时间解析(差 8h 等问题)
|
var query string
|
||||||
query := `
|
if dailyBuckets {
|
||||||
SELECT start_time,
|
query = `
|
||||||
CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END AS failed
|
SELECT date(start_time, 'localtime') AS bucket,
|
||||||
FROM tool_executions
|
COUNT(*) AS total,
|
||||||
WHERE start_time >= ?
|
SUM(CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END) AS failed
|
||||||
`
|
FROM tool_executions
|
||||||
|
WHERE start_time >= ?
|
||||||
|
GROUP BY bucket
|
||||||
|
ORDER BY bucket
|
||||||
|
`
|
||||||
|
} else {
|
||||||
|
query = `
|
||||||
|
SELECT strftime('%Y-%m-%d %H:00:00', start_time, 'localtime') AS bucket,
|
||||||
|
COUNT(*) AS total,
|
||||||
|
SUM(CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END) AS failed
|
||||||
|
FROM tool_executions
|
||||||
|
WHERE start_time >= ?
|
||||||
|
GROUP BY bucket
|
||||||
|
ORDER BY bucket
|
||||||
|
`
|
||||||
|
}
|
||||||
|
|
||||||
rows, err := db.Query(query, since)
|
rows, err := db.Query(query, since)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -614,35 +876,35 @@ func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTime
|
|||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
bucketMap := make(map[time.Time]struct{ total, failed int })
|
buckets := make([]CallsTimelineBucket, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var startTime time.Time
|
var bucketStr string
|
||||||
var failed int
|
var total, failed int
|
||||||
if err := rows.Scan(&startTime, &failed); err != nil {
|
if err := rows.Scan(&bucketStr, &total, &failed); err != nil {
|
||||||
db.logger.Warn("加载调用趋势失败", zap.Error(err))
|
db.logger.Warn("加载调用趋势失败", zap.Error(err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
key := truncateCallsTimelineBucket(startTime, dailyBuckets)
|
bucketTime, err := parseCallsTimelineBucket(bucketStr, dailyBuckets)
|
||||||
entry := bucketMap[key]
|
if err != nil {
|
||||||
entry.total++
|
db.logger.Warn("解析调用趋势时间桶失败", zap.Error(err), zap.String("bucket", bucketStr))
|
||||||
entry.failed += failed
|
continue
|
||||||
bucketMap[key] = entry
|
}
|
||||||
}
|
|
||||||
|
|
||||||
buckets := make([]CallsTimelineBucket, 0, len(bucketMap))
|
|
||||||
for bucketTime, counts := range bucketMap {
|
|
||||||
buckets = append(buckets, CallsTimelineBucket{
|
buckets = append(buckets, CallsTimelineBucket{
|
||||||
BucketTime: bucketTime,
|
BucketTime: bucketTime,
|
||||||
Total: counts.total,
|
Total: total,
|
||||||
Failed: counts.failed,
|
Failed: failed,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
sort.Slice(buckets, func(i, j int) bool {
|
|
||||||
return buckets[i].BucketTime.Before(buckets[j].BucketTime)
|
|
||||||
})
|
|
||||||
return buckets, nil
|
return buckets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseCallsTimelineBucket(bucketStr string, dailyBuckets bool) (time.Time, error) {
|
||||||
|
if dailyBuckets {
|
||||||
|
return time.ParseInLocation("2006-01-02", bucketStr, time.Local)
|
||||||
|
}
|
||||||
|
return time.ParseInLocation("2006-01-02 15:04:05", bucketStr, time.Local)
|
||||||
|
}
|
||||||
|
|
||||||
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
||||||
// 如果统计信息变为0,则删除该统计记录
|
// 如果统计信息变为0,则删除该统计记录
|
||||||
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
||||||
|
|||||||
@@ -0,0 +1,102 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCancelOrphanedRunningToolExecutions(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
start := time.Now().Add(-2 * time.Hour)
|
||||||
|
exec := &mcp.ToolExecution{
|
||||||
|
ID: "orphan-hydra",
|
||||||
|
ToolName: "hydra",
|
||||||
|
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||||
|
Status: "running",
|
||||||
|
StartTime: start,
|
||||||
|
}
|
||||||
|
if err := db.SaveToolExecution(exec); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
end := time.Now()
|
||||||
|
n, err := db.CancelOrphanedRunningToolExecutions(end, "执行已中断(服务重启)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CancelOrphanedRunningToolExecutions: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("expected 1 row updated, got %d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := db.GetToolExecution("orphan-hydra")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
if got.Status != "cancelled" {
|
||||||
|
t.Fatalf("expected cancelled, got %s", got.Status)
|
||||||
|
}
|
||||||
|
if got.EndTime == nil {
|
||||||
|
t.Fatal("expected end_time to be set")
|
||||||
|
}
|
||||||
|
if got.Duration <= 0 {
|
||||||
|
t.Fatalf("expected positive duration, got %v", got.Duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeStaleRunningToolExecutions_skipsActive(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
oldStart := now.Add(-5 * time.Minute)
|
||||||
|
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||||
|
ID: "stale", ToolName: "hydra", Status: "running", StartTime: oldStart,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution stale: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||||
|
ID: "active", ToolName: "hydra", Status: "running", StartTime: oldStart,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution active: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
active := map[string]struct{}{"active": {}}
|
||||||
|
n, err := db.FinalizeStaleRunningToolExecutions(now, time.Minute, active, "执行已中断(会话已结束)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FinalizeStaleRunningToolExecutions: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("expected 1 stale row updated, got %d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
stale, err := db.GetToolExecution("stale")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetToolExecution stale: %v", err)
|
||||||
|
}
|
||||||
|
if stale.Status != "cancelled" {
|
||||||
|
t.Fatalf("stale expected cancelled, got %s", stale.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
activeExec, err := db.GetToolExecution("active")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetToolExecution active: %v", err)
|
||||||
|
}
|
||||||
|
if activeExec.Status != "running" {
|
||||||
|
t.Fatalf("active expected running, got %s", activeExec.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadToolStatsSummaryAndListPage(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor-summary.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
tools := []struct {
|
||||||
|
name string
|
||||||
|
calls int
|
||||||
|
ok int
|
||||||
|
fail int
|
||||||
|
result string
|
||||||
|
}{
|
||||||
|
{"alpha::run", 10, 9, 1, `{"content":[{"type":"text","text":"` + string(make([]byte, 64*1024)) + `"}]}`},
|
||||||
|
{"beta::scan", 5, 5, 0, `{"content":[{"type":"text","text":"ok"}]}`},
|
||||||
|
{"gamma::ping", 1, 1, 0, `{"content":[{"type":"text","text":"pong"}]}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
if err := db.UpdateToolStats(tool.name, tool.calls, tool.ok, tool.fail, &now); err != nil {
|
||||||
|
t.Fatalf("UpdateToolStats(%s): %v", tool.name, err)
|
||||||
|
}
|
||||||
|
for j := 0; j < tool.calls; j++ {
|
||||||
|
exec := &mcp.ToolExecution{
|
||||||
|
ID: fmt.Sprintf("%s-exec-%d", tool.name, j),
|
||||||
|
ToolName: tool.name,
|
||||||
|
Arguments: map[string]interface{}{"n": j},
|
||||||
|
Status: "completed",
|
||||||
|
StartTime: now.Add(-time.Duration(j) * time.Minute),
|
||||||
|
Result: &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: tool.result}}},
|
||||||
|
}
|
||||||
|
end := exec.StartTime.Add(time.Second)
|
||||||
|
exec.EndTime = &end
|
||||||
|
exec.Duration = time.Second
|
||||||
|
if err := db.SaveToolExecution(exec); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := db.LoadToolStatsSummary(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadToolStatsSummary: %v", err)
|
||||||
|
}
|
||||||
|
if summary.Summary.ToolCount != 3 {
|
||||||
|
t.Fatalf("toolCount = %d, want 3", summary.Summary.ToolCount)
|
||||||
|
}
|
||||||
|
if summary.Summary.TotalCalls != 16 {
|
||||||
|
t.Fatalf("totalCalls = %d, want 16", summary.Summary.TotalCalls)
|
||||||
|
}
|
||||||
|
if len(summary.TopTools) != 2 {
|
||||||
|
t.Fatalf("top tools = %d, want 2", len(summary.TopTools))
|
||||||
|
}
|
||||||
|
if summary.TopTools[0].ToolName != "alpha::run" {
|
||||||
|
t.Fatalf("top tool = %q, want alpha::run", summary.TopTools[0].ToolName)
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := db.LoadToolExecutionListPage(0, 5, "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadToolExecutionListPage: %v", err)
|
||||||
|
}
|
||||||
|
if len(list) != 5 {
|
||||||
|
t.Fatalf("list len = %d, want 5", len(list))
|
||||||
|
}
|
||||||
|
for _, exec := range list {
|
||||||
|
if exec.Arguments != nil || exec.Result != nil || exec.Error != "" {
|
||||||
|
t.Fatalf("expected lite execution row, got args/result/error on %s", exec.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,8 +2,8 @@ package einomcp
|
|||||||
|
|
||||||
import "sync"
|
import "sync"
|
||||||
|
|
||||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire,
|
// ToolInvokeNotifyHolder 由 Eino run loop 与 MCP/execute 桥共享;Fire 在工具原始返回时触发。
|
||||||
// 用于清除 pending tool_call(tool_result 由 ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)。
|
// UI 的 tool_result 须等 ADK schema.Tool 事件(reduction 后正文),不在此 holder 的回调里推送。
|
||||||
type ToolInvokeNotifyHolder struct {
|
type ToolInvokeNotifyHolder struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||||
|
|||||||
+190
-401
@@ -21,7 +21,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
"cyberstrike-ai/internal/mcp"
|
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
@@ -78,6 +77,13 @@ type responsePlanAgg struct {
|
|||||||
b strings.Builder
|
b strings.Builder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// thinkingBuf aggregates thinking_stream_* / reasoning_chain_stream_* before flush to process_details.
|
||||||
|
type thinkingBuf struct {
|
||||||
|
b strings.Builder
|
||||||
|
meta map[string]interface{}
|
||||||
|
persistAs string // "thinking" | "reasoning_chain"
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeProcessDetailText(s string) string {
|
func normalizeProcessDetailText(s string) string {
|
||||||
s = strings.ReplaceAll(s, "\r\n", "\n")
|
s = strings.ReplaceAll(s, "\r\n", "\n")
|
||||||
s = strings.ReplaceAll(s, "\r", "\n")
|
s = strings.ReplaceAll(s, "\r", "\n")
|
||||||
@@ -178,10 +184,10 @@ type AgentHandler struct {
|
|||||||
}
|
}
|
||||||
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
||||||
batchCronParser cron.Parser
|
batchCronParser cron.Parser
|
||||||
batchRunnerMu sync.Mutex
|
|
||||||
batchRunning map[string]struct{}
|
|
||||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||||
|
hitlStrategySaver HitlAuditStrategySaver
|
||||||
|
auditLLM *openai.Client
|
||||||
audit *audit.Service
|
audit *audit.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,14 +196,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
|||||||
h.audit = s
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TaskManager 返回 Agent 任务管理器(供 MCP 监控页终止 Eino execute 等)。
|
||||||
|
func (h *AgentHandler) TaskManager() *AgentTaskManager {
|
||||||
|
if h == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return h.tasks
|
||||||
|
}
|
||||||
|
|
||||||
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||||
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||||
if h == nil || conversationID == "" || h.tasks == nil {
|
if h == nil || conversationID == "" || h.tasks == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
h.cancelActiveMCPToolForConversation(conversationID)
|
||||||
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
h.tasks.AbortActiveEinoExecute(conversationID, "")
|
||||||
}
|
|
||||||
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
||||||
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -205,9 +218,19 @@ func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string) {
|
||||||
|
if h == nil || h.tasks == nil || h.agent == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
||||||
|
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlToolWhitelistSaver 合并/设置 HITL 免审批工具到全局配置并落盘
|
||||||
type HitlToolWhitelistSaver interface {
|
type HitlToolWhitelistSaver interface {
|
||||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||||
|
SetHitlToolWhitelist(tools []string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAgentHandler 创建新的Agent处理器
|
// NewAgentHandler 创建新的Agent处理器
|
||||||
@@ -223,6 +246,11 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
|||||||
bus := NewTaskEventBus()
|
bus := NewTaskEventBus()
|
||||||
tm := NewAgentTaskManager()
|
tm := NewAgentTaskManager()
|
||||||
tm.SetTaskEventBus(bus)
|
tm.SetTaskEventBus(bus)
|
||||||
|
llmHTTP := &http.Client{Timeout: 2 * time.Minute}
|
||||||
|
var llmCfg *config.OpenAIConfig
|
||||||
|
if cfg != nil {
|
||||||
|
llmCfg = &cfg.OpenAI
|
||||||
|
}
|
||||||
handler := &AgentHandler{
|
handler := &AgentHandler{
|
||||||
agent: agent,
|
agent: agent,
|
||||||
db: db,
|
db: db,
|
||||||
@@ -233,8 +261,9 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
|||||||
config: cfg,
|
config: cfg,
|
||||||
hitlManager: NewHITLManager(db, logger),
|
hitlManager: NewHITLManager(db, logger),
|
||||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||||
batchRunning: make(map[string]struct{}),
|
auditLLM: openai.NewClient(llmCfg, llmHTTP, logger),
|
||||||
}
|
}
|
||||||
|
tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation)
|
||||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||||
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
@@ -307,6 +336,7 @@ func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientInten
|
|||||||
type HITLRequest struct {
|
type HITLRequest struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Mode string `json:"mode,omitempty"`
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Reviewer string `json:"reviewer,omitempty"` // human | audit_agent
|
||||||
SensitiveTools []string `json:"sensitiveTools,omitempty"`
|
SensitiveTools []string `json:"sensitiveTools,omitempty"`
|
||||||
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
TimeoutSeconds int `json:"timeoutSeconds,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -648,7 +678,7 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
|
|||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
||||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.agentSessionContextBlock(conversationID),
|
||||||
)
|
)
|
||||||
if errMA != nil {
|
if errMA != nil {
|
||||||
*taskStatus = "failed"
|
*taskStatus = "failed"
|
||||||
@@ -669,7 +699,7 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
|
|||||||
resultMA, errMA := multiagent.RunDeepAgent(
|
resultMA, errMA := multiagent.RunDeepAgent(
|
||||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
||||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
h.agentsMarkdownDir, orchestration, nil, h.agentSessionContextBlock(conversationID),
|
||||||
)
|
)
|
||||||
if errMA != nil {
|
if errMA != nil {
|
||||||
*taskStatus = "failed"
|
*taskStatus = "failed"
|
||||||
@@ -836,11 +866,6 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
|
|
||||||
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||||
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||||
type thinkingBuf struct {
|
|
||||||
b strings.Builder
|
|
||||||
meta map[string]interface{}
|
|
||||||
persistAs string // "thinking" | "reasoning_chain"
|
|
||||||
}
|
|
||||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||||
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
seenToolCallSigs := make(map[string]string) // toolCallId -> payload signature
|
||||||
@@ -853,6 +878,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
// response_start + response_delta:前端时间线显示为「📝 规划中」(monitor.js),不落逐条 delta;
|
||||||
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
// 聚合为一条 planning 写入 process_details,刷新后与线上一致。
|
||||||
var respPlan responsePlanAgg
|
var respPlan responsePlanAgg
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
h.tasks.SetHitlAssistantMessageID(conversationID, assistantMessageID)
|
||||||
|
}
|
||||||
|
syncHitlCognition := func() {
|
||||||
|
h.syncHitlCognitionFromProgress(conversationID, assistantMessageID, thinkingStreams, &respPlan)
|
||||||
|
}
|
||||||
flushResponsePlan := func() {
|
flushResponsePlan := func() {
|
||||||
if assistantMessageID == "" {
|
if assistantMessageID == "" {
|
||||||
return
|
return
|
||||||
@@ -872,6 +903,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil {
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "planning", content, data); err != nil {
|
||||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning"))
|
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "planning"))
|
||||||
}
|
}
|
||||||
|
syncHitlCognition()
|
||||||
respPlan.meta = nil
|
respPlan.meta = nil
|
||||||
respPlan.b.Reset()
|
respPlan.b.Reset()
|
||||||
}
|
}
|
||||||
@@ -908,6 +940,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
flushedThinking[sid] = true
|
flushedThinking[sid] = true
|
||||||
}
|
}
|
||||||
|
syncHitlCognition()
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(eventType, message string, data interface{}) {
|
return func(eventType, message string, data interface{}) {
|
||||||
@@ -968,6 +1001,25 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if eventType == "tool_result" {
|
||||||
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
|
toolName, _ := dataMap["toolName"].(string)
|
||||||
|
toolCallID, _ := dataMap["toolCallId"].(string)
|
||||||
|
success := true
|
||||||
|
if v, ok := dataMap["success"].(bool); ok {
|
||||||
|
success = v
|
||||||
|
}
|
||||||
|
resultText := ""
|
||||||
|
if r, ok := dataMap["result"].(string); ok {
|
||||||
|
resultText = r
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(resultText) == "" {
|
||||||
|
resultText = message
|
||||||
|
}
|
||||||
|
h.recordHitlToolExecutionResult(conversationID, toolCallID, toolName, success, resultText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 处理知识检索日志记录
|
// 处理知识检索日志记录
|
||||||
if eventType == "tool_result" && h.knowledgeManager != nil {
|
if eventType == "tool_result" && h.knowledgeManager != nil {
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
@@ -1175,6 +1227,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
respPlan.meta[k] = v
|
respPlan.meta[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
syncHitlCognition()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if eventType == "response" {
|
if eventType == "response" {
|
||||||
@@ -1244,6 +1297,7 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
syncHitlCognition()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1295,10 +1349,60 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cancelToolContinueAfter 仅终止当前工具调用,不停止整条 Agent 任务(对话「中断并继续」与 MCP 监控终止共用)。
|
||||||
|
func (h *AgentHandler) cancelToolContinueAfter(conversationID, preferredExecID, note string) (bool, gin.H) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" || h.tasks.GetTask(conversationID) == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
note = strings.TrimSpace(note)
|
||||||
|
execID := strings.TrimSpace(preferredExecID)
|
||||||
|
if execID == "" {
|
||||||
|
execID = h.tasks.ActiveMCPExecutionID(conversationID)
|
||||||
|
}
|
||||||
|
if execID != "" {
|
||||||
|
if h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||||
|
return true, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"executionId": execID,
|
||||||
|
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
|
||||||
|
return true, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"executionId": execID,
|
||||||
|
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
|
||||||
|
return true, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CancelAgentLoop 取消正在执行的任务
|
// CancelAgentLoop 取消正在执行的任务
|
||||||
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
ConversationID string `json:"conversationId" binding:"required"`
|
ConversationID string `json:"conversationId" binding:"required"`
|
||||||
|
ExecutionID string `json:"executionId,omitempty"`
|
||||||
Reason string `json:"reason,omitempty"`
|
Reason string `json:"reason,omitempty"`
|
||||||
ContinueAfter bool `json:"continueAfter,omitempty"`
|
ContinueAfter bool `json:"continueAfter,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -1313,42 +1417,20 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
|
||||||
note := strings.TrimSpace(req.Reason)
|
note := strings.TrimSpace(req.Reason)
|
||||||
if execID != "" {
|
activeExec := strings.TrimSpace(h.tasks.ActiveMCPExecutionID(req.ConversationID))
|
||||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
if ok, payload := h.cancelToolContinueAfter(req.ConversationID, strings.TrimSpace(req.ExecutionID), note); ok {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
execID, _ := payload["executionId"].(string)
|
||||||
return
|
h.logger.Info("对话页仅终止当前工具",
|
||||||
}
|
|
||||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
|
||||||
zap.String("conversationId", req.ConversationID),
|
zap.String("conversationId", req.ConversationID),
|
||||||
zap.String("executionId", execID),
|
zap.String("executionId", execID),
|
||||||
zap.Bool("hasNote", note != ""),
|
zap.Bool("hasNote", note != ""),
|
||||||
)
|
)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, payload)
|
||||||
"status": "tool_abort_requested",
|
|
||||||
"conversationId": req.ConversationID,
|
|
||||||
"executionId": execID,
|
|
||||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
|
||||||
"continueAfter": true,
|
|
||||||
"interruptWithNote": note != "",
|
|
||||||
"continueWithoutTool": false,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
|
if activeExec != "" {
|
||||||
h.logger.Info("对话页仅终止当前 Eino execute",
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||||
zap.String("conversationId", req.ConversationID),
|
|
||||||
zap.Bool("hasNote", note != ""),
|
|
||||||
)
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"status": "tool_abort_requested",
|
|
||||||
"conversationId": req.ConversationID,
|
|
||||||
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
|
||||||
"continueAfter": true,
|
|
||||||
"interruptWithNote": note != "",
|
|
||||||
"continueWithoutTool": false,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||||
@@ -1380,6 +1462,8 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
|||||||
|
|
||||||
var cause error = ErrTaskCancelled
|
var cause error = ErrTaskCancelled
|
||||||
msg := "已提交取消请求,任务将在当前步骤完成后停止。"
|
msg := "已提交取消请求,任务将在当前步骤完成后停止。"
|
||||||
|
h.cancelActiveMCPToolForConversation(req.ConversationID)
|
||||||
|
h.tasks.AbortActiveEinoExecute(req.ConversationID, "")
|
||||||
ok, err := h.tasks.CancelTask(req.ConversationID, cause)
|
ok, err := h.tasks.CancelTask(req.ConversationID, cause)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("取消任务失败", zap.Error(err))
|
h.logger.Error("取消任务失败", zap.Error(err))
|
||||||
@@ -1446,17 +1530,51 @@ func (h *AgentHandler) SubscribeAgentTaskEvents(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// enrichAgentTasksWithConversationTitles 为任务列表附加当前会话标题(供顶栏/任务页展示,重命名后自动同步)
|
||||||
|
func (h *AgentHandler) enrichAgentTasksWithConversationTitles(tasks []*AgentTask) {
|
||||||
|
if h == nil || h.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task == nil || strings.TrimSpace(task.ConversationID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if title, err := h.db.GetConversationTitle(task.ConversationID); err == nil {
|
||||||
|
task.Title = strings.TrimSpace(title)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// enrichCompletedTasksWithConversationTitles 为已完成任务附加当前会话标题
|
||||||
|
func (h *AgentHandler) enrichCompletedTasksWithConversationTitles(tasks []*CompletedTask) {
|
||||||
|
if h == nil || h.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task == nil || strings.TrimSpace(task.ConversationID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if title, err := h.db.GetConversationTitle(task.ConversationID); err == nil {
|
||||||
|
task.Title = strings.TrimSpace(title)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListAgentTasks 列出所有运行中的任务
|
// ListAgentTasks 列出所有运行中的任务
|
||||||
func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||||
|
tasks := h.tasks.GetActiveTasks()
|
||||||
|
h.enrichAgentTasksWithConversationTitles(tasks)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"tasks": h.tasks.GetActiveTasks(),
|
"tasks": tasks,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListCompletedTasks 列出最近完成的任务历史
|
// ListCompletedTasks 列出最近完成的任务历史
|
||||||
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
||||||
|
tasks := h.tasks.GetCompletedTasks()
|
||||||
|
h.enrichCompletedTasksWithConversationTitles(tasks)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"tasks": h.tasks.GetCompletedTasks(),
|
"tasks": tasks,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1470,6 +1588,7 @@ type BatchTaskRequest struct {
|
|||||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||||
|
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
|
||||||
}
|
}
|
||||||
|
|
||||||
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。
|
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。
|
||||||
@@ -1529,7 +1648,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
|||||||
nextRunAt = &next
|
nextRunAt = &next
|
||||||
}
|
}
|
||||||
|
|
||||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||||
return
|
return
|
||||||
@@ -1719,15 +1838,16 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
|||||||
func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
|
func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
var req struct {
|
var req struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
AgentMode string `json:"agentMode"`
|
AgentMode string `json:"agentMode"`
|
||||||
|
Concurrency *int `json:"concurrency"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil {
|
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1802,9 +1922,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
|
|||||||
// DeleteBatchQueue 删除批量任务队列
|
// DeleteBatchQueue 删除批量任务队列
|
||||||
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
success := h.batchTaskManager.DeleteQueue(queueID)
|
if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
|
||||||
if !success {
|
switch {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
case errors.Is(err, ErrBatchQueueNotFound):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
|
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||||
|
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
|
||||||
|
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||||
|
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.audit != nil {
|
if h.audit != nil {
|
||||||
@@ -1898,7 +2026,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
|
|||||||
|
|
||||||
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
||||||
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
||||||
h.forceUnmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
|
||||||
}
|
}
|
||||||
|
|
||||||
autoStarted := true
|
autoStarted := true
|
||||||
@@ -1957,26 +2085,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
|
|
||||||
h.batchRunnerMu.Lock()
|
|
||||||
defer h.batchRunnerMu.Unlock()
|
|
||||||
if _, exists := h.batchRunning[queueID]; exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
h.batchRunning[queueID] = struct{}{}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
|
||||||
h.batchRunnerMu.Lock()
|
|
||||||
defer h.batchRunnerMu.Unlock()
|
|
||||||
delete(h.batchRunning, queueID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
|
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||||
expr := strings.TrimSpace(cronExpr)
|
expr := strings.TrimSpace(cronExpr)
|
||||||
if expr == "" {
|
if expr == "" {
|
||||||
@@ -1992,43 +2100,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
|
|||||||
|
|
||||||
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
||||||
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
||||||
if !h.markBatchQueueRunning(queueID) {
|
if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
if !exists {
|
if !exists {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if scheduled {
|
if scheduled {
|
||||||
if queue.ScheduleMode != "cron" {
|
if queue.ScheduleMode != "cron" {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("队列未启用 cron 调度")
|
err := fmt.Errorf("队列未启用 cron 调度")
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("重置队列失败")
|
err := fmt.Errorf("重置队列失败")
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
} else if queue.Status != "pending" && queue.Status != "paused" {
|
} else if queue.Status != "pending" && queue.Status != "paused" {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
return true, fmt.Errorf("队列状态不允许启动")
|
return true, fmt.Errorf("队列状态不允许启动")
|
||||||
}
|
}
|
||||||
|
|
||||||
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
|
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
|
||||||
if scheduled {
|
if scheduled {
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
@@ -2080,325 +2188,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeBatchQueue 执行批量任务队列
|
|
||||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|
||||||
defer h.unmarkBatchQueueRunning(queueID)
|
|
||||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
|
|
||||||
|
|
||||||
for {
|
|
||||||
// 检查队列状态
|
|
||||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
|
||||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取下一个任务
|
|
||||||
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
|
|
||||||
if !hasNext {
|
|
||||||
// 所有任务完成:汇总子任务失败信息便于排障
|
|
||||||
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
|
|
||||||
lastRunErr := ""
|
|
||||||
if ok {
|
|
||||||
for _, t := range q.Tasks {
|
|
||||||
if t.Status == "failed" && t.Error != "" {
|
|
||||||
lastRunErr = t.Error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
|
||||||
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
|
|
||||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务状态为运行中
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
|
|
||||||
|
|
||||||
// 创建新对话
|
|
||||||
title := safeTruncateString(task.Message, 50)
|
|
||||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
|
||||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
|
||||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
|
||||||
var conversationID string
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
|
||||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
|
||||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
conversationID = conv.ID
|
|
||||||
|
|
||||||
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
|
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
|
|
||||||
|
|
||||||
// 应用角色用户提示词和工具配置
|
|
||||||
finalMessage := task.Message
|
|
||||||
var roleTools []string // 角色配置的工具列表
|
|
||||||
if queue.Role != "" && queue.Role != "默认" {
|
|
||||||
if h.config.Roles != nil {
|
|
||||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
|
||||||
// 应用用户提示词
|
|
||||||
if role.UserPrompt != "" {
|
|
||||||
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
|
||||||
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
|
||||||
}
|
|
||||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
|
||||||
if len(role.Tools) > 0 {
|
|
||||||
roleTools = role.Tools
|
|
||||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
|
||||||
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 预先创建助手消息,以便关联过程详情
|
|
||||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
// 如果创建失败,继续执行但不保存过程详情
|
|
||||||
assistantMsg = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
|
|
||||||
var assistantMessageID string
|
|
||||||
if assistantMsg != nil {
|
|
||||||
assistantMessageID = assistantMsg.ID
|
|
||||||
}
|
|
||||||
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
|
|
||||||
// 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。
|
|
||||||
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
|
|
||||||
var progressCallback func(eventType, message string, data interface{})
|
|
||||||
|
|
||||||
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
|
|
||||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
|
||||||
|
|
||||||
func() {
|
|
||||||
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
|
|
||||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
|
||||||
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
|
|
||||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
|
||||||
|
|
||||||
registered := false
|
|
||||||
finishStatus := "completed"
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
|
||||||
timeoutCancel()
|
|
||||||
if registered {
|
|
||||||
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
|
|
||||||
if h.taskEventBus != nil {
|
|
||||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
|
||||||
if b, err := json.Marshal(ev); err == nil {
|
|
||||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.tasks.FinishTask(conversationID, finishStatus)
|
|
||||||
}
|
|
||||||
cancelWithCause(nil)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
|
|
||||||
sendEvent := func(eventType, message string, data interface{}) {
|
|
||||||
if h.taskEventBus == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
|
||||||
b, err := json.Marshal(ev)
|
|
||||||
if err != nil {
|
|
||||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
|
||||||
}
|
|
||||||
line := make([]byte, 0, len(b)+8)
|
|
||||||
line = append(line, []byte("data: ")...)
|
|
||||||
line = append(line, b...)
|
|
||||||
line = append(line, '\n', '\n')
|
|
||||||
h.taskEventBus.Publish(conversationID, line)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
|
||||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
|
||||||
zap.String("queueId", queueID),
|
|
||||||
zap.String("taskId", task.ID),
|
|
||||||
zap.String("conversationId", conversationID),
|
|
||||||
zap.Error(err))
|
|
||||||
failMsg := err.Error()
|
|
||||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
|
||||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
|
||||||
}
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
registered = true
|
|
||||||
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
|
|
||||||
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
|
|
||||||
|
|
||||||
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
|
|
||||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
|
||||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
|
||||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
|
||||||
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
|
||||||
|
|
||||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
|
||||||
useBatchMulti := false
|
|
||||||
batchOrch := "deep"
|
|
||||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
|
||||||
if am == "multi" {
|
|
||||||
am = "deep"
|
|
||||||
}
|
|
||||||
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
|
||||||
useBatchMulti = true
|
|
||||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
|
||||||
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
|
||||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
|
||||||
useBatchMulti = true
|
|
||||||
batchOrch = "deep"
|
|
||||||
}
|
|
||||||
var resultMA *multiagent.RunResult
|
|
||||||
var runErr error
|
|
||||||
switch {
|
|
||||||
case useBatchMulti:
|
|
||||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
|
||||||
default:
|
|
||||||
if h.config == nil {
|
|
||||||
runErr = fmt.Errorf("服务器配置未加载")
|
|
||||||
} else {
|
|
||||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if runErr != nil {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
|
||||||
}
|
|
||||||
errStr := runErr.Error()
|
|
||||||
partialResp := ""
|
|
||||||
if resultMA != nil {
|
|
||||||
partialResp = resultMA.Response
|
|
||||||
}
|
|
||||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
|
||||||
errors.Is(runErr, context.Canceled) ||
|
|
||||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
|
||||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
|
||||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
|
||||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
|
||||||
|
|
||||||
if isTimeout {
|
|
||||||
finishStatus = "timeout"
|
|
||||||
} else if isCancelled {
|
|
||||||
finishStatus = "cancelled"
|
|
||||||
} else {
|
|
||||||
finishStatus = "failed"
|
|
||||||
}
|
|
||||||
|
|
||||||
if isCancelled {
|
|
||||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
|
||||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
|
||||||
// 如果执行结果中有更具体的取消消息,使用它
|
|
||||||
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
|
||||||
cancelMsg = partialResp
|
|
||||||
}
|
|
||||||
// 更新助手消息内容
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
|
||||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
|
||||||
}
|
|
||||||
// 保存取消详情到数据库
|
|
||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
|
||||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果没有预先创建的助手消息,创建一个新的
|
|
||||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
|
||||||
if errMsg != nil {
|
|
||||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
|
||||||
} else {
|
|
||||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
|
||||||
errorMsg := "执行失败: " + runErr.Error()
|
|
||||||
// 更新助手消息内容
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
if _, updateErr := h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
|
||||||
errorMsg,
|
|
||||||
time.Now(), assistantMessageID,
|
|
||||||
); updateErr != nil {
|
|
||||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
|
||||||
}
|
|
||||||
// 保存错误详情到数据库
|
|
||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
|
||||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
|
||||||
|
|
||||||
resText := resultMA.Response
|
|
||||||
mcpIDs := resultMA.MCPExecutionIDs
|
|
||||||
lastIn := resultMA.LastAgentTraceInput
|
|
||||||
lastOut := resultMA.LastAgentTraceOutput
|
|
||||||
|
|
||||||
// 更新助手消息内容
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
|
||||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
|
||||||
// 如果更新失败,尝试创建新消息
|
|
||||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果没有预先创建的助手消息,创建一个新的
|
|
||||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存代理轨迹
|
|
||||||
if lastIn != "" || lastOut != "" {
|
|
||||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
|
||||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
} else {
|
|
||||||
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 移动到下一个任务
|
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
|
||||||
|
|
||||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
|
||||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
|
||||||
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否被取消或暂停
|
|
||||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
|
||||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
||||||
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
||||||
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
||||||
|
|||||||
@@ -0,0 +1,352 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
|
||||||
|
|
||||||
|
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
|
||||||
|
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||||
|
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
|
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
|
||||||
|
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
h.runBatchQueueWorker(queueID)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
h.tryFinalizeBatchQueue(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
|
||||||
|
for {
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if batchQueueExecutionShouldStop(queue, exists) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
|
||||||
|
if !ok {
|
||||||
|
if !h.batchTaskManager.HasRunningTasks(queueID) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(batchQueueWorkerIdlePoll)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
|
||||||
|
h.executeOneBatchSubTask(queueID, queue, task)
|
||||||
|
|
||||||
|
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
|
||||||
|
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if batchQueueExecutionShouldStop(queue, exists) {
|
||||||
|
if !exists {
|
||||||
|
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if queue.Status != BatchQueueStatusRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRunErr := ""
|
||||||
|
for _, t := range queue.Tasks {
|
||||||
|
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
|
||||||
|
lastRunErr = t.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
|
||||||
|
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
|
||||||
|
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
|
||||||
|
title := safeTruncateString(task.Message, 50)
|
||||||
|
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||||
|
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||||
|
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conversationID := conv.ID
|
||||||
|
|
||||||
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
|
||||||
|
|
||||||
|
finalMessage := task.Message
|
||||||
|
var roleTools []string
|
||||||
|
if queue.Role != "" && queue.Role != "默认" {
|
||||||
|
if h.config.Roles != nil {
|
||||||
|
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||||
|
if role.UserPrompt != "" {
|
||||||
|
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
||||||
|
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
||||||
|
}
|
||||||
|
if len(role.Tools) > 0 {
|
||||||
|
roleTools = role.Tools
|
||||||
|
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
|
||||||
|
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
assistantMsg = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var assistantMessageID string
|
||||||
|
if assistantMsg != nil {
|
||||||
|
assistantMessageID = assistantMsg.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||||
|
|
||||||
|
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||||
|
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||||
|
|
||||||
|
registered := false
|
||||||
|
finishStatus := "completed"
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
|
||||||
|
timeoutCancel()
|
||||||
|
if registered {
|
||||||
|
if h.taskEventBus != nil {
|
||||||
|
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||||
|
if b, err := json.Marshal(ev); err == nil {
|
||||||
|
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.tasks.FinishTask(conversationID, finishStatus)
|
||||||
|
}
|
||||||
|
cancelWithCause(nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sendEvent := func(eventType, message string, data interface{}) {
|
||||||
|
if h.taskEventBus == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||||
|
b, err := json.Marshal(ev)
|
||||||
|
if err != nil {
|
||||||
|
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||||
|
}
|
||||||
|
line := make([]byte, 0, len(b)+8)
|
||||||
|
line = append(line, []byte("data: ")...)
|
||||||
|
line = append(line, b...)
|
||||||
|
line = append(line, '\n', '\n')
|
||||||
|
h.taskEventBus.Publish(conversationID, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||||
|
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||||
|
zap.String("queueId", queueID),
|
||||||
|
zap.String("taskId", task.ID),
|
||||||
|
zap.String("conversationId", conversationID),
|
||||||
|
zap.Error(err))
|
||||||
|
failMsg := err.Error()
|
||||||
|
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||||
|
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||||
|
}
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registered = true
|
||||||
|
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
|
||||||
|
|
||||||
|
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
|
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
|
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||||
|
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
||||||
|
|
||||||
|
useBatchMulti := false
|
||||||
|
batchOrch := "deep"
|
||||||
|
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||||
|
if am == "multi" {
|
||||||
|
am = "deep"
|
||||||
|
}
|
||||||
|
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||||
|
useBatchMulti = true
|
||||||
|
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||||
|
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||||
|
useBatchMulti = true
|
||||||
|
batchOrch = "deep"
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultMA *multiagent.RunResult
|
||||||
|
var runErr error
|
||||||
|
switch {
|
||||||
|
case useBatchMulti:
|
||||||
|
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.agentSessionContextBlock(conversationID))
|
||||||
|
default:
|
||||||
|
if h.config == nil {
|
||||||
|
runErr = fmt.Errorf("服务器配置未加载")
|
||||||
|
} else {
|
||||||
|
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.agentSessionContextBlock(conversationID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if runErr != nil {
|
||||||
|
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMA == nil {
|
||||||
|
h.logger.Error("批量任务执行成功但无结果对象",
|
||||||
|
zap.String("queueId", queueID),
|
||||||
|
zap.String("taskId", task.ID),
|
||||||
|
zap.String("conversationId", conversationID))
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||||
|
|
||||||
|
resText := resultMA.Response
|
||||||
|
mcpIDs := resultMA.MCPExecutionIDs
|
||||||
|
lastIn := resultMA.LastAgentTraceInput
|
||||||
|
lastOut := resultMA.LastAgentTraceOutput
|
||||||
|
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||||
|
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
|
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||||
|
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||||
|
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastIn != "" || lastOut != "" {
|
||||||
|
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||||
|
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) handleBatchSubTaskRunError(
|
||||||
|
queueID string,
|
||||||
|
task *BatchTask,
|
||||||
|
conversationID, assistantMessageID string,
|
||||||
|
baseCtx, taskCtx context.Context,
|
||||||
|
resultMA *multiagent.RunResult,
|
||||||
|
runErr error,
|
||||||
|
finishStatus *string,
|
||||||
|
) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||||
|
}
|
||||||
|
errStr := runErr.Error()
|
||||||
|
partialResp := ""
|
||||||
|
if resultMA != nil {
|
||||||
|
partialResp = resultMA.Response
|
||||||
|
}
|
||||||
|
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||||
|
errors.Is(runErr, context.Canceled) ||
|
||||||
|
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||||
|
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||||
|
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||||
|
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||||
|
|
||||||
|
if isTimeout {
|
||||||
|
*finishStatus = "timeout"
|
||||||
|
} else if isCancelled {
|
||||||
|
*finishStatus = "cancelled"
|
||||||
|
} else {
|
||||||
|
*finishStatus = "failed"
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCancelled {
|
||||||
|
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||||
|
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||||
|
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
||||||
|
cancelMsg = partialResp
|
||||||
|
}
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||||
|
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
|
}
|
||||||
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||||
|
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
|
||||||
|
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||||
|
}
|
||||||
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||||
|
errorMsg := "执行失败: " + runErr.Error()
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if _, updateErr := h.db.Exec(
|
||||||
|
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||||
|
errorMsg,
|
||||||
|
time.Now(), assistantMessageID,
|
||||||
|
); updateErr != nil {
|
||||||
|
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
|
}
|
||||||
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||||
|
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,6 +18,15 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
|
||||||
|
ErrBatchQueueNotFound = errors.New("batch queue not found")
|
||||||
|
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
|
||||||
|
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
|
||||||
|
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
|
||||||
|
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
|
||||||
|
)
|
||||||
|
|
||||||
// 批量任务状态常量
|
// 批量任务状态常量
|
||||||
const (
|
const (
|
||||||
BatchQueueStatusPending = "pending"
|
BatchQueueStatusPending = "pending"
|
||||||
@@ -39,6 +49,12 @@ const (
|
|||||||
|
|
||||||
// MaxBatchQueueRoleLen 角色名最大长度
|
// MaxBatchQueueRoleLen 角色名最大长度
|
||||||
MaxBatchQueueRoleLen = 100
|
MaxBatchQueueRoleLen = 100
|
||||||
|
|
||||||
|
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
|
||||||
|
DefaultBatchQueueConcurrency = 1
|
||||||
|
|
||||||
|
// MaxBatchQueueConcurrency 批量队列最大并发数
|
||||||
|
MaxBatchQueueConcurrency = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
// BatchTask 批量任务项
|
// BatchTask 批量任务项
|
||||||
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
|
|||||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||||
LastRunError string `json:"lastRunError,omitempty"`
|
LastRunError string `json:"lastRunError,omitempty"`
|
||||||
ProjectID string `json:"projectId,omitempty"`
|
ProjectID string `json:"projectId,omitempty"`
|
||||||
|
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
|
||||||
Tasks []*BatchTask `json:"tasks"`
|
Tasks []*BatchTask `json:"tasks"`
|
||||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
|
|||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
queues map[string]*BatchTaskQueue
|
queues map[string]*BatchTaskQueue
|
||||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
|
||||||
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
||||||
|
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
|
|||||||
return &BatchTaskManager{
|
return &BatchTaskManager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
queues: make(map[string]*BatchTaskQueue),
|
queues: make(map[string]*BatchTaskQueue),
|
||||||
taskCancels: make(map[string]context.CancelFunc),
|
taskCancels: make(map[string]map[string]context.CancelFunc),
|
||||||
singleRunTasks: make(map[string]string),
|
singleRunTasks: make(map[string]string),
|
||||||
|
queueExecutors: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
|
||||||
|
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch queue.Status {
|
||||||
|
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
|
||||||
|
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if _, exists := m.queueExecutors[queueID]; exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.queueExecutors[queueID] = struct{}{}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
|
||||||
|
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.queueExecutors, queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
|
||||||
|
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
|
||||||
|
m.UnmarkQueueExecutor(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
|
||||||
|
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
_, ok := m.queueExecutors[queueID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// SetDB 设置数据库连接
|
// SetDB 设置数据库连接
|
||||||
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
|||||||
m.db = db
|
m.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeBatchQueueConcurrency 规范化队列并发数。
|
||||||
|
func normalizeBatchQueueConcurrency(n int) int {
|
||||||
|
if n < 1 {
|
||||||
|
return DefaultBatchQueueConcurrency
|
||||||
|
}
|
||||||
|
if n > MaxBatchQueueConcurrency {
|
||||||
|
return MaxBatchQueueConcurrency
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
// CreateBatchQueue 创建批量任务队列
|
// CreateBatchQueue 创建批量任务队列
|
||||||
func (m *BatchTaskManager) CreateBatchQueue(
|
func (m *BatchTaskManager) CreateBatchQueue(
|
||||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||||
nextRunAt *time.Time,
|
nextRunAt *time.Time,
|
||||||
|
concurrency int,
|
||||||
tasks []string,
|
tasks []string,
|
||||||
) (*BatchTaskQueue, error) {
|
) (*BatchTaskQueue, error) {
|
||||||
// 输入校验
|
// 输入校验
|
||||||
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
|||||||
CronExpr: strings.TrimSpace(cronExpr),
|
CronExpr: strings.TrimSpace(cronExpr),
|
||||||
NextRunAt: nextRunAt,
|
NextRunAt: nextRunAt,
|
||||||
ScheduleEnabled: true,
|
ScheduleEnabled: true,
|
||||||
|
Concurrency: normalizeBatchQueueConcurrency(concurrency),
|
||||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||||
Status: BatchQueueStatusPending,
|
Status: BatchQueueStatusPending,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
|||||||
queue.CronExpr,
|
queue.CronExpr,
|
||||||
queue.NextRunAt,
|
queue.NextRunAt,
|
||||||
queue.ProjectID,
|
queue.ProjectID,
|
||||||
|
queue.Concurrency,
|
||||||
dbTasks,
|
dbTasks,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||||
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
|||||||
if queueRow.ProjectID.Valid {
|
if queueRow.ProjectID.Valid {
|
||||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||||
}
|
}
|
||||||
|
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||||
if queueRow.StartedAt.Valid {
|
if queueRow.StartedAt.Valid {
|
||||||
queue.StartedAt = &queueRow.StartedAt.Time
|
queue.StartedAt = &queueRow.StartedAt.Time
|
||||||
}
|
}
|
||||||
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
|||||||
if queueRow.ProjectID.Valid {
|
if queueRow.ProjectID.Valid {
|
||||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||||
}
|
}
|
||||||
|
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||||
if queueRow.StartedAt.Valid {
|
if queueRow.StartedAt.Valid {
|
||||||
queue.StartedAt = &queueRow.StartedAt.Time
|
queue.StartedAt = &queueRow.StartedAt.Time
|
||||||
}
|
}
|
||||||
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用)
|
// batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
|
||||||
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error {
|
func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
|
||||||
|
if row == nil || !row.Concurrency.Valid {
|
||||||
|
return DefaultBatchQueueConcurrency
|
||||||
|
}
|
||||||
|
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
|
||||||
|
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
|
||||||
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
|
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
|
||||||
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
|
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
|
||||||
}
|
}
|
||||||
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
|
|||||||
queue.Title = title
|
queue.Title = title
|
||||||
queue.Role = role
|
queue.Role = role
|
||||||
queue.AgentMode = agentMode
|
queue.AgentMode = agentMode
|
||||||
|
if concurrency != nil {
|
||||||
|
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
if m.db != nil {
|
if m.db != nil {
|
||||||
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil {
|
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
|
||||||
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
|
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
|
|||||||
|
|
||||||
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
||||||
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||||
var cancelFunc context.CancelFunc
|
|
||||||
var siblingRunningIDs []string
|
var siblingRunningIDs []string
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
||||||
|
var cancelFuncs []context.CancelFunc
|
||||||
if queue.Status == BatchQueueStatusPaused {
|
if queue.Status == BatchQueueStatusPaused {
|
||||||
if c, ok := m.taskCancels[queueID]; ok {
|
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||||
cancelFunc = c
|
|
||||||
delete(m.taskCancels, queueID)
|
|
||||||
}
|
|
||||||
for _, t := range queue.Tasks {
|
for _, t := range queue.Tasks {
|
||||||
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
||||||
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
||||||
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
|||||||
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
if cancelFunc != nil {
|
for _, c := range cancelFuncs {
|
||||||
cancelFunc()
|
if c != nil {
|
||||||
|
c()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const staleRunMsg = "为单条执行其它任务,已中止"
|
const staleRunMsg = "为单条执行其它任务,已中止"
|
||||||
for _, sid := range siblingRunningIDs {
|
for _, sid := range siblingRunningIDs {
|
||||||
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNextTask 获取下一个待执行的任务
|
// ClaimNextPendingTask 原子领取下一个待执行子任务(并发 worker 安全)。
|
||||||
|
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
onlyTaskID := ""
|
||||||
|
if m.singleRunTasks != nil {
|
||||||
|
onlyTaskID = m.singleRunTasks[queueID]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, task := range queue.Tasks {
|
||||||
|
if task == nil || task.Status != BatchTaskStatusPending {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if onlyTaskID != "" && task.ID != onlyTaskID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
task.Status = BatchTaskStatusRunning
|
||||||
|
queue.CurrentIndex = i
|
||||||
|
return task, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRunningTasks 队列是否仍有 running 状态的子任务。
|
||||||
|
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, task := range queue.Tasks {
|
||||||
|
if task != nil && task.Status == BatchTaskStatusRunning {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
|
||||||
|
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, task := range queue.Tasks {
|
||||||
|
if task == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
|
||||||
|
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
|
||||||
|
taskMap, ok := m.taskCancels[queueID]
|
||||||
|
if !ok || len(taskMap) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cancels := make([]context.CancelFunc, 0, len(taskMap))
|
||||||
|
for _, c := range taskMap {
|
||||||
|
if c != nil {
|
||||||
|
cancels = append(cancels, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m.taskCancels, queueID)
|
||||||
|
return cancels
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask)
|
||||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTaskCancel 设置当前任务的取消函数
|
// SetTaskCancel 设置子任务的取消函数
|
||||||
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
|
func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
if cancel != nil {
|
if cancel == nil {
|
||||||
m.taskCancels[queueID] = cancel
|
if taskMap, ok := m.taskCancels[queueID]; ok {
|
||||||
} else {
|
delete(taskMap, taskID)
|
||||||
delete(m.taskCancels, queueID)
|
if len(taskMap) == 0 {
|
||||||
|
delete(m.taskCancels, queueID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if m.taskCancels[queueID] == nil {
|
||||||
|
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
|
||||||
|
}
|
||||||
|
m.taskCancels[queueID][taskID] = cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
// PauseQueue 暂停队列
|
// PauseQueue 暂停队列
|
||||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||||
var cancelFunc context.CancelFunc
|
var cancelFuncs []context.CancelFunc
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
queue, exists := m.queues[queueID]
|
queue, exists := m.queues[queueID]
|
||||||
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
queue.Status = BatchQueueStatusPaused
|
queue.Status = BatchQueueStatusPaused
|
||||||
|
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||||
// 取消当前正在执行的任务(通过取消context)
|
|
||||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
|
||||||
cancelFunc = cancel
|
|
||||||
delete(m.taskCancels, queueID)
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
for _, c := range cancelFuncs {
|
||||||
if cancelFunc != nil {
|
c()
|
||||||
cancelFunc()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
|||||||
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
||||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var cancelFunc context.CancelFunc
|
var cancelFuncs []context.CancelFunc
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
queue, exists := m.queues[queueID]
|
queue, exists := m.queues[queueID]
|
||||||
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 取消当前正在执行的任务
|
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
|
||||||
cancelFunc = cancel
|
|
||||||
delete(m.taskCancels, queueID)
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
for _, c := range cancelFuncs {
|
||||||
if cancelFunc != nil {
|
c()
|
||||||
cancelFunc()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteQueue 删除队列(运行中的队列不允许删除)
|
// DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
|
||||||
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
func (m *BatchTaskManager) DeleteQueue(queueID string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
queue, exists := m.queues[queueID]
|
queue, exists := m.queues[queueID]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return ErrBatchQueueNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exec := m.queueExecutors[queueID]; exec {
|
||||||
|
return ErrBatchQueueExecutorActive
|
||||||
}
|
}
|
||||||
|
|
||||||
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
|
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
|
||||||
if queue.Status == BatchQueueStatusRunning {
|
if queue.Status == BatchQueueStatusRunning {
|
||||||
return false
|
return ErrBatchQueueStillRunning
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理取消函数
|
// 清理取消函数
|
||||||
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
delete(m.queues, queueID)
|
delete(m.queues, queueID)
|
||||||
return true
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateShortID 生成短ID
|
// generateShortID 生成短ID
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
|
||||||
|
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
|
||||||
|
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
|
||||||
|
}
|
||||||
|
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
|
||||||
|
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaimNextPendingTaskParallel(t *testing.T) {
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateBatchQueue: %v", err)
|
||||||
|
}
|
||||||
|
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||||
|
|
||||||
|
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
if !ok1 || !ok2 || t1.ID == t2.ID {
|
||||||
|
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
|
||||||
|
}
|
||||||
|
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
|
||||||
|
t.Fatalf("claimed tasks should be running")
|
||||||
|
}
|
||||||
|
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
if !ok3 {
|
||||||
|
t.Fatal("expected third claim")
|
||||||
|
}
|
||||||
|
_, ok4 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
if ok4 {
|
||||||
|
t.Fatal("expected no fourth pending task")
|
||||||
|
}
|
||||||
|
_ = t3
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchQueueExecutionShouldStop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !batchQueueExecutionShouldStop(nil, false) {
|
||||||
|
t.Fatal("expected stop when queue missing")
|
||||||
|
}
|
||||||
|
if !batchQueueExecutionShouldStop(nil, true) {
|
||||||
|
t.Fatal("expected stop when queue is nil but exists=true")
|
||||||
|
}
|
||||||
|
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
|
||||||
|
if batchQueueExecutionShouldStop(q, true) {
|
||||||
|
t.Fatal("expected continue when running")
|
||||||
|
}
|
||||||
|
q.Status = BatchQueueStatusCancelled
|
||||||
|
if !batchQueueExecutionShouldStop(q, true) {
|
||||||
|
t.Fatal("expected stop when cancelled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateBatchQueue: %v", err)
|
||||||
|
}
|
||||||
|
if !m.TryMarkQueueExecutor(queue.ID) {
|
||||||
|
t.Fatal("expected to mark executor")
|
||||||
|
}
|
||||||
|
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
|
||||||
|
|
||||||
|
err = m.DeleteQueue(queue.ID)
|
||||||
|
if !errors.Is(err, ErrBatchQueueExecutorActive) {
|
||||||
|
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := m.GetBatchQueue(queue.ID); !ok {
|
||||||
|
t.Fatal("queue should still exist while executor active")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.UnmarkQueueExecutor(queue.ID)
|
||||||
|
if err := m.DeleteQueue(queue.ID); err != nil {
|
||||||
|
t.Fatalf("expected delete after executor unmarked, got %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := m.GetBatchQueue(queue.ID); ok {
|
||||||
|
t.Fatal("queue should be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateBatchQueue: %v", err)
|
||||||
|
}
|
||||||
|
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||||
|
|
||||||
|
err = m.DeleteQueue(queue.ID)
|
||||||
|
if !errors.Is(err, ErrBatchQueueStillRunning) {
|
||||||
|
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
if !m.TryMarkQueueExecutor("q-1") {
|
||||||
|
t.Fatal("first mark should succeed")
|
||||||
|
}
|
||||||
|
if m.TryMarkQueueExecutor("q-1") {
|
||||||
|
t.Fatal("second mark should fail")
|
||||||
|
}
|
||||||
|
m.UnmarkQueueExecutor("q-1")
|
||||||
|
if !m.TryMarkQueueExecutor("q-1") {
|
||||||
|
t.Fatal("mark after unmark should succeed")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||||
},
|
},
|
||||||
|
"concurrency": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
executeNow = false
|
executeNow = false
|
||||||
}
|
}
|
||||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
concurrency := int(mcpArgFloat(args, "concurrency"))
|
||||||
|
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||||
}
|
}
|
||||||
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
if qid == "" {
|
if qid == "" {
|
||||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||||
}
|
}
|
||||||
if !h.batchTaskManager.DeleteQueue(qid) {
|
if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
|
||||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
switch {
|
||||||
|
case errors.Is(err, ErrBatchQueueNotFound):
|
||||||
|
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||||
|
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||||
|
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
|
||||||
|
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||||
|
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
|
||||||
|
default:
|
||||||
|
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||||||
return batchMCPTextResult("队列已删除。", false), nil
|
return batchMCPTextResult("队列已删除。", false), nil
|
||||||
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
||||||
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||||
},
|
},
|
||||||
|
"concurrency": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "同时执行的子任务数,默认 1,最大 8",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"queue_id"},
|
"required": []string{"queue_id"},
|
||||||
},
|
},
|
||||||
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
title := mcpArgString(args, "title")
|
title := mcpArgString(args, "title")
|
||||||
role := mcpArgString(args, "role")
|
role := mcpArgString(args, "role")
|
||||||
agentMode := mcpArgString(args, "agent_mode")
|
agentMode := mcpArgString(args, "agent_mode")
|
||||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
|
var concurrency *int
|
||||||
|
if raw, ok := args["concurrency"]; ok && raw != nil {
|
||||||
|
v := int(mcpArgFloat(args, "concurrency"))
|
||||||
|
concurrency = &v
|
||||||
|
}
|
||||||
|
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
|
||||||
return batchMCPTextResult(err.Error(), true), nil
|
return batchMCPTextResult(err.Error(), true), nil
|
||||||
}
|
}
|
||||||
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||||
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
|
|||||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||||
CurrentIndex int `json:"currentIndex"`
|
CurrentIndex int `json:"currentIndex"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
TaskTotal int `json:"task_total"`
|
TaskTotal int `json:"task_total"`
|
||||||
TaskCounts map[string]int `json:"task_counts"`
|
TaskCounts map[string]int `json:"task_counts"`
|
||||||
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
||||||
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
|
|||||||
StartedAt: q.StartedAt,
|
StartedAt: q.StartedAt,
|
||||||
CompletedAt: q.CompletedAt,
|
CompletedAt: q.CompletedAt,
|
||||||
CurrentIndex: q.CurrentIndex,
|
CurrentIndex: q.CurrentIndex,
|
||||||
|
Concurrency: q.Concurrency,
|
||||||
TaskTotal: len(tasks),
|
TaskTotal: len(tasks),
|
||||||
TaskCounts: counts,
|
TaskCounts: counts,
|
||||||
Tasks: tasks,
|
Tasks: tasks,
|
||||||
|
|||||||
@@ -798,6 +798,10 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
|
|
||||||
// 更新机器人配置
|
// 更新机器人配置
|
||||||
if req.Robots != nil {
|
if req.Robots != nil {
|
||||||
|
if err := config.ValidateWecomConfig(req.Robots.Wecom); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
h.config.Robots = *req.Robots
|
h.config.Robots = *req.Robots
|
||||||
h.logger.Info("更新机器人配置",
|
h.logger.Info("更新机器人配置",
|
||||||
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
||||||
@@ -1329,6 +1333,17 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
|||||||
h.logger.Info("已更新嵌入模型配置记录")
|
h.logger.Info("已更新嵌入模型配置记录")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 从 tools 目录重新加载工具配置(新增/修改/删除 yaml 后无需重启)
|
||||||
|
if err := config.ReloadSecurityToolsFromDir(h.config, h.configPath); err != nil {
|
||||||
|
h.logger.Error("重新加载工具配置失败", zap.Error(err))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新加载工具", map[string]interface{}{"error": err.Error()})
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新加载工具配置失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("已从 tools 目录重新加载工具配置", zap.Int("tools_count", len(h.config.Security.Tools)))
|
||||||
|
|
||||||
// 重新注册工具(根据新的启用状态)
|
// 重新注册工具(根据新的启用状态)
|
||||||
h.logger.Info("重新注册工具")
|
h.logger.Info("重新注册工具")
|
||||||
|
|
||||||
@@ -1751,6 +1766,20 @@ func mergeHitlToolWhitelistSlice(existing, add []string) []string {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHitlToolWhitelist 将全局免审批工具白名单整表写入 config.yaml(替换,非合并)。
|
||||||
|
func (h *ConfigHandler) SetHitlToolWhitelist(tools []string) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.config.Hitl.ToolWhitelist = mergeHitlToolWhitelistSlice(nil, tools)
|
||||||
|
if err := h.saveConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.logger.Info("HITL 全局工具白名单已写入配置文件",
|
||||||
|
zap.Int("count", len(h.config.Hitl.ToolWhitelist)),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。
|
// MergeHitlToolWhitelistIntoConfig 将会话侧栏提交的免审批工具名合并进内存配置并写入 config.yaml(与全局白名单去重规则一致:小写键、保留首次出现的原始大小写)。
|
||||||
func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error {
|
func (h *ConfigHandler) MergeHitlToolWhitelistIntoConfig(add []string) error {
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
@@ -1771,6 +1800,21 @@ func updateHitlConfig(doc *yaml.Node, cfg config.HitlConfig) {
|
|||||||
hitlNode := ensureMap(root, "hitl")
|
hitlNode := ensureMap(root, "hitl")
|
||||||
// flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数
|
// flow 样式 [a, b, c] 单行展示,工具多时比块序列省行数
|
||||||
setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist)
|
setFlowStringSliceInMap(hitlNode, "tool_whitelist", cfg.ToolWhitelist)
|
||||||
|
setStringInMap(hitlNode, "audit_agent_prompt", cfg.AuditAgentPrompt)
|
||||||
|
setStringInMap(hitlNode, "audit_agent_prompt_review_edit", cfg.AuditAgentPromptReviewEdit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHitlAuditAgentStrategy 更新审批/审查编辑两套审计 Agent 提示词并写入 config.yaml。
|
||||||
|
func (h *ConfigHandler) UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.config.Hitl.AuditAgentPrompt = strings.TrimSpace(approvalPrompt)
|
||||||
|
h.config.Hitl.AuditAgentPromptReviewEdit = strings.TrimSpace(reviewEditPrompt)
|
||||||
|
if err := h.saveConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.logger.Info("HITL 审计 Agent 提示词已写入配置文件")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
|||||||
limitStr := c.DefaultQuery("limit", "50")
|
limitStr := c.DefaultQuery("limit", "50")
|
||||||
offsetStr := c.DefaultQuery("offset", "0")
|
offsetStr := c.DefaultQuery("offset", "0")
|
||||||
search := c.Query("search") // 获取搜索参数
|
search := c.Query("search") // 获取搜索参数
|
||||||
|
projectID := strings.TrimSpace(c.Query("project_id"))
|
||||||
|
|
||||||
limit, _ := strconv.Atoi(limitStr)
|
limit, _ := strconv.Atoi(limitStr)
|
||||||
offset, _ := strconv.Atoi(offsetStr)
|
offset, _ := strconv.Atoi(offsetStr)
|
||||||
@@ -114,7 +115,7 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
|||||||
limit = 1000
|
limit = 1000
|
||||||
}
|
}
|
||||||
|
|
||||||
excludeGrouped := strings.TrimSpace(search) == "" &&
|
excludeGrouped := strings.TrimSpace(search) == "" && projectID == "" &&
|
||||||
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
||||||
sortBy := strings.TrimSpace(c.Query("sort_by"))
|
sortBy := strings.TrimSpace(c.Query("sort_by"))
|
||||||
|
|
||||||
@@ -122,14 +123,14 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
|||||||
var total int
|
var total int
|
||||||
var err error
|
var err error
|
||||||
if excludeGrouped {
|
if excludeGrouped {
|
||||||
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy)
|
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy, projectID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total, err = h.db.CountUngroupedConversations()
|
total, err = h.db.CountUngroupedConversations(projectID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
conversations, err = h.db.ListConversations(limit, offset, search, sortBy)
|
conversations, err = h.db.ListConversations(limit, offset, search, sortBy, projectID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total, err = h.db.CountConversations(search)
|
total, err = h.db.CountConversations(search, projectID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// rebindEinoRunningTask 中断并继续 / 空正文续跑:重建 cancel 链与超时 ctx,保持任务 running。
|
||||||
|
func (h *AgentHandler) rebindEinoRunningTask(conversationID string, timeoutCancel context.CancelFunc) (context.Context, context.CancelCauseFunc, context.Context, context.CancelFunc) {
|
||||||
|
if timeoutCancel != nil {
|
||||||
|
timeoutCancel()
|
||||||
|
}
|
||||||
|
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, newTimeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
return baseCtx, cancelWithCause, taskCtx, newTimeoutCancel
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryContinueOnEinoEmptyResponse Run 成功但 Response 为 emptyHint 时退避续跑;true 表示已准备下一段 Run。
|
||||||
|
func (h *AgentHandler) tryContinueOnEinoEmptyResponse(
|
||||||
|
taskCtx context.Context,
|
||||||
|
mw *config.MultiAgentEinoMiddlewareConfig,
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
attempt *int,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
progressCallback func(eventType, message string, data interface{}),
|
||||||
|
) bool {
|
||||||
|
if result == nil || !multiagent.IsEinoEmptyResponseResult(result) || !multiagent.HasEinoResumeTrace(result) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
maxAttempts := multiagent.EmptyResponseContinueMaxAttemptsFromConfig(mw)
|
||||||
|
if *attempt >= maxAttempts {
|
||||||
|
if h.logger != nil {
|
||||||
|
h.logger.Warn("eino empty response continue exhausted",
|
||||||
|
zap.String("conversationId", conversationID),
|
||||||
|
zap.Int("maxAttempts", maxAttempts))
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
*attempt++
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
|
||||||
|
backoff := multiagent.EmptyResponseContinueBackoff(*attempt-1, mw)
|
||||||
|
waitMsg := fmt.Sprintf("会话已结束但未捕获到助手正文,%d 秒后第 %d/%d 次自动续跑…",
|
||||||
|
int(backoff.Seconds()), *attempt, maxAttempts)
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback("eino_empty_response_continue", waitMsg, map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"attempt": *attempt,
|
||||||
|
"maxAttempts": maxAttempts,
|
||||||
|
"backoffSec": int(backoff.Seconds()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-taskCtx.Done():
|
||||||
|
return false
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
|
||||||
|
inject := multiagent.FormatEmptyResponseContinueUserMessage()
|
||||||
|
h.applyEinoTraceResumeSegment(conversationID, result, curHistory, curFinalMessage, inject)
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback("eino_empty_response_continue", "已恢复上下文,正在续跑…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"attempt": *attempt,
|
||||||
|
"maxAttempts": maxAttempts,
|
||||||
|
"contextSource": "empty_response_continue",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -178,6 +178,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
var mainIterationOffset int
|
var mainIterationOffset int
|
||||||
|
var emptyResponseContinueAttempt int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
segmentMainIterationMax := 0
|
segmentMainIterationMax := 0
|
||||||
@@ -231,7 +232,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
roleTools,
|
roleTools,
|
||||||
progressCallback,
|
progressCallback,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
h.projectBlackboardBlock(conversationID),
|
h.agentSessionContextBlock(conversationID),
|
||||||
)
|
)
|
||||||
|
|
||||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
@@ -239,6 +240,13 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
|
mw := &h.config.MultiAgent.EinoMiddleware
|
||||||
|
if h.tryContinueOnEinoEmptyResponse(taskCtx, mw, conversationID, result, &emptyResponseContinueAttempt, &curHistory, &curFinalMessage, progressCallback) {
|
||||||
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
timeoutCancel()
|
||||||
|
baseCtx, cancelWithCause, taskCtx, timeoutCancel = h.rebindEinoRunningTask(conversationID, timeoutCancel)
|
||||||
|
continue
|
||||||
|
}
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -416,7 +424,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
prep.RoleTools,
|
prep.RoleTools,
|
||||||
progressCallback,
|
progressCallback,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
h.projectBlackboardBlock(prep.ConversationID),
|
h.agentSessionContextBlock(prep.ConversationID),
|
||||||
)
|
)
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
break
|
break
|
||||||
|
|||||||
+136
-89
@@ -23,6 +23,7 @@ import (
|
|||||||
type hitlRuntimeConfig struct {
|
type hitlRuntimeConfig struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
Mode string
|
Mode string
|
||||||
|
Reviewer string
|
||||||
SensitiveTools map[string]struct{}
|
SensitiveTools map[string]struct{}
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
@@ -49,6 +50,8 @@ type HITLManager struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
runtime map[string]hitlRuntimeConfig
|
runtime map[string]hitlRuntimeConfig
|
||||||
pending map[string]*pendingInterrupt
|
pending map[string]*pendingInterrupt
|
||||||
|
// approvedExec 审批通过、待回写 tool_result 的队列(按会话 FIFO)
|
||||||
|
approvedExec map[string][]hitlApprovedExecTrack
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
|
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
|
||||||
@@ -90,6 +93,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
m.migrateHitlSchemaColumns()
|
||||||
|
|
||||||
// On startup, cancel all orphaned pending interrupts from previous process.
|
// On startup, cancel all orphaned pending interrupts from previous process.
|
||||||
// Their in-memory channels are gone, so they can never be resolved.
|
// Their in-memory channels are gone, so they can never be resolved.
|
||||||
@@ -141,6 +145,7 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
|
|||||||
m.runtime[conversationID] = hitlRuntimeConfig{
|
m.runtime[conversationID] = hitlRuntimeConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Mode: normalizeHitlMode(req.Mode),
|
Mode: normalizeHitlMode(req.Mode),
|
||||||
|
Reviewer: normalizeHitlReviewer(req.Reviewer),
|
||||||
SensitiveTools: tools,
|
SensitiveTools: tools,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
}
|
}
|
||||||
@@ -153,17 +158,14 @@ func (m *HITLManager) DeactivateConversation(conversationID string) {
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。
|
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空),并合并内置元工具免审批项。
|
||||||
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
||||||
if h == nil || h.config == nil {
|
if h == nil || h.config == nil {
|
||||||
return nil
|
return multiagent.MergeHitlExemptMetaTools(nil)
|
||||||
}
|
}
|
||||||
raw := h.config.Hitl.ToolWhitelist
|
raw := h.config.Hitl.ToolWhitelist
|
||||||
if len(raw) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
seen := make(map[string]struct{})
|
seen := make(map[string]struct{})
|
||||||
out := make([]string, 0, len(raw))
|
out := make([]string, 0, len(raw)+len(multiagent.HitlExemptMetaTools))
|
||||||
for _, t := range raw {
|
for _, t := range raw {
|
||||||
n := strings.ToLower(strings.TrimSpace(t))
|
n := strings.ToLower(strings.TrimSpace(t))
|
||||||
if n == "" {
|
if n == "" {
|
||||||
@@ -175,44 +177,35 @@ func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
|||||||
seen[n] = struct{}{}
|
seen[n] = struct{}{}
|
||||||
out = append(out, strings.TrimSpace(t))
|
out = append(out, strings.TrimSpace(t))
|
||||||
}
|
}
|
||||||
return out
|
return multiagent.MergeHitlExemptMetaTools(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。
|
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单及内置元工具免审批项合并(并集),仅用于运行时 Activate;不写入数据库。
|
||||||
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
|
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
|
||||||
gw := h.hitlConfigGlobalToolWhitelist()
|
|
||||||
if len(gw) == 0 {
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
if req == nil {
|
if req == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
seen := make(map[string]struct{})
|
seen := make(map[string]struct{})
|
||||||
union := make([]string, 0, len(gw)+len(req.SensitiveTools))
|
union := make([]string, 0, len(req.SensitiveTools)+16)
|
||||||
for _, t := range gw {
|
add := func(t string) {
|
||||||
n := strings.ToLower(strings.TrimSpace(t))
|
n := strings.ToLower(strings.TrimSpace(t))
|
||||||
if n == "" {
|
if n == "" {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
if _, ok := seen[n]; ok {
|
if _, ok := seen[n]; ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
seen[n] = struct{}{}
|
seen[n] = struct{}{}
|
||||||
union = append(union, strings.TrimSpace(t))
|
union = append(union, strings.TrimSpace(t))
|
||||||
}
|
}
|
||||||
|
for _, t := range h.hitlConfigGlobalToolWhitelist() {
|
||||||
|
add(t)
|
||||||
|
}
|
||||||
for _, t := range req.SensitiveTools {
|
for _, t := range req.SensitiveTools {
|
||||||
n := strings.ToLower(strings.TrimSpace(t))
|
add(t)
|
||||||
if n == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := seen[n]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[n] = struct{}{}
|
|
||||||
union = append(union, strings.TrimSpace(t))
|
|
||||||
}
|
}
|
||||||
out := *req
|
out := *req
|
||||||
out.SensitiveTools = union
|
out.SensitiveTools = multiagent.MergeHitlExemptMetaTools(union)
|
||||||
return &out
|
return &out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,22 +355,22 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
|
|||||||
timeout = 0
|
timeout = 0
|
||||||
}
|
}
|
||||||
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
||||||
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
|
(conversation_id, enabled, mode, reviewer, sensitive_tools, timeout_seconds, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT(conversation_id) DO UPDATE SET
|
ON CONFLICT(conversation_id) DO UPDATE SET
|
||||||
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
enabled=excluded.enabled, mode=excluded.mode, reviewer=excluded.reviewer, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||||
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now())
|
conversationID, boolToInt(req.Enabled), mode, normalizeHitlReviewer(req.Reviewer), string(tools), timeout, time.Now())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
|
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
|
||||||
var enabledInt int
|
var enabledInt int
|
||||||
var mode, toolsJSON string
|
var mode, reviewer, toolsJSON string
|
||||||
var timeout int
|
var timeout int
|
||||||
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
err := m.db.QueryRow(`SELECT enabled, mode, COALESCE(reviewer,'human'), sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||||
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
|
Scan(&enabledInt, &mode, &reviewer, &toolsJSON, &timeout)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
return &HITLRequest{Enabled: false, Mode: "off", Reviewer: "human", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -390,6 +383,7 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
|
|||||||
return &HITLRequest{
|
return &HITLRequest{
|
||||||
Enabled: enabledInt == 1,
|
Enabled: enabledInt == 1,
|
||||||
Mode: mode,
|
Mode: mode,
|
||||||
|
Reviewer: normalizeHitlReviewer(reviewer),
|
||||||
SensitiveTools: tools,
|
SensitiveTools: tools,
|
||||||
TimeoutSeconds: timeout,
|
TimeoutSeconds: timeout,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -413,15 +407,16 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
|
|||||||
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
|
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
|
||||||
d.EditedArguments = nil
|
d.EditedArguments = nil
|
||||||
}
|
}
|
||||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
|
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='human' WHERE id=?`,
|
||||||
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
||||||
return d, nil
|
return d, nil
|
||||||
case <-timeoutCh:
|
case <-timeoutCh:
|
||||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
|
comment := "HITL timeout auto-reject for safety"
|
||||||
time.Now(), p.InterruptID)
|
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='reject', decision_comment=?, decided_at=?, decided_by='system' WHERE id=?`,
|
||||||
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
|
comment, time.Now(), p.InterruptID)
|
||||||
|
return hitlDecision{Decision: "reject", Comment: comment}, nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`,
|
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=?, decided_by='system' WHERE id=?`,
|
||||||
time.Now(), p.InterruptID)
|
time.Now(), p.InterruptID)
|
||||||
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
|
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
|
||||||
}
|
}
|
||||||
@@ -445,12 +440,57 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
|||||||
if !need {
|
if !need {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
h.enrichHitlApprovalPayload(conversationID, assistantMessageID, payload)
|
||||||
payloadRaw, _ := json.Marshal(payload)
|
payloadRaw, _ := json.Marshal(payload)
|
||||||
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
|
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
|
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Reviewer == "audit_agent" {
|
||||||
|
ad := h.auditAgentReview(runCtx, cfg.Mode, toolName, payload)
|
||||||
|
now := time.Now()
|
||||||
|
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=?, decided_by='audit_agent' WHERE id=?`,
|
||||||
|
ad.Decision, ad.Comment, now, p.InterruptID)
|
||||||
|
if sendEventFunc != nil {
|
||||||
|
sendEventFunc("hitl_audit_agent", "审计 Agent 已裁决", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"interruptId": p.InterruptID,
|
||||||
|
"toolName": toolName,
|
||||||
|
"mode": cfg.Mode,
|
||||||
|
"decision": ad.Decision,
|
||||||
|
"comment": ad.Comment,
|
||||||
|
"editedArgs": ad.EditedArguments,
|
||||||
|
"decidedBy": "audit_agent",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if ad.Decision == "reject" {
|
||||||
|
if sendEventFunc != nil {
|
||||||
|
sendEventFunc("hitl_rejected", "审计 Agent 拒绝本次工具调用", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"interruptId": p.InterruptID,
|
||||||
|
"toolName": toolName,
|
||||||
|
"comment": ad.Comment,
|
||||||
|
"decidedBy": "audit_agent",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &ad, nil
|
||||||
|
}
|
||||||
|
if sendEventFunc != nil {
|
||||||
|
sendEventFunc("hitl_resumed", "审计 Agent 已通过,继续执行", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"interruptId": p.InterruptID,
|
||||||
|
"toolName": toolName,
|
||||||
|
"comment": ad.Comment,
|
||||||
|
"editedArgs": ad.EditedArguments,
|
||||||
|
"decidedBy": "audit_agent",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
|
||||||
|
return &ad, nil
|
||||||
|
}
|
||||||
|
|
||||||
if sendEventFunc != nil {
|
if sendEventFunc != nil {
|
||||||
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
|
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
@@ -479,8 +519,12 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
|||||||
return nil, waitErr
|
return nil, waitErr
|
||||||
}
|
}
|
||||||
if d.Decision == "reject" {
|
if d.Decision == "reject" {
|
||||||
|
rejectMsg := "人工拒绝本次工具调用,模型将基于反馈继续迭代"
|
||||||
|
if strings.Contains(strings.ToLower(strings.TrimSpace(d.Comment)), "timeout") {
|
||||||
|
rejectMsg = "审批超时,安全起见已自动拒绝,模型将基于反馈继续迭代"
|
||||||
|
}
|
||||||
if sendEventFunc != nil {
|
if sendEventFunc != nil {
|
||||||
sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{
|
sendEventFunc("hitl_rejected", rejectMsg, map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"interruptId": p.InterruptID,
|
"interruptId": p.InterruptID,
|
||||||
"toolName": toolName,
|
"toolName": toolName,
|
||||||
@@ -498,6 +542,7 @@ func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun contex
|
|||||||
"editedArgs": d.EditedArguments,
|
"editedArgs": d.EditedArguments,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
h.hitlManager.TrackApprovedHitlExecution(p.InterruptID, conversationID, toolName, toolCallID)
|
||||||
return &d, nil
|
return &d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,11 +572,6 @@ func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun cont
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
|
||||||
status := strings.TrimSpace(c.Query("status"))
|
|
||||||
if status == "" {
|
|
||||||
status = "pending"
|
|
||||||
}
|
|
||||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
if page < 1 {
|
if page < 1 {
|
||||||
page = 1
|
page = 1
|
||||||
@@ -539,15 +579,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
|||||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1`
|
q, args := h.buildHitlListQuery(false)
|
||||||
args := []interface{}{}
|
q, args = h.appendHitlListFilters(q, args, c)
|
||||||
if conversationID != "" {
|
total, err := h.countHitlQuery(q, args)
|
||||||
q += " AND conversation_id = ?"
|
if err != nil {
|
||||||
args = append(args, conversationID)
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
return
|
||||||
if status != "all" {
|
|
||||||
q += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
}
|
||||||
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
args = append(args, pageSize, offset)
|
args = append(args, pageSize, offset)
|
||||||
@@ -557,41 +594,12 @@ func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
items := make([]map[string]interface{}, 0)
|
items, err := h.scanHitlInterruptRows(rows)
|
||||||
for rows.Next() {
|
if err != nil {
|
||||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus string
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
var messageID sql.NullString
|
return
|
||||||
var decision, comment sql.NullString
|
|
||||||
var createdAt time.Time
|
|
||||||
var decidedAt sql.NullTime
|
|
||||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msgID := ""
|
|
||||||
if messageID.Valid {
|
|
||||||
msgID = messageID.String
|
|
||||||
}
|
|
||||||
items = append(items, map[string]interface{}{
|
|
||||||
"id": id,
|
|
||||||
"conversationId": cid,
|
|
||||||
"messageId": msgID,
|
|
||||||
"mode": mode,
|
|
||||||
"toolName": toolName,
|
|
||||||
"toolCallId": toolCallID,
|
|
||||||
"payload": payload,
|
|
||||||
"status": rowStatus,
|
|
||||||
"decision": decision.String,
|
|
||||||
"comment": comment.String,
|
|
||||||
"createdAt": createdAt,
|
|
||||||
"decidedAt": func() interface{} {
|
|
||||||
if decidedAt.Valid {
|
|
||||||
return decidedAt.Time
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}(),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize})
|
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total})
|
||||||
}
|
}
|
||||||
|
|
||||||
type hitlDecisionReq struct {
|
type hitlDecisionReq struct {
|
||||||
@@ -636,7 +644,7 @@ func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
||||||
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP
|
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP, decided_by='human'
|
||||||
WHERE id=? AND status='pending'`, req.InterruptID)
|
WHERE id=? AND status='pending'`, req.InterruptID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{"error": err.Error()})
|
c.JSON(500, gin.H{"error": err.Error()})
|
||||||
@@ -732,6 +740,7 @@ func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Mode = normalizeHitlMode(req.Mode)
|
req.Mode = normalizeHitlMode(req.Mode)
|
||||||
|
req.Reviewer = normalizeHitlReviewer(req.Reviewer)
|
||||||
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
|
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -753,6 +762,44 @@ type mergeHitlGlobalWhitelistReq struct {
|
|||||||
SensitiveTools []string `json:"sensitiveTools"`
|
SensitiveTools []string `json:"sensitiveTools"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type setHitlGlobalWhitelistReq struct {
|
||||||
|
ToolWhitelist []string `json:"toolWhitelist"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHITLGlobalToolWhitelist 返回 config.yaml 中的全局免审批工具白名单。
|
||||||
|
func (h *AgentHandler) GetHITLGlobalToolWhitelist(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHITLGlobalToolWhitelist 整表替换 config.yaml 中的全局免审批工具白名单。
|
||||||
|
func (h *AgentHandler) SetHITLGlobalToolWhitelist(c *gin.Context) {
|
||||||
|
if h.hitlWhitelistSaver == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req setHitlGlobalWhitelistReq
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.hitlWhitelistSaver.SetHitlToolWhitelist(req.ToolWhitelist); err != nil {
|
||||||
|
h.logger.Warn("写入 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "hitl", "tool_whitelist_update", "HITL 全局白名单更新", "hitl_config", "tool_whitelist", nil)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"ok": true,
|
||||||
|
"toolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||||
|
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||||
|
"hitlGlobalWhitelistMerged": false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
|
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
|
||||||
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
|
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
|
||||||
if h.hitlWhitelistSaver == nil {
|
if h.hitlWhitelistSaver == nil {
|
||||||
|
|||||||
@@ -0,0 +1,357 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// auditAgentReview 在 reviewer=audit_agent 时由 LLM 代行审批。
|
||||||
|
// 白名单工具在 shouldInterrupt 阶段已跳过,到达此处的一律需要裁决。
|
||||||
|
func (h *AgentHandler) auditAgentReview(ctx context.Context, hitlMode, toolName string, payload map[string]interface{}) hitlDecision {
|
||||||
|
if h == nil {
|
||||||
|
return hitlDecision{Decision: "reject", Comment: "audit agent: handler unavailable"}
|
||||||
|
}
|
||||||
|
mode := normalizeHitlMode(hitlMode)
|
||||||
|
prompt := config.DefaultHitlAuditAgentPrompt()
|
||||||
|
if h.config != nil {
|
||||||
|
prompt = h.config.Hitl.EffectiveAuditAgentPromptForMode(mode)
|
||||||
|
}
|
||||||
|
if h.auditLLM == nil {
|
||||||
|
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 未配置"}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
callCtx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
userContent := buildAuditAgentReviewInput(mode, toolName, payload)
|
||||||
|
requestBody := map[string]interface{}{
|
||||||
|
"model": h.auditLLMModel(),
|
||||||
|
"messages": []map[string]interface{}{
|
||||||
|
{"role": "system", "content": prompt},
|
||||||
|
{"role": "user", "content": userContent},
|
||||||
|
},
|
||||||
|
"temperature": 0.1,
|
||||||
|
"max_completion_tokens": 1024,
|
||||||
|
// 审计裁决需要结构化 JSON;关闭 thinking 避免 Qwen 等把正文放进 reasoning_content 导致解析失败。
|
||||||
|
"thinking": map[string]interface{}{"type": "disabled"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiResponse struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ReasoningContent string `json:"reasoning_content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
if err := h.auditLLM.ChatCompletion(callCtx, requestBody, &apiResponse); err != nil {
|
||||||
|
h.logger.Warn("审计 Agent LLM 调用失败", zap.Error(err), zap.String("tool", toolName))
|
||||||
|
return hitlDecision{
|
||||||
|
Decision: "reject",
|
||||||
|
Comment: "audit agent: LLM 调用失败,保守拒绝",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(apiResponse.Choices) == 0 {
|
||||||
|
return hitlDecision{Decision: "reject", Comment: "audit agent: LLM 无有效响应,保守拒绝"}
|
||||||
|
}
|
||||||
|
msg := apiResponse.Choices[0].Message
|
||||||
|
raw := strings.TrimSpace(msg.Content)
|
||||||
|
if raw == "" {
|
||||||
|
raw = strings.TrimSpace(msg.ReasoningContent)
|
||||||
|
}
|
||||||
|
dec, err := parseAuditAgentLLMContent(raw)
|
||||||
|
if err != nil {
|
||||||
|
snippet := raw
|
||||||
|
if len(snippet) > 240 {
|
||||||
|
snippet = snippet[:240] + "..."
|
||||||
|
}
|
||||||
|
h.logger.Warn("审计 Agent 响应解析失败",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("tool", toolName),
|
||||||
|
zap.String("mode", mode),
|
||||||
|
zap.String("snippet", snippet),
|
||||||
|
)
|
||||||
|
return hitlDecision{Decision: "reject", Comment: "audit agent: 响应无法解析,保守拒绝"}
|
||||||
|
}
|
||||||
|
if mode != "review_edit" && len(dec.EditedArguments) > 0 {
|
||||||
|
h.logger.Warn("审计 Agent 在审批模式下返回 editedArguments,已忽略",
|
||||||
|
zap.String("tool", toolName),
|
||||||
|
)
|
||||||
|
dec.EditedArguments = nil
|
||||||
|
}
|
||||||
|
if dec.Comment == "" {
|
||||||
|
dec.Comment = "audit agent: " + dec.Decision
|
||||||
|
} else if !strings.HasPrefix(strings.ToLower(dec.Comment), "audit agent") {
|
||||||
|
dec.Comment = "audit agent: " + dec.Comment
|
||||||
|
}
|
||||||
|
return dec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) auditLLMModel() string {
|
||||||
|
if h.config != nil && strings.TrimSpace(h.config.OpenAI.Model) != "" {
|
||||||
|
return strings.TrimSpace(h.config.OpenAI.Model)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAuditAgentReviewInput(hitlMode, toolName string, payload map[string]interface{}) string {
|
||||||
|
review := map[string]interface{}{
|
||||||
|
"hitlMode": normalizeHitlMode(hitlMode),
|
||||||
|
"toolName": strings.TrimSpace(toolName),
|
||||||
|
}
|
||||||
|
if payload != nil {
|
||||||
|
for _, k := range []string{"arguments", "argumentsObj", "command", hitlPayloadUserMessage, hitlPayloadThinking, hitlPayloadReasoningChain, hitlPayloadPlanning} {
|
||||||
|
if v, ok := payload[k]; ok && v != nil && fmt.Sprint(v) != "" {
|
||||||
|
review[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b, err := json.MarshalIndent(review, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf(`{"hitlMode":%q,"toolName":%q}`, normalizeHitlMode(hitlMode), toolName)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAuditAgentLLMContent(content string) (hitlDecision, error) {
|
||||||
|
s := strings.TrimSpace(content)
|
||||||
|
if s == "" {
|
||||||
|
return hitlDecision{}, errors.New("empty content")
|
||||||
|
}
|
||||||
|
for _, candidate := range auditAgentJSONCandidates(s) {
|
||||||
|
dec, comment, editedArgs, err := parseAuditAgentDecisionObject(candidate)
|
||||||
|
if err == nil {
|
||||||
|
return hitlDecision{
|
||||||
|
Decision: dec,
|
||||||
|
Comment: comment,
|
||||||
|
EditedArguments: editedArgs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hitlDecision{}, fmt.Errorf("no valid decision json in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func auditAgentJSONCandidates(s string) []string {
|
||||||
|
out := make([]string, 0, 4)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
add := func(c string) {
|
||||||
|
c = strings.TrimSpace(c)
|
||||||
|
if c == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[c]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[c] = struct{}{}
|
||||||
|
out = append(out, c)
|
||||||
|
}
|
||||||
|
add(s)
|
||||||
|
add(stripMarkdownCodeFence(s))
|
||||||
|
if obj := extractFirstJSONObject(s); obj != "" {
|
||||||
|
add(obj)
|
||||||
|
}
|
||||||
|
if obj := extractFirstJSONObject(stripMarkdownCodeFence(s)); obj != "" {
|
||||||
|
add(obj)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripMarkdownCodeFence(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
for _, fence := range []string{"```json", "```JSON", "```"} {
|
||||||
|
if strings.HasPrefix(s, fence) {
|
||||||
|
s = strings.TrimPrefix(s, fence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s = strings.TrimSuffix(s, "```")
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractFirstJSONObject(s string) string {
|
||||||
|
start := strings.Index(s, "{")
|
||||||
|
if start < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
depth := 0
|
||||||
|
inStr := false
|
||||||
|
esc := false
|
||||||
|
for i := start; i < len(s); i++ {
|
||||||
|
ch := s[i]
|
||||||
|
if inStr {
|
||||||
|
if esc {
|
||||||
|
esc = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '\\' {
|
||||||
|
esc = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ch == '"' {
|
||||||
|
inStr = false
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch ch {
|
||||||
|
case '"':
|
||||||
|
inStr = true
|
||||||
|
case '{':
|
||||||
|
depth++
|
||||||
|
case '}':
|
||||||
|
depth--
|
||||||
|
if depth == 0 {
|
||||||
|
return s[start : i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAuditAgentDecisionObject(jsonText string) (decision, comment string, editedArgs map[string]interface{}, err error) {
|
||||||
|
var parsed map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(jsonText), &parsed); err != nil {
|
||||||
|
return "", "", nil, err
|
||||||
|
}
|
||||||
|
rawDecision := auditAgentPickString(parsed, "decision", "Decision", "result", "action", "verdict", "决策", "决定")
|
||||||
|
decision = normalizeAuditAgentDecision(rawDecision)
|
||||||
|
if decision == "" {
|
||||||
|
return "", "", nil, fmt.Errorf("missing decision")
|
||||||
|
}
|
||||||
|
comment = auditAgentPickString(parsed, "comment", "Comment", "reason", "message", "rationale", "备注", "理由", "说明")
|
||||||
|
editedArgs = auditAgentPickObject(parsed, "editedArguments", "edited_arguments", "editedArgs")
|
||||||
|
return decision, strings.TrimSpace(comment), editedArgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func auditAgentPickString(m map[string]interface{}, keys ...string) string {
|
||||||
|
for _, k := range keys {
|
||||||
|
if v, ok := m[k]; ok && v != nil {
|
||||||
|
s := strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
if s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func auditAgentPickObject(m map[string]interface{}, keys ...string) map[string]interface{} {
|
||||||
|
for _, k := range keys {
|
||||||
|
v, ok := m[k]
|
||||||
|
if !ok || v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch t := v.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
if len(t) > 0 {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
s := strings.TrimSpace(t)
|
||||||
|
if s == "" || s == "{}" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var obj map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(s), &obj); err == nil && len(obj) > 0 {
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAuditAgentDecision(v string) string {
|
||||||
|
d := strings.ToLower(strings.TrimSpace(v))
|
||||||
|
switch d {
|
||||||
|
case "approve", "approved", "pass", "passed", "allow", "allowed", "yes", "ok", "accept", "accepted":
|
||||||
|
return "approve"
|
||||||
|
case "reject", "rejected", "deny", "denied", "no", "block", "blocked", "refuse", "refused":
|
||||||
|
return "reject"
|
||||||
|
}
|
||||||
|
switch strings.TrimSpace(v) {
|
||||||
|
case "通过", "批准", "允许", "同意", "放行":
|
||||||
|
return "approve"
|
||||||
|
case "拒绝", "驳回", "禁止", "否决":
|
||||||
|
return "reject"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type hitlAuditStrategyReq struct {
|
||||||
|
AuditAgentPrompt string `json:"auditAgentPrompt"`
|
||||||
|
AuditAgentPromptReviewEdit string `json:"auditAgentPromptReviewEdit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) GetHITLAuditStrategy(c *gin.Context) {
|
||||||
|
approvalPrompt := config.DefaultHitlAuditAgentPrompt()
|
||||||
|
reviewEditPrompt := config.DefaultHitlAuditAgentPromptReviewEdit()
|
||||||
|
approvalCustom := false
|
||||||
|
reviewEditCustom := false
|
||||||
|
if h.config != nil {
|
||||||
|
approvalPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("approval")
|
||||||
|
reviewEditPrompt = h.config.Hitl.EffectiveAuditAgentPromptForMode("review_edit")
|
||||||
|
approvalCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPrompt) != ""
|
||||||
|
reviewEditCustom = strings.TrimSpace(h.config.Hitl.AuditAgentPromptReviewEdit) != ""
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"auditAgentPrompt": approvalPrompt,
|
||||||
|
"auditAgentPromptCustom": approvalCustom,
|
||||||
|
"auditAgentPromptReviewEdit": reviewEditPrompt,
|
||||||
|
"auditAgentPromptReviewEditCustom": reviewEditCustom,
|
||||||
|
"defaultAuditAgentPrompt": config.DefaultHitlAuditAgentPrompt(),
|
||||||
|
"defaultAuditAgentPromptReviewEdit": config.DefaultHitlAuditAgentPromptReviewEdit(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) UpdateHITLAuditStrategy(c *gin.Context) {
|
||||||
|
if h.hitlStrategySaver == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 策略持久化不可用"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var req hitlAuditStrategyReq
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
approvalPrompt := strings.TrimSpace(req.AuditAgentPrompt)
|
||||||
|
reviewEditPrompt := strings.TrimSpace(req.AuditAgentPromptReviewEdit)
|
||||||
|
if err := h.hitlStrategySaver.UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt); err != nil {
|
||||||
|
h.logger.Warn("保存审计 Agent 提示词失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "hitl", "audit_strategy_update", "HITL 审计策略更新", "hitl_config", "audit_agent_prompt", nil)
|
||||||
|
}
|
||||||
|
if h.config != nil {
|
||||||
|
h.config.Hitl.AuditAgentPrompt = approvalPrompt
|
||||||
|
h.config.Hitl.AuditAgentPromptReviewEdit = reviewEditPrompt
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"ok": true,
|
||||||
|
"auditAgentPrompt": config.HitlConfig{AuditAgentPrompt: approvalPrompt}.EffectiveAuditAgentPromptForMode("approval"),
|
||||||
|
"auditAgentPromptCustom": approvalPrompt != "",
|
||||||
|
"auditAgentPromptReviewEdit": config.HitlConfig{AuditAgentPromptReviewEdit: reviewEditPrompt}.EffectiveAuditAgentPromptForMode("review_edit"),
|
||||||
|
"auditAgentPromptReviewEditCustom": reviewEditPrompt != "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlAuditStrategySaver 持久化审计 Agent 提示词到 config.yaml。
|
||||||
|
type HitlAuditStrategySaver interface {
|
||||||
|
UpdateHitlAuditAgentStrategy(approvalPrompt, reviewEditPrompt string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHitlAuditStrategySaver 设置审计策略落盘。
|
||||||
|
func (h *AgentHandler) SetHitlAuditStrategySaver(s HitlAuditStrategySaver) {
|
||||||
|
h.hitlStrategySaver = s
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseAuditAgentLLMContentApprove(t *testing.T) {
|
||||||
|
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"与任务一致"}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if d.Decision != "approve" || d.Comment != "与任务一致" {
|
||||||
|
t.Fatalf("unexpected %+v", d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAuditAgentLLMContentReject(t *testing.T) {
|
||||||
|
d, err := parseAuditAgentLLMContent("```json\n{\"decision\":\"reject\",\"comment\":\"风险过高\"}\n```")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if d.Decision != "reject" {
|
||||||
|
t.Fatalf("expected reject, got %s", d.Decision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAuditAgentLLMContentInvalid(t *testing.T) {
|
||||||
|
_, err := parseAuditAgentLLMContent(`{"decision":"maybe"}`)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid decision")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAuditAgentLLMContentProseWrapped(t *testing.T) {
|
||||||
|
d, err := parseAuditAgentLLMContent("好的,裁决如下:\n```json\n{\"decision\":\"approve\",\"comment\":\"只读 ls\"}\n```\n以上。")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if d.Decision != "approve" {
|
||||||
|
t.Fatalf("expected approve, got %s", d.Decision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAuditAgentLLMContentChineseDecision(t *testing.T) {
|
||||||
|
d, err := parseAuditAgentLLMContent(`{"decision":"通过","comment":"风险低"}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if d.Decision != "approve" {
|
||||||
|
t.Fatalf("expected approve, got %s", d.Decision)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseAuditAgentLLMContentWithEditedArguments(t *testing.T) {
|
||||||
|
d, err := parseAuditAgentLLMContent(`{"decision":"approve","comment":"收窄路径","editedArguments":{"path":"/safe"}}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if d.Decision != "approve" {
|
||||||
|
t.Fatalf("expected approve, got %s", d.Decision)
|
||||||
|
}
|
||||||
|
if d.EditedArguments == nil || d.EditedArguments["path"] != "/safe" {
|
||||||
|
t.Fatalf("unexpected edited args: %+v", d.EditedArguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuditAgentReviewInputIncludesMode(t *testing.T) {
|
||||||
|
s := buildAuditAgentReviewInput("review_edit", "execute", map[string]interface{}{
|
||||||
|
"arguments": `{"command":"pwd"}`,
|
||||||
|
})
|
||||||
|
if !strings.Contains(s, "review_edit") || !strings.Contains(s, "execute") {
|
||||||
|
t.Fatalf("unexpected input: %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAuditAgentReviewInput(t *testing.T) {
|
||||||
|
s := buildAuditAgentReviewInput("approval", "nmap", map[string]interface{}{
|
||||||
|
"arguments": `{"target":"10.0.0.1"}`,
|
||||||
|
"userMessage": "扫描内网",
|
||||||
|
})
|
||||||
|
if s == "" {
|
||||||
|
t.Fatal("expected non-empty input")
|
||||||
|
}
|
||||||
|
if !strings.Contains(s, "nmap") || !strings.Contains(s, "10.0.0.1") || !strings.Contains(s, "扫描内网") {
|
||||||
|
t.Fatalf("unexpected input: %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hitlCognitionState struct {
|
||||||
|
AssistantMessageID string
|
||||||
|
UserMessage string
|
||||||
|
Thinking string
|
||||||
|
ReasoningChain string
|
||||||
|
Planning string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHitlCognition 返回当前运行任务上缓存的本轮 HITL 上下文(不含会话历史)。
|
||||||
|
func (m *AgentTaskManager) GetHitlCognition(conversationID string) hitlCognitionFields {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if m == nil || conversationID == "" {
|
||||||
|
return hitlCognitionFields{}
|
||||||
|
}
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
t, ok := m.tasks[conversationID]
|
||||||
|
if !ok || t == nil || t.hitlCognition == nil {
|
||||||
|
return hitlCognitionFields{}
|
||||||
|
}
|
||||||
|
c := t.hitlCognition
|
||||||
|
return hitlCognitionFields{
|
||||||
|
UserMessage: c.UserMessage,
|
||||||
|
Thinking: c.Thinking,
|
||||||
|
ReasoningChain: c.ReasoningChain,
|
||||||
|
Planning: c.Planning,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetHitlCognition 新任务开始时重置本轮 HITL 上下文。
|
||||||
|
func (m *AgentTaskManager) ResetHitlCognition(conversationID, userMessage string) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if m == nil || conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
t, ok := m.tasks[conversationID]
|
||||||
|
if !ok || t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(userMessage)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHitlAssistantMessageID 记录当前助手消息 ID,供 HITL 与 DB 回退对齐。
|
||||||
|
func (m *AgentTaskManager) SetHitlAssistantMessageID(conversationID, assistantMessageID string) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
assistantMessageID = strings.TrimSpace(assistantMessageID)
|
||||||
|
if m == nil || conversationID == "" || assistantMessageID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
t, ok := m.tasks[conversationID]
|
||||||
|
if !ok || t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if t.hitlCognition == nil {
|
||||||
|
t.hitlCognition = &hitlCognitionState{}
|
||||||
|
}
|
||||||
|
t.hitlCognition.AssistantMessageID = assistantMessageID
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHitlCognitionSnapshot 从进行中的进度流快照更新 thinking / reasoning / planning。
|
||||||
|
func (m *AgentTaskManager) UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoningChain, planning string) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if m == nil || conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
t, ok := m.tasks[conversationID]
|
||||||
|
if !ok || t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if t.hitlCognition == nil {
|
||||||
|
t.hitlCognition = &hitlCognitionState{}
|
||||||
|
}
|
||||||
|
if id := strings.TrimSpace(assistantMessageID); id != "" {
|
||||||
|
t.hitlCognition.AssistantMessageID = id
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(thinking); s != "" {
|
||||||
|
t.hitlCognition.Thinking = s
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(reasoningChain); s != "" {
|
||||||
|
t.hitlCognition.ReasoningChain = s
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(planning); s != "" {
|
||||||
|
t.hitlCognition.Planning = s
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
hitlPayloadUserMessage = "userMessage"
|
||||||
|
hitlPayloadThinking = "thinking"
|
||||||
|
hitlPayloadReasoningChain = "reasoningChain"
|
||||||
|
hitlPayloadPlanning = "planning"
|
||||||
|
)
|
||||||
|
|
||||||
|
type hitlCognitionFields struct {
|
||||||
|
UserMessage string
|
||||||
|
Thinking string
|
||||||
|
ReasoningChain string
|
||||||
|
Planning string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) enrichHitlApprovalPayload(conversationID, assistantMessageID string, payload map[string]interface{}) {
|
||||||
|
if h == nil || payload == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cog := h.collectHitlCognition(conversationID, assistantMessageID)
|
||||||
|
if s := strings.TrimSpace(cog.UserMessage); s != "" {
|
||||||
|
payload[hitlPayloadUserMessage] = s
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(cog.Thinking); s != "" {
|
||||||
|
payload[hitlPayloadThinking] = s
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(cog.ReasoningChain); s != "" {
|
||||||
|
payload[hitlPayloadReasoningChain] = s
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(cog.Planning); s != "" {
|
||||||
|
payload[hitlPayloadPlanning] = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) collectHitlCognition(conversationID, assistantMessageID string) hitlCognitionFields {
|
||||||
|
var out hitlCognitionFields
|
||||||
|
if h.tasks != nil {
|
||||||
|
out = h.tasks.GetHitlCognition(conversationID)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(out.UserMessage) == "" && h.db != nil {
|
||||||
|
if msg, err := h.db.GetTurnUserMessage(conversationID, assistantMessageID); err == nil {
|
||||||
|
out.UserMessage = msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.db != nil && assistantMessageID != "" {
|
||||||
|
dbCog, err := h.db.GetAssistantCognitionTexts(assistantMessageID)
|
||||||
|
if err == nil {
|
||||||
|
if strings.TrimSpace(out.Thinking) == "" {
|
||||||
|
out.Thinking = dbCog.Thinking
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(out.ReasoningChain) == "" {
|
||||||
|
out.ReasoningChain = dbCog.ReasoningChain
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(out.Planning) == "" {
|
||||||
|
out.Planning = dbCog.Planning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotHitlCognitionFromStreams(thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) (thinking, reasoningChain, planning string) {
|
||||||
|
if len(thinkingStreams) > 0 {
|
||||||
|
var thinkingParts, reasoningParts []string
|
||||||
|
for _, tb := range thinkingStreams {
|
||||||
|
if tb == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(tb.b.String())
|
||||||
|
if content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tb.persistAs == "reasoning_chain" {
|
||||||
|
reasoningParts = append(reasoningParts, content)
|
||||||
|
} else {
|
||||||
|
thinkingParts = append(thinkingParts, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
thinking = strings.Join(thinkingParts, "\n\n")
|
||||||
|
reasoningChain = strings.Join(reasoningParts, "\n\n")
|
||||||
|
}
|
||||||
|
if respPlan != nil {
|
||||||
|
planning = strings.TrimSpace(respPlan.b.String())
|
||||||
|
}
|
||||||
|
return thinking, reasoningChain, planning
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) syncHitlCognitionFromProgress(conversationID, assistantMessageID string, thinkingStreams map[string]*thinkingBuf, respPlan *responsePlanAgg) {
|
||||||
|
if h == nil || h.tasks == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
thinking, reasoning, planning := snapshotHitlCognitionFromStreams(thinkingStreams, respPlan)
|
||||||
|
if thinking == "" && reasoning == "" && planning == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.tasks.UpdateHitlCognitionSnapshot(conversationID, assistantMessageID, thinking, reasoning, planning)
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnrichHitlApprovalPayload(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmp)
|
||||||
|
|
||||||
|
conv, err := db.CreateConversation("hitl ctx", database.ConversationCreateMeta{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conv: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.AddMessage(conv.ID, "user", "scan 10.0.0.1 please", nil); err != nil {
|
||||||
|
t.Fatalf("user msg: %v", err)
|
||||||
|
}
|
||||||
|
asst, err := db.AddMessage(conv.ID, "assistant", "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("asst msg: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.AddProcessDetail(asst.ID, conv.ID, "thinking", "need port scan first", nil); err != nil {
|
||||||
|
t.Fatalf("detail: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &AgentHandler{db: db, tasks: NewAgentTaskManager()}
|
||||||
|
payload := map[string]interface{}{"toolName": "nmap", "arguments": "{}"}
|
||||||
|
h.enrichHitlApprovalPayload(conv.ID, asst.ID, payload)
|
||||||
|
|
||||||
|
if got := payload["userMessage"]; got != "scan 10.0.0.1 please" {
|
||||||
|
t.Fatalf("userMessage=%v", got)
|
||||||
|
}
|
||||||
|
if got := payload["thinking"]; got != "need port scan first" {
|
||||||
|
t.Fatalf("thinking=%v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const hitlPayloadExecutionResult = "executionResult"
|
||||||
|
|
||||||
|
type hitlExecutionResult struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Result string `json:"result,omitempty"`
|
||||||
|
ToolName string `json:"toolName,omitempty"`
|
||||||
|
ToolCallID string `json:"toolCallId,omitempty"`
|
||||||
|
RecordedAt time.Time `json:"recordedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type hitlApprovedExecTrack struct {
|
||||||
|
InterruptID string
|
||||||
|
ConversationID string
|
||||||
|
ToolName string
|
||||||
|
ToolCallID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackApprovedHitlExecution 审批通过后登记,待 tool_result 回写执行结果。
|
||||||
|
func (m *HITLManager) TrackApprovedHitlExecution(interruptID, conversationID, toolName, toolCallID string) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
interruptID = strings.TrimSpace(interruptID)
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if interruptID == "" || conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.approvedExec == nil {
|
||||||
|
m.approvedExec = make(map[string][]hitlApprovedExecTrack)
|
||||||
|
}
|
||||||
|
m.approvedExec[conversationID] = append(m.approvedExec[conversationID], hitlApprovedExecTrack{
|
||||||
|
InterruptID: interruptID,
|
||||||
|
ConversationID: conversationID,
|
||||||
|
ToolName: strings.TrimSpace(toolName),
|
||||||
|
ToolCallID: strings.TrimSpace(toolCallID),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *HITLManager) popApprovedInterruptForTool(conversationID, toolCallID, toolName string) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
toolCallID = strings.TrimSpace(toolCallID)
|
||||||
|
toolName = strings.TrimSpace(toolName)
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
queue := m.approvedExec[conversationID]
|
||||||
|
if len(queue) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
idx := -1
|
||||||
|
if toolCallID != "" {
|
||||||
|
for i, t := range queue {
|
||||||
|
if t.ToolCallID == toolCallID {
|
||||||
|
idx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx < 0 && toolName != "" {
|
||||||
|
for i, t := range queue {
|
||||||
|
if strings.EqualFold(t.ToolName, toolName) {
|
||||||
|
idx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
id := queue[idx].InterruptID
|
||||||
|
queue = append(queue[:idx], queue[idx+1:]...)
|
||||||
|
if len(queue) == 0 {
|
||||||
|
delete(m.approvedExec, conversationID)
|
||||||
|
} else {
|
||||||
|
m.approvedExec[conversationID] = queue
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeHitlPayloadExecutionResult(payloadJSON string, exec hitlExecutionResult) (string, error) {
|
||||||
|
root := make(map[string]interface{})
|
||||||
|
if strings.TrimSpace(payloadJSON) != "" {
|
||||||
|
_ = json.Unmarshal([]byte(payloadJSON), &root)
|
||||||
|
}
|
||||||
|
if root == nil {
|
||||||
|
root = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
root[hitlPayloadExecutionResult] = exec
|
||||||
|
out, err := json.Marshal(root)
|
||||||
|
if err != nil {
|
||||||
|
return payloadJSON, err
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) recordHitlToolExecutionResult(conversationID, toolCallID, toolName string, success bool, result string) {
|
||||||
|
if h == nil || h.hitlManager == nil || h.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
interruptID := h.hitlManager.popApprovedInterruptForTool(conversationID, toolCallID, toolName)
|
||||||
|
if interruptID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var payloadJSON string
|
||||||
|
err := h.db.QueryRow(`SELECT payload FROM hitl_interrupts WHERE id = ?`, interruptID).Scan(&payloadJSON)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
merged, err := mergeHitlPayloadExecutionResult(payloadJSON, hitlExecutionResult{
|
||||||
|
Success: success,
|
||||||
|
Result: strings.TrimSpace(result),
|
||||||
|
ToolName: strings.TrimSpace(toolName),
|
||||||
|
ToolCallID: strings.TrimSpace(toolCallID),
|
||||||
|
RecordedAt: time.Now(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = h.db.Exec(`UPDATE hitl_interrupts SET payload = ? WHERE id = ?`, merged, interruptID)
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeHitlPayloadExecutionResult(t *testing.T) {
|
||||||
|
merged, err := mergeHitlPayloadExecutionResult(`{"userMessage":"hi","toolName":"nmap"}`, hitlExecutionResult{
|
||||||
|
Success: true,
|
||||||
|
Result: "open ports: 80",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var root map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(merged), &root); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if root["userMessage"] != "hi" {
|
||||||
|
t.Fatalf("userMessage lost: %v", root["userMessage"])
|
||||||
|
}
|
||||||
|
exec, ok := root["executionResult"].(map[string]interface{})
|
||||||
|
if !ok || exec["success"] != true {
|
||||||
|
t.Fatalf("executionResult missing: %v", root["executionResult"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPopApprovedInterruptForTool(t *testing.T) {
|
||||||
|
m := NewHITLManager(nil, nil)
|
||||||
|
m.TrackApprovedHitlExecution("hitl_a", "conv1", "nmap", "tc1")
|
||||||
|
m.TrackApprovedHitlExecution("hitl_b", "conv1", "exec", "")
|
||||||
|
if id := m.popApprovedInterruptForTool("conv1", "tc1", "nmap"); id != "hitl_a" {
|
||||||
|
t.Fatalf("tc1 match=%q", id)
|
||||||
|
}
|
||||||
|
if id := m.popApprovedInterruptForTool("conv1", "", "exec"); id != "hitl_b" {
|
||||||
|
t.Fatalf("tool name match=%q", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,263 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeHitlReviewer(v string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||||
|
case "audit_agent", "agent", "ai":
|
||||||
|
return "audit_agent"
|
||||||
|
default:
|
||||||
|
return "human"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeHitlDecidedBy(v string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||||
|
case "audit_agent", "agent", "ai":
|
||||||
|
return "audit_agent"
|
||||||
|
case "system", "timeout":
|
||||||
|
return "system"
|
||||||
|
case "manual":
|
||||||
|
return "manual"
|
||||||
|
default:
|
||||||
|
return "human"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *HITLManager) migrateHitlSchemaColumns() {
|
||||||
|
_, _ = m.db.Exec(`ALTER TABLE hitl_interrupts ADD COLUMN decided_by TEXT NOT NULL DEFAULT 'human'`)
|
||||||
|
_, _ = m.db.Exec(`ALTER TABLE hitl_conversation_configs ADD COLUMN reviewer TEXT NOT NULL DEFAULT 'human'`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hitlInterruptRowToMap(
|
||||||
|
id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string,
|
||||||
|
messageID sql.NullString,
|
||||||
|
decision, comment sql.NullString,
|
||||||
|
createdAt time.Time,
|
||||||
|
decidedAt sql.NullTime,
|
||||||
|
) map[string]interface{} {
|
||||||
|
msgID := ""
|
||||||
|
if messageID.Valid {
|
||||||
|
msgID = messageID.String
|
||||||
|
}
|
||||||
|
return map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"conversationId": cid,
|
||||||
|
"messageId": msgID,
|
||||||
|
"mode": mode,
|
||||||
|
"toolName": toolName,
|
||||||
|
"toolCallId": toolCallID,
|
||||||
|
"payload": payload,
|
||||||
|
"status": rowStatus,
|
||||||
|
"decision": decision.String,
|
||||||
|
"comment": comment.String,
|
||||||
|
"decidedBy": decidedBy,
|
||||||
|
"createdAt": createdAt,
|
||||||
|
"decidedAt": func() interface{} {
|
||||||
|
if decidedAt.Valid {
|
||||||
|
return decidedAt.Time
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) buildHitlListQuery(logs bool) (string, []interface{}) {
|
||||||
|
where, args := h.buildHitlLogsWhere(logs)
|
||||||
|
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts` + where
|
||||||
|
return q, args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) buildHitlLogsWhere(logs bool) (string, []interface{}) {
|
||||||
|
q := " WHERE 1=1"
|
||||||
|
args := []interface{}{}
|
||||||
|
if logs {
|
||||||
|
q += " AND status != 'pending'"
|
||||||
|
} else {
|
||||||
|
q += " AND status = 'pending'"
|
||||||
|
}
|
||||||
|
return q, args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) appendHitlListFilters(q string, args []interface{}, c *gin.Context) (string, []interface{}) {
|
||||||
|
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||||
|
toolName := strings.TrimSpace(c.Query("toolName"))
|
||||||
|
decision := strings.TrimSpace(c.Query("decision"))
|
||||||
|
decidedBy := strings.TrimSpace(c.Query("decidedBy"))
|
||||||
|
status := strings.TrimSpace(c.Query("status"))
|
||||||
|
search := strings.TrimSpace(c.Query("q"))
|
||||||
|
|
||||||
|
if conversationID != "" {
|
||||||
|
q += " AND conversation_id = ?"
|
||||||
|
args = append(args, conversationID)
|
||||||
|
}
|
||||||
|
if toolName != "" {
|
||||||
|
q += " AND tool_name LIKE ?"
|
||||||
|
args = append(args, "%"+toolName+"%")
|
||||||
|
}
|
||||||
|
if decision != "" && decision != "all" {
|
||||||
|
q += " AND decision = ?"
|
||||||
|
args = append(args, decision)
|
||||||
|
}
|
||||||
|
if decidedBy != "" && decidedBy != "all" {
|
||||||
|
q += " AND COALESCE(decided_by,'human') = ?"
|
||||||
|
args = append(args, normalizeHitlDecidedBy(decidedBy))
|
||||||
|
}
|
||||||
|
if status != "" && status != "all" {
|
||||||
|
q += " AND status = ?"
|
||||||
|
args = append(args, status)
|
||||||
|
}
|
||||||
|
if search != "" {
|
||||||
|
like := "%" + search + "%"
|
||||||
|
q += " AND (id LIKE ? OR conversation_id LIKE ? OR tool_name LIKE ? OR payload LIKE ? OR COALESCE(decision_comment,'') LIKE ?)"
|
||||||
|
args = append(args, like, like, like, like, like)
|
||||||
|
}
|
||||||
|
return q, args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) scanHitlInterruptRows(rows *sql.Rows) ([]map[string]interface{}, error) {
|
||||||
|
items := make([]map[string]interface{}, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
||||||
|
var messageID sql.NullString
|
||||||
|
var decision, comment sql.NullString
|
||||||
|
var createdAt time.Time
|
||||||
|
var decidedAt sql.NullTime
|
||||||
|
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
items = append(items, hitlInterruptRowToMap(id, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) countHitlQuery(baseQ string, args []interface{}) (int, error) {
|
||||||
|
countQ := "SELECT COUNT(*) FROM (" + baseQ + ") AS hitl_cnt"
|
||||||
|
var total int
|
||||||
|
if err := h.db.QueryRow(countQ, args...).Scan(&total); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) ListHITLLogs(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
if page < 1 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||||
|
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
|
||||||
|
q, args := h.buildHitlListQuery(true)
|
||||||
|
q, args = h.appendHitlListFilters(q, args, c)
|
||||||
|
total, err := h.countHitlQuery(q, args)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q += " ORDER BY COALESCE(decided_at, created_at) DESC LIMIT ? OFFSET ?"
|
||||||
|
args = append(args, pageSize, offset)
|
||||||
|
rows, err := h.db.Query(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
items, err := h.scanHitlInterruptRows(rows)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize, "total": total, "retentionDays": h.hitlRetentionDays()})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) hitlRetentionDays() int {
|
||||||
|
if h.config != nil {
|
||||||
|
return h.config.Hitl.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
return config.HitlConfig{}.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteHITLLogs 批量删除或按筛选清空已决策的人机协同审计日志(不删除 pending)。
|
||||||
|
func (h *AgentHandler) DeleteHITLLogs(c *gin.Context) {
|
||||||
|
var request struct {
|
||||||
|
IDs []string `json:"ids"`
|
||||||
|
All bool `json:"all"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&request); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var deleted int64
|
||||||
|
var err error
|
||||||
|
if request.All {
|
||||||
|
where, args := h.buildHitlLogsWhere(true)
|
||||||
|
where, args = h.appendHitlListFilters(where, args, c)
|
||||||
|
deleted, err = h.db.DeleteHitlInterruptLogsMatching(where, args)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "hitl", "logs_clear", "清空人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
|
||||||
|
"deleted": deleted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(request.IDs) == 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "审计日志 ID 列表不能为空"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
deleted, err = h.db.DeleteHitlInterruptLogsByIDs(request.IDs)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "hitl", "logs_delete_batch", "批量删除人机协同审计日志", "hitl_interrupt", "", map[string]interface{}{
|
||||||
|
"count": len(request.IDs),
|
||||||
|
"deleted": deleted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功", "deleted": deleted})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) GetHITLLog(c *gin.Context) {
|
||||||
|
id := strings.TrimSpace(c.Param("id"))
|
||||||
|
if id == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, COALESCE(decided_by,'human'), created_at, decided_at FROM hitl_interrupts WHERE id = ?`
|
||||||
|
var rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy string
|
||||||
|
var messageID sql.NullString
|
||||||
|
var decision, comment sql.NullString
|
||||||
|
var createdAt time.Time
|
||||||
|
var decidedAt sql.NullTime
|
||||||
|
err := h.db.QueryRow(q, id).Scan(&rowID, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &decidedBy, &createdAt, &decidedAt)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, hitlInterruptRowToMap(rowID, cid, mode, toolName, toolCallID, payload, rowStatus, decidedBy, messageID, decision, comment, createdAt, decidedAt))
|
||||||
|
}
|
||||||
+295
-14
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,6 +24,8 @@ import (
|
|||||||
type MonitorHandler struct {
|
type MonitorHandler struct {
|
||||||
mcpServer *mcp.Server
|
mcpServer *mcp.Server
|
||||||
externalMCPMgr *mcp.ExternalMCPManager
|
externalMCPMgr *mcp.ExternalMCPManager
|
||||||
|
taskManager *AgentTaskManager
|
||||||
|
agentHandler *AgentHandler
|
||||||
executor *security.Executor
|
executor *security.Executor
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
@@ -56,16 +59,44 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
|||||||
h.externalMCPMgr = mgr
|
h.externalMCPMgr = mgr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTaskManager 设置 Agent 任务管理器(用于 Eino execute 等按 executionId 终止)。
|
||||||
|
func (h *MonitorHandler) SetTaskManager(mgr *AgentTaskManager) {
|
||||||
|
h.taskManager = mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAgentHandler 设置 Agent 处理器(MCP 监控终止与对话页「中断并继续」共用逻辑)。
|
||||||
|
func (h *MonitorHandler) SetAgentHandler(ah *AgentHandler) {
|
||||||
|
h.agentHandler = ah
|
||||||
|
}
|
||||||
|
|
||||||
|
const monitorPageTopTools = 6
|
||||||
|
|
||||||
|
// MonitorStatsSummary 工具调用汇总
|
||||||
|
type MonitorStatsSummary struct {
|
||||||
|
TotalCalls int `json:"totalCalls"`
|
||||||
|
SuccessCalls int `json:"successCalls"`
|
||||||
|
FailedCalls int `json:"failedCalls"`
|
||||||
|
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
|
||||||
|
ToolCount int `json:"toolCount"`
|
||||||
|
}
|
||||||
|
|
||||||
// MonitorResponse 监控响应
|
// MonitorResponse 监控响应
|
||||||
type MonitorResponse struct {
|
type MonitorResponse struct {
|
||||||
Executions []*mcp.ToolExecution `json:"executions"`
|
Executions []*mcp.ToolExecution `json:"executions"`
|
||||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
Summary *MonitorStatsSummary `json:"summary"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
TopTools []*mcp.ToolStats `json:"topTools"`
|
||||||
Total int `json:"total,omitempty"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
Page int `json:"page,omitempty"`
|
Total int `json:"total"`
|
||||||
PageSize int `json:"page_size,omitempty"`
|
Page int `json:"page"`
|
||||||
TotalPages int `json:"total_pages,omitempty"`
|
PageSize int `json:"pageSize"`
|
||||||
RetentionDays int `json:"retention_days,omitempty"`
|
TotalPages int `json:"totalPages"`
|
||||||
|
RetentionDays int `json:"retentionDays"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatsResponse 统计信息响应(Dashboard 等)
|
||||||
|
type StatsResponse struct {
|
||||||
|
Summary *MonitorStatsSummary `json:"summary"`
|
||||||
|
TopTools []*mcp.ToolStats `json:"topTools"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Monitor 获取监控信息
|
// Monitor 获取监控信息
|
||||||
@@ -89,8 +120,9 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
|||||||
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool)
|
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool)
|
||||||
toolName := normalizeToolNameFilter(c.Query("tool"))
|
toolName := normalizeToolNameFilter(c.Query("tool"))
|
||||||
|
|
||||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
executions, total := h.loadExecutionListWithPagination(page, pageSize, status, toolName)
|
||||||
stats := h.loadStats()
|
h.enrichExecutionsConversationID(executions)
|
||||||
|
summary, topTools := h.loadStatsSummary(monitorPageTopTools)
|
||||||
|
|
||||||
totalPages := (total + pageSize - 1) / pageSize
|
totalPages := (total + pageSize - 1) / pageSize
|
||||||
if totalPages == 0 {
|
if totalPages == 0 {
|
||||||
@@ -99,7 +131,8 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, MonitorResponse{
|
c.JSON(http.StatusOK, MonitorResponse{
|
||||||
Executions: executions,
|
Executions: executions,
|
||||||
Stats: stats,
|
Summary: summary,
|
||||||
|
TopTools: topTools,
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
Total: total,
|
Total: total,
|
||||||
Page: page,
|
Page: page,
|
||||||
@@ -121,6 +154,112 @@ func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
|||||||
return executions
|
return executions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) loadExecutionListWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||||
|
if h.db == nil {
|
||||||
|
allExecutions := h.mcpServer.GetAllExecutions()
|
||||||
|
if status != "" || toolName != "" {
|
||||||
|
filtered := make([]*mcp.ToolExecution, 0)
|
||||||
|
for _, exec := range allExecutions {
|
||||||
|
matchStatus := status == "" || exec.Status == status
|
||||||
|
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||||
|
if matchStatus && matchTool {
|
||||||
|
filtered = append(filtered, exec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allExecutions = filtered
|
||||||
|
}
|
||||||
|
total := len(allExecutions)
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
end := offset + pageSize
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
}
|
||||||
|
if offset >= total {
|
||||||
|
return []*mcp.ToolExecution{}, total
|
||||||
|
}
|
||||||
|
pageSlice := allExecutions[offset:end]
|
||||||
|
out := make([]*mcp.ToolExecution, 0, len(pageSlice))
|
||||||
|
for _, exec := range pageSlice {
|
||||||
|
if exec == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, slimToolExecution(exec))
|
||||||
|
}
|
||||||
|
return out, total
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
executions, err := h.db.LoadToolExecutionListPage(offset, pageSize, status, toolName)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("从数据库加载执行记录列表失败,回退到内存数据", zap.Error(err))
|
||||||
|
return h.loadExecutionListWithPaginationFromMemory(page, pageSize, status, toolName)
|
||||||
|
}
|
||||||
|
|
||||||
|
total, err := h.db.CountToolExecutions(status, toolName)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
|
||||||
|
total = offset + len(executions)
|
||||||
|
if len(executions) == pageSize {
|
||||||
|
total = offset + len(executions) + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return executions, total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) loadExecutionListWithPaginationFromMemory(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||||
|
allExecutions := h.mcpServer.GetAllExecutions()
|
||||||
|
if status != "" || toolName != "" {
|
||||||
|
filtered := make([]*mcp.ToolExecution, 0)
|
||||||
|
for _, exec := range allExecutions {
|
||||||
|
matchStatus := status == "" || exec.Status == status
|
||||||
|
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||||
|
if matchStatus && matchTool {
|
||||||
|
filtered = append(filtered, exec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allExecutions = filtered
|
||||||
|
}
|
||||||
|
total := len(allExecutions)
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
end := offset + pageSize
|
||||||
|
if end > total {
|
||||||
|
end = total
|
||||||
|
}
|
||||||
|
if offset >= total {
|
||||||
|
return []*mcp.ToolExecution{}, total
|
||||||
|
}
|
||||||
|
pageSlice := allExecutions[offset:end]
|
||||||
|
out := make([]*mcp.ToolExecution, 0, len(pageSlice))
|
||||||
|
for _, exec := range pageSlice {
|
||||||
|
if exec == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, slimToolExecution(exec))
|
||||||
|
}
|
||||||
|
return out, total
|
||||||
|
}
|
||||||
|
|
||||||
|
func slimToolExecution(exec *mcp.ToolExecution) *mcp.ToolExecution {
|
||||||
|
if exec == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
slim := &mcp.ToolExecution{
|
||||||
|
ID: exec.ID,
|
||||||
|
ToolName: exec.ToolName,
|
||||||
|
Status: exec.Status,
|
||||||
|
StartTime: exec.StartTime,
|
||||||
|
}
|
||||||
|
if exec.EndTime != nil {
|
||||||
|
end := *exec.EndTime
|
||||||
|
slim.EndTime = &end
|
||||||
|
}
|
||||||
|
if exec.Duration > 0 {
|
||||||
|
slim.Duration = exec.Duration
|
||||||
|
}
|
||||||
|
return slim
|
||||||
|
}
|
||||||
|
|
||||||
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||||
if h.db == nil {
|
if h.db == nil {
|
||||||
allExecutions := h.mcpServer.GetAllExecutions()
|
allExecutions := h.mcpServer.GetAllExecutions()
|
||||||
@@ -193,7 +332,78 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
|
|||||||
return executions, total
|
return executions, total
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
func (h *MonitorHandler) loadStatsSummary(topN int) (*MonitorStatsSummary, []*mcp.ToolStats) {
|
||||||
|
if topN <= 0 {
|
||||||
|
topN = monitorPageTopTools
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.db != nil {
|
||||||
|
result, err := h.db.LoadToolStatsSummary(topN)
|
||||||
|
if err == nil {
|
||||||
|
return dbStatsSummaryToMonitor(result), result.TopTools
|
||||||
|
}
|
||||||
|
h.logger.Warn("从数据库加载统计汇总失败,回退到内存数据", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := h.loadStatsMap()
|
||||||
|
return summarizeToolStats(stats, topN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbStatsSummaryToMonitor(result *database.ToolStatsSummaryResult) *MonitorStatsSummary {
|
||||||
|
if result == nil {
|
||||||
|
return &MonitorStatsSummary{}
|
||||||
|
}
|
||||||
|
summary := &MonitorStatsSummary{
|
||||||
|
TotalCalls: result.Summary.TotalCalls,
|
||||||
|
SuccessCalls: result.Summary.SuccessCalls,
|
||||||
|
FailedCalls: result.Summary.FailedCalls,
|
||||||
|
ToolCount: result.Summary.ToolCount,
|
||||||
|
}
|
||||||
|
if result.Summary.LastCallTime != nil {
|
||||||
|
t := *result.Summary.LastCallTime
|
||||||
|
summary.LastCallTime = &t
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeToolStats(stats map[string]*mcp.ToolStats, topN int) (*MonitorStatsSummary, []*mcp.ToolStats) {
|
||||||
|
summary := &MonitorStatsSummary{}
|
||||||
|
if len(stats) == 0 {
|
||||||
|
return summary, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
all := make([]*mcp.ToolStats, 0, len(stats))
|
||||||
|
for _, stat := range stats {
|
||||||
|
if stat == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
summary.ToolCount++
|
||||||
|
summary.TotalCalls += stat.TotalCalls
|
||||||
|
summary.SuccessCalls += stat.SuccessCalls
|
||||||
|
summary.FailedCalls += stat.FailedCalls
|
||||||
|
if stat.LastCallTime != nil && (summary.LastCallTime == nil || stat.LastCallTime.After(*summary.LastCallTime)) {
|
||||||
|
t := *stat.LastCallTime
|
||||||
|
summary.LastCallTime = &t
|
||||||
|
}
|
||||||
|
if stat.TotalCalls > 0 {
|
||||||
|
statCopy := *stat
|
||||||
|
all = append(all, &statCopy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(all, func(i, j int) bool {
|
||||||
|
if all[i].TotalCalls == all[j].TotalCalls {
|
||||||
|
return all[i].ToolName < all[j].ToolName
|
||||||
|
}
|
||||||
|
return all[i].TotalCalls > all[j].TotalCalls
|
||||||
|
})
|
||||||
|
if len(all) > topN {
|
||||||
|
all = all[:topN]
|
||||||
|
}
|
||||||
|
return summary, all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) loadStatsMap() map[string]*mcp.ToolStats {
|
||||||
// 合并内部MCP服务器和外部MCP管理器的统计信息
|
// 合并内部MCP服务器和外部MCP管理器的统计信息
|
||||||
stats := make(map[string]*mcp.ToolStats)
|
stats := make(map[string]*mcp.ToolStats)
|
||||||
|
|
||||||
@@ -247,6 +457,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
|||||||
// 先从内部MCP服务器查找
|
// 先从内部MCP服务器查找
|
||||||
exec, exists := h.mcpServer.GetExecution(id)
|
exec, exists := h.mcpServer.GetExecution(id)
|
||||||
if exists {
|
if exists {
|
||||||
|
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||||
c.JSON(http.StatusOK, exec)
|
c.JSON(http.StatusOK, exec)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -255,6 +466,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
|||||||
if h.externalMCPMgr != nil {
|
if h.externalMCPMgr != nil {
|
||||||
exec, exists = h.externalMCPMgr.GetExecution(id)
|
exec, exists = h.externalMCPMgr.GetExecution(id)
|
||||||
if exists {
|
if exists {
|
||||||
|
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||||
c.JSON(http.StatusOK, exec)
|
c.JSON(http.StatusOK, exec)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -264,6 +476,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
|||||||
if h.db != nil {
|
if h.db != nil {
|
||||||
exec, err := h.db.GetToolExecution(id)
|
exec, err := h.db.GetToolExecution(id)
|
||||||
if err == nil && exec != nil {
|
if err == nil && exec != nil {
|
||||||
|
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||||
c.JSON(http.StatusOK, exec)
|
c.JSON(http.StatusOK, exec)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -290,6 +503,19 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
note = strings.TrimSpace(body.Note)
|
note = strings.TrimSpace(body.Note)
|
||||||
|
|
||||||
|
convID := h.conversationIDForRunningExecution(id)
|
||||||
|
if convID != "" && h.agentHandler != nil {
|
||||||
|
if ok, payload := h.agentHandler.cancelToolContinueAfter(convID, id, note); ok {
|
||||||
|
h.logger.Info("MCP 监控页终止工具(与对话中断并继续一致)",
|
||||||
|
zap.String("executionId", id),
|
||||||
|
zap.String("conversationId", convID),
|
||||||
|
zap.Bool("hasNote", note != ""),
|
||||||
|
)
|
||||||
|
c.JSON(http.StatusOK, payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
||||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||||
@@ -303,6 +529,52 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) enrichExecutionsConversationID(executions []*mcp.ToolExecution) {
|
||||||
|
for _, exec := range executions {
|
||||||
|
if exec == nil || exec.Status != "running" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exec.ConversationID = h.conversationIDForRunningExecution(exec.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) conversationIDForRunningExecution(executionID string) string {
|
||||||
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
if executionID == "" || h.taskManager == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if conv := h.taskManager.ConversationIDForActiveMCPExecution(executionID); conv != "" {
|
||||||
|
return conv
|
||||||
|
}
|
||||||
|
exec := h.lookupExecution(executionID)
|
||||||
|
if exec == nil || exec.Status != "running" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(exec.ToolName) == "execute" {
|
||||||
|
if onlyConv, ok := h.taskManager.ConversationIDForActiveEinoExecute(); ok {
|
||||||
|
return onlyConv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) lookupExecution(id string) *mcp.ToolExecution {
|
||||||
|
if exec, ok := h.mcpServer.GetExecution(id); ok {
|
||||||
|
return exec
|
||||||
|
}
|
||||||
|
if h.externalMCPMgr != nil {
|
||||||
|
if exec, ok := h.externalMCPMgr.GetExecution(id); ok {
|
||||||
|
return exec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.db != nil {
|
||||||
|
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
|
||||||
|
return exec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
||||||
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
@@ -340,8 +612,17 @@ func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
|||||||
|
|
||||||
// GetStats 获取统计信息
|
// GetStats 获取统计信息
|
||||||
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||||
stats := h.loadStats()
|
topN := 30
|
||||||
c.JSON(http.StatusOK, stats)
|
if topStr := c.Query("top"); topStr != "" {
|
||||||
|
if t, err := strconv.Atoi(topStr); err == nil && t > 0 && t <= 100 {
|
||||||
|
topN = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary, topTools := h.loadStatsSummary(topN)
|
||||||
|
c.JSON(http.StatusOK, StatsResponse{
|
||||||
|
Summary: summary,
|
||||||
|
TopTools: topTools,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallsTimelinePoint 调用趋势数据点
|
// CallsTimelinePoint 调用趋势数据点
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
var mainIterationOffset int
|
var mainIterationOffset int
|
||||||
|
var emptyResponseContinueAttempt int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
segmentMainIterationMax := 0
|
segmentMainIterationMax := 0
|
||||||
@@ -243,7 +244,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
h.agentsMarkdownDir,
|
h.agentsMarkdownDir,
|
||||||
orch,
|
orch,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
h.projectBlackboardBlock(conversationID),
|
h.agentSessionContextBlock(conversationID),
|
||||||
)
|
)
|
||||||
|
|
||||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
@@ -251,6 +252,13 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
|
mw := &h.config.MultiAgent.EinoMiddleware
|
||||||
|
if h.tryContinueOnEinoEmptyResponse(taskCtx, mw, conversationID, result, &emptyResponseContinueAttempt, &curHistory, &curFinalMessage, progressCallback) {
|
||||||
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
timeoutCancel()
|
||||||
|
baseCtx, cancelWithCause, taskCtx, timeoutCancel = h.rebindEinoRunningTask(conversationID, timeoutCancel)
|
||||||
|
continue
|
||||||
|
}
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -430,7 +438,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
h.agentsMarkdownDir,
|
h.agentsMarkdownDir,
|
||||||
strings.TrimSpace(req.Orchestration),
|
strings.TrimSpace(req.Orchestration),
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
h.projectBlackboardBlock(prep.ConversationID),
|
h.agentSessionContextBlock(prep.ConversationID),
|
||||||
)
|
)
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -740,14 +740,21 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"executions": map[string]interface{}{
|
"executions": map[string]interface{}{
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"description": "执行记录列表",
|
"description": "执行记录列表(轻量字段,不含 arguments/result)",
|
||||||
"items": map[string]interface{}{
|
"items": map[string]interface{}{
|
||||||
"$ref": "#/components/schemas/ToolExecution",
|
"$ref": "#/components/schemas/ToolExecution",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"stats": map[string]interface{}{
|
"summary": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "统计信息",
|
"description": "工具调用汇总",
|
||||||
|
},
|
||||||
|
"topTools": map[string]interface{}{
|
||||||
|
"type": "array",
|
||||||
|
"description": "调用量 Top N 工具",
|
||||||
|
"items": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"timestamp": map[string]interface{}{
|
"timestamp": map[string]interface{}{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -756,20 +763,24 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
"total": map[string]interface{}{
|
"total": map[string]interface{}{
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "总数",
|
"description": "执行记录总数",
|
||||||
},
|
},
|
||||||
"page": map[string]interface{}{
|
"page": map[string]interface{}{
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "当前页",
|
"description": "当前页",
|
||||||
},
|
},
|
||||||
"page_size": map[string]interface{}{
|
"pageSize": map[string]interface{}{
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "每页数量",
|
"description": "每页数量",
|
||||||
},
|
},
|
||||||
"total_pages": map[string]interface{}{
|
"totalPages": map[string]interface{}{
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "总页数",
|
"description": "总页数",
|
||||||
},
|
},
|
||||||
|
"retentionDays": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "执行记录保留天数",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ConfigResponse": map[string]interface{}{
|
"ConfigResponse": map[string]interface{}{
|
||||||
@@ -1232,6 +1243,34 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "project_id",
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"description": "按项目筛选;传 __none__ 表示仅未绑定项目的对话",
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "exclude_grouped",
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"description": "为 true 时排除已加入分组的对话(默认在未搜索且未按项目筛选时启用)",
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "boolean",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "sort_by",
|
||||||
|
"in": "query",
|
||||||
|
"required": false,
|
||||||
|
"description": "排序字段:updated_at(默认)或 created_at",
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"updated_at", "created_at"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"responses": map[string]interface{}{
|
"responses": map[string]interface{}{
|
||||||
"200": map[string]interface{}{
|
"200": map[string]interface{}{
|
||||||
|
|||||||
@@ -7,6 +7,45 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// agentSessionContextBlock 注入会话工作目录、项目黑板与用户原文锚点(用于 system prompt 追加块)。
|
||||||
|
func (h *AgentHandler) agentSessionContextBlock(conversationID string) string {
|
||||||
|
var parts []string
|
||||||
|
if ws := h.buildWorkspaceBlock(conversationID); ws != "" {
|
||||||
|
parts = append(parts, ws)
|
||||||
|
}
|
||||||
|
if bb := h.projectBlackboardBlock(conversationID); bb != "" {
|
||||||
|
parts = append(parts, bb)
|
||||||
|
}
|
||||||
|
if uv := h.userVerbatimAnchorBlock(conversationID); uv != "" {
|
||||||
|
parts = append(parts, uv)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) buildWorkspaceBlock(conversationID string) string {
|
||||||
|
if h == nil || h.config == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
projectID := h.conversationProjectID(conversationID)
|
||||||
|
rel := project.WorkspaceRootDir(h.config.Agent.WorkspaceRootDir, projectID, conversationID)
|
||||||
|
abs, err := project.EnsureWorkspace(rel)
|
||||||
|
if err != nil {
|
||||||
|
if h.logger != nil {
|
||||||
|
h.logger.Warn("创建会话工作目录失败",
|
||||||
|
zap.String("conversationId", conversationID),
|
||||||
|
zap.String("projectId", projectID),
|
||||||
|
zap.String("path", rel),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return project.BuildWorkspaceBlock(abs)
|
||||||
|
}
|
||||||
|
|
||||||
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||||
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||||
if h == nil || h.db == nil || h.config == nil {
|
if h == nil || h.db == nil || h.config == nil {
|
||||||
@@ -31,6 +70,29 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
|||||||
return strings.TrimSpace(block)
|
return strings.TrimSpace(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userVerbatimAnchorBlock 从 messages 表构建用户各轮原文锚点(压缩后仍由 summarization Finalize 刷新)。
|
||||||
|
func (h *AgentHandler) userVerbatimAnchorBlock(conversationID string) string {
|
||||||
|
if h == nil || h.db == nil || h.config == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
maxRunes := h.config.MultiAgent.UserVerbatimAnchorMaxRunesEffective()
|
||||||
|
if maxRunes < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
msgs, err := h.db.GetMessages(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
if h.logger != nil {
|
||||||
|
h.logger.Warn("构建用户原文锚点失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return project.BuildUserVerbatimAnchorBlockFromMessages(msgs, maxRunes)
|
||||||
|
}
|
||||||
|
|
||||||
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
||||||
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
||||||
if h == nil || h.db == nil {
|
if h == nil || h.db == nil {
|
||||||
|
|||||||
+46
-23
@@ -447,7 +447,7 @@ func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RobotHandler) cmdList() string {
|
func (h *RobotHandler) cmdList() string {
|
||||||
convs, err := h.db.ListConversations(50, 0, "", "")
|
convs, err := h.db.ListConversations(50, 0, "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "获取对话列表失败: " + err.Error()
|
return "获取对话列表失败: " + err.Error()
|
||||||
}
|
}
|
||||||
@@ -711,12 +711,27 @@ type wecomReplyXML struct {
|
|||||||
Content string `xml:"Content"`
|
Content string `xml:"Content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wecomRequireToken 企业微信回调必须配置 Token;未配置时拒绝请求,防止未授权触发 Agent。
|
||||||
|
func (h *RobotHandler) wecomRequireToken(c *gin.Context) (string, bool) {
|
||||||
|
token := strings.TrimSpace(h.config.Robots.Wecom.Token)
|
||||||
|
if token == "" {
|
||||||
|
h.logger.Warn("企业微信已启用但未配置 token,已拒绝回调(请在配置中设置 robots.wecom.token)")
|
||||||
|
c.String(http.StatusForbidden, "")
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return token, true
|
||||||
|
}
|
||||||
|
|
||||||
// HandleWecomGET 企业微信 URL 校验(GET)
|
// HandleWecomGET 企业微信 URL 校验(GET)
|
||||||
func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
|
func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
|
||||||
if !h.config.Robots.Wecom.Enabled {
|
if !h.config.Robots.Wecom.Enabled {
|
||||||
c.String(http.StatusNotFound, "")
|
c.String(http.StatusNotFound, "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
token, ok := h.wecomRequireToken(c)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
|
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
|
||||||
echostr := c.Query("echostr")
|
echostr := c.Query("echostr")
|
||||||
msgSignature := c.Query("msg_signature")
|
msgSignature := c.Query("msg_signature")
|
||||||
@@ -724,7 +739,7 @@ func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
|
|||||||
nonce := c.Query("nonce")
|
nonce := c.Query("nonce")
|
||||||
|
|
||||||
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
|
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
|
||||||
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr)
|
signature := h.signWecomRequest(token, timestamp, nonce, echostr)
|
||||||
if signature != msgSignature {
|
if signature != msgSignature {
|
||||||
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
|
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
|
||||||
c.String(http.StatusBadRequest, "invalid signature")
|
c.String(http.StatusBadRequest, "invalid signature")
|
||||||
@@ -865,27 +880,28 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
|
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
|
||||||
|
|
||||||
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
|
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段。
|
||||||
// 若配置了 Token 则必须校验签名,避免未授权请求触发 Agent(防止平台被接管)
|
// 启用企业微信时必须配置 token 并校验签名,避免未授权请求触发 Agent。
|
||||||
token := h.config.Robots.Wecom.Token
|
token, ok := h.wecomRequireToken(c)
|
||||||
if token != "" {
|
if !ok {
|
||||||
if msgSignature == "" {
|
return
|
||||||
h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需配置 token 并确保回调携带 msg_signature)")
|
}
|
||||||
c.String(http.StatusOK, "")
|
if msgSignature == "" {
|
||||||
return
|
h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需确保回调携带 msg_signature)")
|
||||||
}
|
c.String(http.StatusOK, "")
|
||||||
var tmp wecomXML
|
return
|
||||||
if err := xml.Unmarshal(bodyRaw, &tmp); err != nil {
|
}
|
||||||
h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err))
|
var tmp wecomXML
|
||||||
c.String(http.StatusOK, "")
|
if err := xml.Unmarshal(bodyRaw, &tmp); err != nil {
|
||||||
return
|
h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err))
|
||||||
}
|
c.String(http.StatusOK, "")
|
||||||
expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt)
|
return
|
||||||
if expected != msgSignature {
|
}
|
||||||
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
|
expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt)
|
||||||
c.String(http.StatusOK, "")
|
if expected != msgSignature {
|
||||||
return
|
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
|
||||||
}
|
c.String(http.StatusOK, "")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var body wecomXML
|
var body wecomXML
|
||||||
@@ -899,6 +915,13 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
|||||||
// 保存企业 ID(用于明文模式回复)
|
// 保存企业 ID(用于明文模式回复)
|
||||||
enterpriseID := body.ToUserName
|
enterpriseID := body.ToUserName
|
||||||
|
|
||||||
|
// 配置了 EncodingAESKey 时必须走加密消息,拒绝明文 XML 绕过
|
||||||
|
if strings.TrimSpace(h.config.Robots.Wecom.EncodingAESKey) != "" && strings.TrimSpace(body.Encrypt) == "" {
|
||||||
|
h.logger.Warn("企业微信已配置加密模式但收到明文消息,已拒绝")
|
||||||
|
c.String(http.StatusOK, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 加密模式:先解密再解析内层 XML
|
// 加密模式:先解密再解析内层 XML
|
||||||
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
|
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
|
||||||
h.logger.Debug("企业微信进入加密模式解密流程")
|
h.logger.Debug("企业微信进入加密模式解密流程")
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newWecomTestHandler(token string, aesKey string) *RobotHandler {
|
||||||
|
return &RobotHandler{
|
||||||
|
config: &config.Config{
|
||||||
|
Robots: config.RobotsConfig{
|
||||||
|
Wecom: config.RobotWecomConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Token: token,
|
||||||
|
EncodingAESKey: aesKey,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWecomPOST_rejectsWhenTokenEmpty(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := newWecomTestHandler("", "")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
body := `<?xml version="1.0"?><xml><FromUserName>attacker</FromUserName><MsgType>text</MsgType><Content>hi</Content></xml>`
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/robot/wecom", strings.NewReader(body))
|
||||||
|
|
||||||
|
h.HandleWecomPOST(c)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
if w.Body.String() == "success" {
|
||||||
|
t.Fatal("expected rejection, got success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWecomPOST_rejectsPlaintextWhenEncryptionConfigured(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := newWecomTestHandler("secret-token", "abcdefghijklmnopqrstuvwxyz0123456789ABCD")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
body := `<?xml version="1.0"?><xml><FromUserName>attacker</FromUserName><MsgType>text</MsgType><Content>hi</Content></xml>`
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/api/robot/wecom?timestamp=1&nonce=2&msg_signature=fake", strings.NewReader(body))
|
||||||
|
|
||||||
|
h.HandleWecomPOST(c)
|
||||||
|
|
||||||
|
if w.Body.String() == "success" {
|
||||||
|
t.Fatal("expected rejection for plaintext in encryption mode, got success")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWecomGET_rejectsWhenTokenEmpty(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
h := newWecomTestHandler("", "")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/robot/wecom?msg_signature=x×tamp=1&nonce=2&echostr=abc", nil)
|
||||||
|
|
||||||
|
h.HandleWecomGET(c)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,7 @@ func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
|
|||||||
// AgentTask 描述正在运行的Agent任务
|
// AgentTask 描述正在运行的Agent任务
|
||||||
type AgentTask struct {
|
type AgentTask struct {
|
||||||
ConversationID string `json:"conversationId"`
|
ConversationID string `json:"conversationId"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
Message string `json:"message,omitempty"`
|
Message string `json:"message,omitempty"`
|
||||||
StartedAt time.Time `json:"startedAt"`
|
StartedAt time.Time `json:"startedAt"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
@@ -42,6 +43,9 @@ type AgentTask struct {
|
|||||||
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||||
activeEinoExecuteAbortNote string
|
activeEinoExecuteAbortNote string
|
||||||
|
|
||||||
|
// hitlCognition 本轮运行中供 HITL/审计 Agent 读取的上下文(用户原话 + 思考,不含会话历史)
|
||||||
|
hitlCognition *hitlCognitionState
|
||||||
|
|
||||||
cancel func(error)
|
cancel func(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +107,40 @@ func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConversationIDForActiveMCPExecution 根据当前登记的工具 executionId 反查会话 ID(供 MCP 监控页按 executionId 终止)。
|
||||||
|
func (m *AgentTaskManager) ConversationIDForActiveMCPExecution(executionID string) string {
|
||||||
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
if executionID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
for convID, t := range m.tasks {
|
||||||
|
if t != nil && t.ActiveMCPExecutionID == executionID {
|
||||||
|
return convID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationIDForActiveEinoExecute 返回当前唯一进行 Eino execute 的会话 ID;多会话并行时返回空。
|
||||||
|
func (m *AgentTaskManager) ConversationIDForActiveEinoExecute() (string, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
var found string
|
||||||
|
count := 0
|
||||||
|
for convID, t := range m.tasks {
|
||||||
|
if t != nil && t.activeEinoExecuteCancel != nil {
|
||||||
|
found = convID
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count == 1 {
|
||||||
|
return found, true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
||||||
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
||||||
conversationID = strings.TrimSpace(conversationID)
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
@@ -199,6 +237,7 @@ func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
|||||||
// CompletedTask 已完成的任务(用于历史记录)
|
// CompletedTask 已完成的任务(用于历史记录)
|
||||||
type CompletedTask struct {
|
type CompletedTask struct {
|
||||||
ConversationID string `json:"conversationId"`
|
ConversationID string `json:"conversationId"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
Message string `json:"message,omitempty"`
|
Message string `json:"message,omitempty"`
|
||||||
StartedAt time.Time `json:"startedAt"`
|
StartedAt time.Time `json:"startedAt"`
|
||||||
CompletedAt time.Time `json:"completedAt"`
|
CompletedAt time.Time `json:"completedAt"`
|
||||||
@@ -213,6 +252,8 @@ type AgentTaskManager struct {
|
|||||||
maxHistorySize int // 最大历史记录数
|
maxHistorySize int // 最大历史记录数
|
||||||
historyRetention time.Duration // 历史记录保留时间
|
historyRetention time.Duration // 历史记录保留时间
|
||||||
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
|
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
|
||||||
|
// toolCanceler 在用户整轮停止任务时终止当前 MCP 工具(非「中断并继续」)。
|
||||||
|
toolCanceler func(conversationID string)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -243,6 +284,13 @@ func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) {
|
|||||||
m.eventBus = b
|
m.eventBus = b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetToolCanceler 设置整轮停止任务时终止当前 MCP 工具的回调(由 AgentHandler 注入)。
|
||||||
|
func (m *AgentTaskManager) SetToolCanceler(fn func(conversationID string)) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.toolCanceler = fn
|
||||||
|
}
|
||||||
|
|
||||||
// GetTask 返回运行中任务(无则 nil)。
|
// GetTask 返回运行中任务(无则 nil)。
|
||||||
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
|
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
@@ -309,6 +357,7 @@ func (m *AgentTaskManager) StartTask(conversationID, message string, cancel cont
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.tasks[conversationID] = task
|
m.tasks[conversationID] = task
|
||||||
|
task.hitlCognition = &hitlCognitionState{UserMessage: strings.TrimSpace(message)}
|
||||||
return task, nil
|
return task, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -338,14 +387,21 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
|
|||||||
task.InterruptContinueNote = ""
|
task.InterruptContinueNote = ""
|
||||||
}
|
}
|
||||||
cancel := task.cancel
|
cancel := task.cancel
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
if cause == nil {
|
if cause == nil {
|
||||||
cause = ErrTaskCancelled
|
cause = ErrTaskCancelled
|
||||||
}
|
}
|
||||||
|
var toolCanceler func(string)
|
||||||
|
if errors.Is(cause, ErrTaskCancelled) {
|
||||||
|
toolCanceler = m.toolCanceler
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
if cancel != nil {
|
if cancel != nil {
|
||||||
cancel(cause)
|
cancel(cause)
|
||||||
}
|
}
|
||||||
|
if toolCanceler != nil {
|
||||||
|
toolCanceler(conversationID)
|
||||||
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,3 +38,19 @@ func TestAbortActiveEinoExecute(t *testing.T) {
|
|||||||
t.Fatal("second abort should fail when no active execute")
|
t.Fatal("second abort should fail when no active execute")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConversationIDForActiveMCPExecution(t *testing.T) {
|
||||||
|
m := NewAgentTaskManager()
|
||||||
|
conv := "conv-mcp-exec"
|
||||||
|
_, err := m.StartTask(conv, "test", func(error) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
m.RegisterRunningTool(conv, "exec-123")
|
||||||
|
if got := m.ConversationIDForActiveMCPExecution("exec-123"); got != conv {
|
||||||
|
t.Fatalf("got %q, want %q", got, conv)
|
||||||
|
}
|
||||||
|
if got := m.ConversationIDForActiveMCPExecution("missing"); got != "" {
|
||||||
|
t.Fatalf("missing should be empty, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCancelTaskInvokesToolCancelerOnFullStop(t *testing.T) {
|
||||||
|
tm := NewAgentTaskManager()
|
||||||
|
called := false
|
||||||
|
tm.SetToolCanceler(func(conversationID string) {
|
||||||
|
if conversationID == "conv-1" {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
_, cancel := context.WithCancelCause(context.Background())
|
||||||
|
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := tm.CancelTask("conv-1", ErrTaskCancelled)
|
||||||
|
if err != nil || !ok {
|
||||||
|
t.Fatalf("CancelTask: ok=%v err=%v", ok, err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("expected tool canceler to be invoked on full task cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelTaskSkipsToolCancelerOnInterruptContinue(t *testing.T) {
|
||||||
|
tm := NewAgentTaskManager()
|
||||||
|
called := false
|
||||||
|
tm.SetToolCanceler(func(conversationID string) {
|
||||||
|
called = true
|
||||||
|
})
|
||||||
|
|
||||||
|
_, cancel := context.WithCancelCause(context.Background())
|
||||||
|
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := tm.CancelTask("conv-1", multiagent.ErrInterruptContinue)
|
||||||
|
if err != nil || !ok {
|
||||||
|
t.Fatalf("CancelTask: ok=%v err=%v", ok, err)
|
||||||
|
}
|
||||||
|
if called {
|
||||||
|
t.Fatal("tool canceler must not run for interrupt-continue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCancelTaskDefaultCauseIsTaskCancelled(t *testing.T) {
|
||||||
|
tm := NewAgentTaskManager()
|
||||||
|
var gotCause error
|
||||||
|
tm.SetToolCanceler(func(conversationID string) {
|
||||||
|
if conversationID == "conv-2" {
|
||||||
|
gotCause = ErrTaskCancelled
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
if _, err := tm.StartTask("conv-2", "hello", cancel); err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tm.CancelTask("conv-2", nil); err != nil {
|
||||||
|
t.Fatalf("CancelTask: %v", err)
|
||||||
|
}
|
||||||
|
if !errors.Is(context.Cause(ctx), ErrTaskCancelled) {
|
||||||
|
t.Fatalf("expected ErrTaskCancelled cause, got %v", context.Cause(ctx))
|
||||||
|
}
|
||||||
|
if gotCause != ErrTaskCancelled {
|
||||||
|
t.Fatalf("expected tool canceler path for default cancel cause")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RunCommandWS 交互式 PTY 终端依赖 Unix PTY(见 terminal_ws_unix.go);Windows 暂不支持。
|
||||||
|
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusNotImplemented, gin.H{
|
||||||
|
"error": "Interactive WebSocket terminal is not supported on Windows; use POST /terminal/run or /terminal/run/stream instead.",
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package hitl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const retentionPurgeInterval = time.Hour
|
||||||
|
|
||||||
|
// Service manages HITL audit log retention (decided hitl_interrupts rows).
|
||||||
|
type Service struct {
|
||||||
|
db *database.DB
|
||||||
|
cfg *config.Config
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewService creates a HITL audit log retention service.
|
||||||
|
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||||
|
return &Service{db: db, cfg: cfg, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetentionDays returns configured retention; 0 means keep forever.
|
||||||
|
func (s *Service) RetentionDays() int {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return config.HitlConfig{}.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
return s.cfg.Hitl.RetentionDaysEffective()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeExpired deletes decided HITL log rows older than retention_days when configured.
|
||||||
|
func (s *Service) PurgeExpired() {
|
||||||
|
if s == nil || s.db == nil || s.cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
days := s.cfg.Hitl.RetentionDaysEffective()
|
||||||
|
if days <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -days)
|
||||||
|
n, err := s.db.PurgeHitlInterruptLogsBefore(cutoff)
|
||||||
|
if err != nil {
|
||||||
|
if s.logger != nil {
|
||||||
|
s.logger.Warn("清理过期人机协同审计日志失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > 0 && s.logger != nil {
|
||||||
|
s.logger.Info("已清理过期人机协同审计日志", zap.Int64("deleted", n), zap.Int("retention_days", days))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartRetentionLoop periodically purges expired HITL audit log rows.
|
||||||
|
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(retentionPurgeInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
s.PurgeExpired()
|
||||||
|
if logger != nil {
|
||||||
|
logger.Debug("hitl audit log retention tick completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
package hitl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
appconfig "cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "hitl.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS hitl_interrupts (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
conversation_id TEXT NOT NULL,
|
||||||
|
mode TEXT NOT NULL,
|
||||||
|
tool_name TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
decision TEXT,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
decided_at DATETIME
|
||||||
|
)`); err != nil {
|
||||||
|
t.Fatalf("create table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
old := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339)
|
||||||
|
if _, err := db.Exec(`INSERT INTO hitl_interrupts
|
||||||
|
(id, conversation_id, mode, tool_name, status, decision, created_at, decided_at)
|
||||||
|
VALUES ('old-1', 'c1', 'approval', 'exec', 'decided', 'approve', ?, ?)`, old, old); err != nil {
|
||||||
|
t.Fatalf("insert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
zero := 0
|
||||||
|
svc := NewService(db, &appconfig.Config{
|
||||||
|
Hitl: appconfig.HitlConfig{RetentionDays: &zero},
|
||||||
|
}, zap.NewNop())
|
||||||
|
svc.PurgeExpired()
|
||||||
|
|
||||||
|
if err := db.QueryRow(`SELECT id FROM hitl_interrupts WHERE id = 'old-1'`).Scan(new(string)); err != nil {
|
||||||
|
t.Fatalf("record should remain when retention_days=0: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -814,6 +814,23 @@ func (m *ExternalMCPManager) CancelToolExecution(id string) bool {
|
|||||||
return m.CancelToolExecutionWithNote(id, "")
|
return m.CancelToolExecutionWithNote(id, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ActiveRunningExecutionIDs 返回当前进程内仍登记 cancel 的外部 MCP executionId 快照。
|
||||||
|
func (m *ExternalMCPManager) ActiveRunningExecutionIDs() map[string]struct{} {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if len(m.runningCancels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]struct{}, len(m.runningCancels))
|
||||||
|
for id := range m.runningCancels {
|
||||||
|
out[id] = struct{}{}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// updateStats 更新统计信息
|
// updateStats 更新统计信息
|
||||||
func (m *ExternalMCPManager) updateStats(toolName string, failed bool) {
|
func (m *ExternalMCPManager) updateStats(toolName string, failed bool) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|||||||
+100
-16
@@ -921,9 +921,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
|||||||
return finalResult, executionID, nil
|
return finalResult, executionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
// BeginToolExecution 创建 running 状态的执行记录,供 Eino 等非 CallTool 路径在工具开始时落库。
|
||||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
func (s *Server) BeginToolExecution(toolName string, args map[string]interface{}) string {
|
||||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -931,21 +930,73 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
|||||||
args = map[string]interface{}{}
|
args = map[string]interface{}{}
|
||||||
}
|
}
|
||||||
executionID := uuid.New().String()
|
executionID := uuid.New().String()
|
||||||
now := time.Now()
|
execution := &ToolExecution{
|
||||||
failed := invokeErr != nil
|
|
||||||
exec := &ToolExecution{
|
|
||||||
ID: executionID,
|
ID: executionID,
|
||||||
ToolName: toolName,
|
ToolName: toolName,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
StartTime: now,
|
Status: "running",
|
||||||
EndTime: &now,
|
StartTime: time.Now(),
|
||||||
Duration: 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.executions[executionID] = execution
|
||||||
|
s.cleanupOldExecutions()
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.storage != nil {
|
||||||
|
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||||||
|
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return executionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinishToolExecution 完成先前 BeginToolExecution 创建的记录;executionID 为空时等同 RecordCompletedToolInvocation。
|
||||||
|
func (s *Server) FinishToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if args == nil {
|
||||||
|
args = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
id := strings.TrimSpace(executionID)
|
||||||
|
if id == "" {
|
||||||
|
return s.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
failed := invokeErr != nil
|
||||||
|
var finalResult *ToolResult
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
exec, inMem := s.executions[id]
|
||||||
|
if !inMem || exec == nil {
|
||||||
|
exec = &ToolExecution{
|
||||||
|
ID: id,
|
||||||
|
ToolName: toolName,
|
||||||
|
Arguments: args,
|
||||||
|
StartTime: now,
|
||||||
|
}
|
||||||
|
s.executions[id] = exec
|
||||||
|
} else if toolName != "" {
|
||||||
|
exec.ToolName = toolName
|
||||||
|
}
|
||||||
|
if len(args) > 0 {
|
||||||
|
exec.Arguments = args
|
||||||
|
}
|
||||||
|
exec.EndTime = &now
|
||||||
|
if exec.StartTime.IsZero() {
|
||||||
|
exec.StartTime = now
|
||||||
|
}
|
||||||
|
exec.Duration = now.Sub(exec.StartTime)
|
||||||
|
|
||||||
if failed {
|
if failed {
|
||||||
exec.Status = "failed"
|
st, msg := executionStatusAndMessage(invokeErr)
|
||||||
exec.Error = invokeErr.Error()
|
exec.Status = st
|
||||||
|
exec.Error = msg
|
||||||
if strings.TrimSpace(resultText) != "" {
|
if strings.TrimSpace(resultText) != "" {
|
||||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||||
|
exec.Result = finalResult
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
exec.Status = "completed"
|
exec.Status = "completed"
|
||||||
@@ -953,15 +1004,31 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
|||||||
if strings.TrimSpace(text) == "" {
|
if strings.TrimSpace(text) == "" {
|
||||||
text = "(无输出)"
|
text = "(无输出)"
|
||||||
}
|
}
|
||||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||||
|
exec.Result = finalResult
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
if s.storage != nil {
|
if s.storage != nil {
|
||||||
if err := s.storage.SaveToolExecution(exec); err != nil {
|
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||||
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.updateStats(toolName, failed)
|
|
||||||
return executionID
|
s.updateStats(exec.ToolName, failed)
|
||||||
|
|
||||||
|
if s.storage != nil {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.executions, id)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||||
|
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||||
|
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
return s.FinishToolExecution("", toolName, args, resultText, invokeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
||||||
@@ -1103,6 +1170,23 @@ func (s *Server) CancelToolExecution(id string) bool {
|
|||||||
return s.CancelToolExecutionWithNote(id, "")
|
return s.CancelToolExecutionWithNote(id, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ActiveRunningExecutionIDs 返回当前进程内仍登记 cancel 的 executionId 快照。
|
||||||
|
func (s *Server) ActiveRunningExecutionIDs() map[string]struct{} {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.runningCancelsMu.Lock()
|
||||||
|
defer s.runningCancelsMu.Unlock()
|
||||||
|
if len(s.runningCancels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]struct{}, len(s.runningCancels))
|
||||||
|
for id := range s.runningCancels {
|
||||||
|
out[id] = struct{}{}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// initDefaultPrompts 初始化默认提示词模板
|
// initDefaultPrompts 初始化默认提示词模板
|
||||||
func (s *Server) initDefaultPrompts() {
|
func (s *Server) initDefaultPrompts() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
|||||||
@@ -199,6 +199,8 @@ type ToolExecution struct {
|
|||||||
StartTime time.Time `json:"startTime"`
|
StartTime time.Time `json:"startTime"`
|
||||||
EndTime *time.Time `json:"endTime,omitempty"`
|
EndTime *time.Time `json:"endTime,omitempty"`
|
||||||
Duration time.Duration `json:"duration,omitempty"`
|
Duration time.Duration `json:"duration,omitempty"`
|
||||||
|
// ConversationID 仅 API 展示用(进行中的 Agent 任务),不写入 tool_executions 表。
|
||||||
|
ConversationID string `json:"conversationId,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolStats 工具统计信息
|
// ToolStats 工具统计信息
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
package monitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
staleRunningMinAge = 45 * time.Second
|
||||||
|
staleRunningReconcileGap = 2 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExecutionReconciler 在启动或运行期将无对应协程的 running 执行记录收尾为 cancelled。
|
||||||
|
type ExecutionReconciler struct {
|
||||||
|
db *database.DB
|
||||||
|
mcpServer *mcp.Server
|
||||||
|
externalMgr *mcp.ExternalMCPManager
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExecutionReconciler creates a reconciler for orphaned MCP tool executions.
|
||||||
|
func NewExecutionReconciler(db *database.DB, mcpServer *mcp.Server, externalMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ExecutionReconciler {
|
||||||
|
return &ExecutionReconciler{
|
||||||
|
db: db,
|
||||||
|
mcpServer: mcpServer,
|
||||||
|
externalMgr: externalMgr,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReconcileOnStartup marks every persisted running row as cancelled (safe right after process start).
|
||||||
|
func (r *ExecutionReconciler) ReconcileOnStartup() {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
n, err := r.db.CancelOrphanedRunningToolExecutions(now, "执行已中断(服务重启)")
|
||||||
|
if err != nil {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Warn("启动时清理孤儿 running 工具执行记录失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > 0 && r.logger != nil {
|
||||||
|
r.logger.Info("启动时已收尾孤儿 running 工具执行记录", zap.Int64("count", n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ExecutionReconciler) activeExecutionIDs() map[string]struct{} {
|
||||||
|
ids := make(map[string]struct{})
|
||||||
|
if r.mcpServer != nil {
|
||||||
|
for id := range r.mcpServer.ActiveRunningExecutionIDs() {
|
||||||
|
ids[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.externalMgr != nil {
|
||||||
|
for id := range r.externalMgr.ActiveRunningExecutionIDs() {
|
||||||
|
ids[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReconcileStaleRunning finalizes running rows that are not tracked in-memory and older than staleRunningMinAge.
|
||||||
|
func (r *ExecutionReconciler) ReconcileStaleRunning() {
|
||||||
|
if r == nil || r.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
n, err := r.db.FinalizeStaleRunningToolExecutions(now, staleRunningMinAge, r.activeExecutionIDs(), "执行已中断(会话已结束)")
|
||||||
|
if err != nil {
|
||||||
|
if r.logger != nil {
|
||||||
|
r.logger.Warn("定期收尾 stale running 工具执行记录失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > 0 && r.logger != nil {
|
||||||
|
r.logger.Info("已收尾 stale running 工具执行记录", zap.Int64("count", n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartStaleRunningReconcileLoop periodically reconciles orphaned running tool executions.
|
||||||
|
func StartStaleRunningReconcileLoop(r *ExecutionReconciler, logger *zap.Logger) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(staleRunningReconcileGap)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
r.ReconcileStaleRunning()
|
||||||
|
if logger != nil {
|
||||||
|
logger.Debug("monitor stale running reconcile tick completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package monitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExecutionReconciler_ReconcileOnStartup(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||||
|
ID: "run-1", ToolName: "hydra", Status: "running", StartTime: time.Now().Add(-time.Hour),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("SaveToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := NewExecutionReconciler(db, mcp.NewServer(zap.NewNop()), nil, zap.NewNop())
|
||||||
|
r.ReconcileOnStartup()
|
||||||
|
|
||||||
|
got, err := db.GetToolExecution("run-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetToolExecution: %v", err)
|
||||||
|
}
|
||||||
|
if got.Status != "cancelled" {
|
||||||
|
t.Fatalf("expected cancelled after startup reconcile, got %s", got.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitADK configures global Eino ADK settings. Call once at process startup before
|
||||||
|
// any ADK middleware or agents are created.
|
||||||
|
func InitADK() error {
|
||||||
|
if err := adk.SetLanguage(adk.LanguageChinese); err != nil {
|
||||||
|
return fmt.Errorf("adk set language: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/einoobserve"
|
"cyberstrike-ai/internal/einoobserve"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
@@ -90,7 +91,7 @@ type einoADKRunLoopArgs struct {
|
|||||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||||
MCPExecutionBinder *MCPExecutionBinder
|
MCPExecutionBinder *MCPExecutionBinder
|
||||||
|
|
||||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,execute/MCP 桥 Fire 时立即推送 tool_result(ADK 晚到经 toolResultSent 去重)。
|
||||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
|
|
||||||
DA adk.Agent
|
DA adk.Agent
|
||||||
@@ -196,6 +197,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
pendingByID[tc.ToolCallID] = tc
|
pendingByID[tc.ToolCallID] = tc
|
||||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||||
}
|
}
|
||||||
|
markPendingWithMonitor := func(tc toolCallPendingInfo) {
|
||||||
|
markPending(tc)
|
||||||
|
beginEinoADKFilesystemToolMonitor(
|
||||||
|
args.FilesystemMonitorAgent,
|
||||||
|
args.FilesystemMonitorRecord,
|
||||||
|
args.MCPExecutionBinder,
|
||||||
|
tc.ToolCallID,
|
||||||
|
tc.ToolName,
|
||||||
|
)
|
||||||
|
}
|
||||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||||
pendingMu.Lock()
|
pendingMu.Lock()
|
||||||
defer pendingMu.Unlock()
|
defer pendingMu.Unlock()
|
||||||
@@ -288,6 +299,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
|
|
||||||
var toolResultSent sync.Map // toolCallID -> struct{};ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
|
var toolResultSent sync.Map // toolCallID -> struct{};ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
|
||||||
tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
|
tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
|
||||||
|
// 仅由 ADK schema.Tool 事件调用;MCP/execute 桥在 reduction 前的 ToolInvokeNotify 不得推送 tool_result,
|
||||||
|
// 否则全量输出会先占位并触发 toolResultSent 去重,导致 UI/监控展示与 agent 实际收到的截断正文不一致。
|
||||||
if progress == nil {
|
if progress == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -305,6 +318,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"isError": isErr,
|
"isError": isErr,
|
||||||
"result": content,
|
"result": content,
|
||||||
"resultPreview": preview,
|
"resultPreview": preview,
|
||||||
|
"agentFacing": true, // 与 reduction 后送入 ChatModel 的正文一致,供前端展示
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"einoAgent": agentName,
|
"einoAgent": agentName,
|
||||||
"einoRole": einoRoleTag(agentName),
|
"einoRole": einoRoleTag(agentName),
|
||||||
@@ -331,7 +345,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
toolCallID = tid
|
toolCallID = tid
|
||||||
}
|
}
|
||||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, args.MCPExecutionBinder, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||||
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
||||||
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
||||||
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
||||||
@@ -339,12 +353,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||||
}
|
}
|
||||||
if args.ToolInvokeNotify != nil {
|
|
||||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
|
||||||
removePendingByID(strings.TrimSpace(toolCallID))
|
|
||||||
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.EinoCallbacks != nil {
|
if args.EinoCallbacks != nil {
|
||||||
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
|
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
|
||||||
@@ -539,6 +547,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 仅在退避重试后真正收到数据/完成一步时清零,避免重启后首个无错 ADK 事件误把计数打回 0。
|
||||||
|
confirmTransientRetryRecovery := func() {
|
||||||
|
if transientRetrier.attempt() > 0 {
|
||||||
|
transientRetrier.reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
takePartial := func(runErr error) (*RunResult, error) {
|
takePartial := func(runErr error) (*RunResult, error) {
|
||||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||||
return nil, runErr
|
return nil, runErr
|
||||||
@@ -551,10 +566,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
// iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending。
|
||||||
select {
|
ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
|
||||||
case <-ctx.Done():
|
if iterCtxErr != nil {
|
||||||
flushAllPendingAsFailed(ctx.Err())
|
flushAllPendingAsFailed(iterCtxErr)
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
if isInterruptContinue(ctx) {
|
if isInterruptContinue(ctx) {
|
||||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||||
@@ -563,17 +578,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"kind": "interrupt_continue",
|
"kind": "interrupt_continue",
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
progress("error", iterCtxErr.Error(), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return takePartial(ctx.Err())
|
return takePartial(iterCtxErr)
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ev, ok := iter.Next()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// iter 结束并不总是“正常完成”:
|
// iter 结束并不总是“正常完成”:
|
||||||
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
||||||
@@ -627,8 +639,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if restarted {
|
if restarted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
transientRetrier.reset()
|
|
||||||
}
|
}
|
||||||
if ev.AgentName != "" && progress != nil {
|
if ev.AgentName != "" && progress != nil {
|
||||||
iterEinoAgent := orchestratorName
|
iterEinoAgent := orchestratorName
|
||||||
@@ -691,34 +701,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
|
|
||||||
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||||
toolName := strings.TrimSpace(mv.ToolName)
|
toolName := strings.TrimSpace(mv.ToolName)
|
||||||
var toolBuf strings.Builder
|
content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
|
||||||
streamToolCallID := ""
|
isErr := einoToolResultIsError(toolName, content)
|
||||||
var toolStreamRecvErr error
|
content = einoToolResultBody(content)
|
||||||
for {
|
|
||||||
chunk, rerr := mv.MessageStream.Recv()
|
|
||||||
if errors.Is(rerr, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if rerr != nil {
|
|
||||||
toolStreamRecvErr = rerr
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if chunk == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if chunk.Content != "" {
|
|
||||||
toolBuf.WriteString(chunk.Content)
|
|
||||||
}
|
|
||||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
|
||||||
streamToolCallID = tid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content := toolBuf.String()
|
|
||||||
isErr := false
|
|
||||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
|
||||||
isErr = true
|
|
||||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
|
||||||
}
|
|
||||||
if streamToolCallID != "" {
|
if streamToolCallID != "" {
|
||||||
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
||||||
@@ -730,6 +715,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
zap.String("agent", ev.AgentName),
|
zap.String("agent", ev.AgentName),
|
||||||
zap.String("tool", toolName))
|
zap.String("tool", toolName))
|
||||||
}
|
}
|
||||||
|
if toolStreamRecvErr == nil {
|
||||||
|
confirmTransientRetryRecovery()
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -977,7 +965,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||||
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||||
}
|
}
|
||||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
|
||||||
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||||
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||||
@@ -1001,6 +989,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if restarted {
|
if restarted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
confirmTransientRetryRecovery()
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1010,7 +1000,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
|
||||||
|
|
||||||
if mv.Role == schema.Assistant {
|
if mv.Role == schema.Assistant {
|
||||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||||
@@ -1085,15 +1075,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
|
|
||||||
content := msg.Content
|
content := msg.Content
|
||||||
isErr := false
|
isErr := einoToolResultIsError(toolName, content)
|
||||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
content = einoToolResultBody(content)
|
||||||
isErr = true
|
|
||||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||||
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
||||||
}
|
}
|
||||||
|
confirmTransientRetryRecovery()
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpIDsMu.Lock()
|
mcpIDsMu.Lock()
|
||||||
@@ -1121,17 +1109,119 @@ func einoPartialRunLastOutputHint() string {
|
|||||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||||
}
|
}
|
||||||
|
|
||||||
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
// friendlyEinoExecuteInvokeTail 将 Eino execute 超时/中断/流异常转为简短提示。
|
||||||
|
// 命令非零退出(ExecuteExitError)已有 exec 对齐的正文,不再追加「执行未正常结束」。
|
||||||
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||||
if invokeErr == nil {
|
if invokeErr == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
var exitErr *ExecuteExitError
|
||||||
|
if errors.As(invokeErr, &exitErr) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
return einoExecuteTimeoutUserHint()
|
return einoExecuteTimeoutUserHint()
|
||||||
}
|
}
|
||||||
|
if errors.Is(invokeErr, context.Canceled) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.Contains(invokeErr.Error(), "shell inactivity timeout") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return "[执行未正常结束] " + invokeErr.Error()
|
return "[执行未正常结束] " + invokeErr.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// einoToolResultIsError 统一判断 Eino 工具结果是否应标记为错误(与 MCP exec 的 IsError 对齐)。
|
||||||
|
func einoToolResultIsError(toolName, content string) bool {
|
||||||
|
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(toolName) == "execute" && security.IsCommandFailureResult(content) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoToolResultBody 去掉工具错误前缀,返回展示/持久化正文。
|
||||||
|
func einoToolResultBody(content string) string {
|
||||||
|
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
return strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
|
||||||
|
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
|
||||||
|
if iter == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
type nextRes struct {
|
||||||
|
ev *adk.AgentEvent
|
||||||
|
ok bool
|
||||||
|
}
|
||||||
|
ch := make(chan nextRes, 1)
|
||||||
|
go func() {
|
||||||
|
e, o := iter.Next()
|
||||||
|
ch <- nextRes{e, o}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, false, ctx.Err()
|
||||||
|
case res := <-ch:
|
||||||
|
return res.ev, res.ok, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
|
||||||
|
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
|
||||||
|
if stream == nil {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
type streamMsg struct {
|
||||||
|
chunk *schema.Message
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
recvCh := make(chan streamMsg, 8)
|
||||||
|
go func() {
|
||||||
|
defer close(recvCh)
|
||||||
|
for {
|
||||||
|
ch, rerr := stream.Recv()
|
||||||
|
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var buf strings.Builder
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return buf.String(), toolCallID, ctx.Err()
|
||||||
|
case sm, open := <-recvCh:
|
||||||
|
if !open {
|
||||||
|
return buf.String(), toolCallID, nil
|
||||||
|
}
|
||||||
|
rerr := sm.err
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
return buf.String(), toolCallID, nil
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
return buf.String(), toolCallID, rerr
|
||||||
|
}
|
||||||
|
chunk := sm.chunk
|
||||||
|
if chunk == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if chunk.Content != "" {
|
||||||
|
buf.WriteString(chunk.Content)
|
||||||
|
}
|
||||||
|
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||||
|
toolCallID = tid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildEinoRunResultFromAccumulated(
|
func buildEinoRunResultFromAccumulated(
|
||||||
orchMode string,
|
orchMode string,
|
||||||
runAccumulatedMsgs []adk.Message,
|
runAccumulatedMsgs []adk.Message,
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
|
||||||
|
sw.Close()
|
||||||
|
|
||||||
|
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
if content != "hello" {
|
||||||
|
t.Fatalf("content=%q want hello", content)
|
||||||
|
}
|
||||||
|
if tid != "tc-1" {
|
||||||
|
t.Fatalf("toolCallID=%q want tc-1", tid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
t.Cleanup(func() { sw.Close() })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
content, _, err := recvSchemaMessageStream(ctx, sr)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
want := errors.New("stream broken")
|
||||||
|
_ = sw.Send(nil, want)
|
||||||
|
sw.Close()
|
||||||
|
|
||||||
|
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Fatalf("want %v, got %v", want, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
|
||||||
|
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
|
||||||
|
if err != nil || content != "" || tid != "" {
|
||||||
|
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
_ = sw.Send(nil, io.EOF)
|
||||||
|
sw.Close()
|
||||||
|
|
||||||
|
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EOF should not surface as error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultEmptyResponseContinueMaxAttempts = 5
|
||||||
|
|
||||||
|
// IsEinoEmptyResponseResult 判断 Run 是否以「未捕获助手正文」占位结束(非真实用户可见回复)。
|
||||||
|
func IsEinoEmptyResponseResult(result *RunResult) bool {
|
||||||
|
if result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isEinoEmptyResponseText(result.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isEinoEmptyResponseText(s string) bool {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(s, "no assistant text was captured") ||
|
||||||
|
strings.Contains(s, "未捕获到助手文本输出")
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasEinoResumeTrace 轨迹非空,续跑才有上下文可恢复。
|
||||||
|
func HasEinoResumeTrace(result *RunResult) bool {
|
||||||
|
if result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s := strings.TrimSpace(result.LastAgentTraceInput)
|
||||||
|
return s != "" && s != "[]" && s != "null"
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptyResponseContinueMaxAttemptsFromConfig 无助手正文时 Handler 层退避续跑上限;0=默认 5。
|
||||||
|
func EmptyResponseContinueMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||||
|
if mw != nil && mw.EmptyResponseContinueMaxAttempts > 0 {
|
||||||
|
return mw.EmptyResponseContinueMaxAttempts
|
||||||
|
}
|
||||||
|
return defaultEmptyResponseContinueMaxAttempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptyResponseContinueBackoff 与 run_retry 相同指数退避(2s, 4s, 8s… capped)。
|
||||||
|
func EmptyResponseContinueBackoff(attempt int, mw *config.MultiAgentEinoMiddlewareConfig) time.Duration {
|
||||||
|
maxBackoff := defaultEinoRunRetryMaxBackoff
|
||||||
|
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
|
||||||
|
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
|
||||||
|
}
|
||||||
|
return einoTransientRetryBackoff(attempt, maxBackoff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatEmptyResponseContinueUserMessage 系统自动续跑时注入的 user 轮次(不写入 messages 表气泡)。
|
||||||
|
func FormatEmptyResponseContinueUserMessage() string {
|
||||||
|
return strings.TrimSpace(`【系统自动续跑 / Auto resume】
|
||||||
|
上一轮 Eino 会话未产出可见助手正文(可能流式中断或仅完成工具调用)。请基于已有轨迹与工具结果继续推进,并给出阶段性总结;勿重复已完成步骤。`)
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsEinoEmptyResponseResult(t *testing.T) {
|
||||||
|
empty := &RunResult{
|
||||||
|
Response: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
|
||||||
|
"(Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||||
|
}
|
||||||
|
if !IsEinoEmptyResponseResult(empty) {
|
||||||
|
t.Fatal("expected empty placeholder response")
|
||||||
|
}
|
||||||
|
ok := &RunResult{Response: "扫描完成,发现 2 个开放端口。"}
|
||||||
|
if IsEinoEmptyResponseResult(ok) {
|
||||||
|
t.Fatalf("expected real response, got placeholder match")
|
||||||
|
}
|
||||||
|
if IsEinoEmptyResponseResult(nil) {
|
||||||
|
t.Fatal("nil result should be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasEinoResumeTrace(t *testing.T) {
|
||||||
|
if HasEinoResumeTrace(nil) {
|
||||||
|
t.Fatal("nil")
|
||||||
|
}
|
||||||
|
if HasEinoResumeTrace(&RunResult{LastAgentTraceInput: "[]"}) {
|
||||||
|
t.Fatal("enable resume on empty trace")
|
||||||
|
}
|
||||||
|
if !HasEinoResumeTrace(&RunResult{LastAgentTraceInput: `[{"role":"user","content":"hi"}]`}) {
|
||||||
|
t.Fatal("expected resume trace")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptyResponseContinueMaxAttemptsFromConfig(t *testing.T) {
|
||||||
|
if got := EmptyResponseContinueMaxAttemptsFromConfig(nil); got != defaultEmptyResponseContinueMaxAttempts {
|
||||||
|
t.Fatalf("default: got %d want %d", got, defaultEmptyResponseContinueMaxAttempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStreamingShellExitFail struct {
|
||||||
|
output string
|
||||||
|
code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShellExitFail) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
defer outW.Close()
|
||||||
|
if m.output != "" {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
||||||
|
}
|
||||||
|
code := m.code
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{ExitCode: &code}, nil)
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_CommandFailureFormat(t *testing.T) {
|
||||||
|
inner := &mockStreamingShellExitFail{
|
||||||
|
output: "sudo: a password is required\n",
|
||||||
|
code: 1,
|
||||||
|
}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
var firedBody string
|
||||||
|
var firedSuccess bool
|
||||||
|
var firedErr error
|
||||||
|
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
firedBody = content
|
||||||
|
firedSuccess = success
|
||||||
|
firedErr = invokeErr
|
||||||
|
})
|
||||||
|
wrap := &einoStreamingShellWrap{inner: inner, invokeNotify: notify}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var stream strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
stream.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if firedSuccess {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
var exitErr *ExecuteExitError
|
||||||
|
if !errors.As(firedErr, &exitErr) || exitErr.Code != 1 {
|
||||||
|
t.Fatalf("expected ExecuteExitError code 1, got %v", firedErr)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(firedBody, einomcp.ToolErrorPrefix) {
|
||||||
|
t.Fatalf("missing tool error prefix: %q", firedBody)
|
||||||
|
}
|
||||||
|
body := strings.TrimPrefix(firedBody, einomcp.ToolErrorPrefix)
|
||||||
|
if body != security.FormatCommandFailureResult(1, "sudo: a password is required\n") {
|
||||||
|
t.Fatalf("fire body = %q", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(stream.String(), "sudo:") {
|
||||||
|
t.Fatalf("stream missing sudo output: %q", stream.String())
|
||||||
|
}
|
||||||
|
if strings.Contains(stream.String(), "command exited with non-zero") {
|
||||||
|
t.Fatalf("stream has legacy noise: %q", stream.String())
|
||||||
|
}
|
||||||
|
if strings.Contains(stream.String(), "执行未正常结束") {
|
||||||
|
t.Fatalf("stream has abnormal tail: %q", stream.String())
|
||||||
|
}
|
||||||
|
if !security.IsCommandFailureResult(stream.String()) {
|
||||||
|
t.Fatalf("stream missing failure status line: %q", stream.String())
|
||||||
|
}
|
||||||
|
if tail := friendlyEinoExecuteInvokeTail(firedErr); tail != "" {
|
||||||
|
t.Fatalf("unexpected invoke tail: %q", tail)
|
||||||
|
}
|
||||||
|
if !einoToolResultIsError("execute", firedBody) {
|
||||||
|
t.Fatal("expected isError for execute failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFriendlyEinoExecuteInvokeTail(t *testing.T) {
|
||||||
|
if friendlyEinoExecuteInvokeTail(&ExecuteExitError{Code: 1}) != "" {
|
||||||
|
t.Fatal("exit error should not get abnormal tail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(friendlyEinoExecuteInvokeTail(context.DeadlineExceeded), "Timed out") {
|
||||||
|
t.Fatal("deadline should get timeout hint")
|
||||||
|
}
|
||||||
|
if friendlyEinoExecuteInvokeTail(errors.New("broken pipe")) == "" {
|
||||||
|
t.Fatal("unexpected error should get tail")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,11 +7,25 @@ import (
|
|||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
// newEinoExecuteMonitorCallbacks 在 Eino filesystem execute 开始/结束时写入 MCP 监控库并 recorder(executionId),
|
||||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
// 与 CallTool 路径一致,使监控页能展示「执行中」状态。
|
||||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
func newEinoExecuteMonitorCallbacks(ag *agent.Agent, recorder einomcp.ExecutionRecorder) (
|
||||||
return func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
begin func(toolCallID, command string) string,
|
||||||
if ag == nil || recorder == nil {
|
finish func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
|
||||||
|
) {
|
||||||
|
begin = func(toolCallID, command string) string {
|
||||||
|
if ag == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
args := map[string]interface{}{"command": command}
|
||||||
|
id := ag.BeginLocalToolExecution("execute", args)
|
||||||
|
if id != "" && recorder != nil {
|
||||||
|
recorder(id, toolCallID)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
finish = func(executionID, toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||||
|
if ag == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
@@ -23,9 +37,10 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
args := map[string]interface{}{"command": command}
|
args := map[string]interface{}{"command": command}
|
||||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
id := ag.FinishLocalToolExecution(executionID, "execute", args, stdout, err)
|
||||||
if id != "" {
|
if id != "" && recorder != nil && executionID == "" {
|
||||||
recorder(id, toolCallID)
|
recorder(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return begin, finish
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
|||||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||||
//
|
//
|
||||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
// run loop 收到 Fire 后立即推送 tool_result(toolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」。
|
||||||
//
|
//
|
||||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||||
@@ -63,8 +63,11 @@ type einoStreamingShellWrap struct {
|
|||||||
outputChunk func(toolName, toolCallID, chunk string)
|
outputChunk func(toolName, toolCallID, chunk string)
|
||||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||||
toolTimeoutMinutes int
|
toolTimeoutMinutes int
|
||||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
// shellNoOutputTimeoutSec:无任何输出时的空闲秒数;0=关闭。
|
||||||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
|
shellNoOutputTimeoutSec int
|
||||||
|
// beginMonitor 在 execute 开始时写入 running 状态;finishMonitor 在流结束后更新为 completed/failed。
|
||||||
|
beginMonitor func(toolCallID, command string) string
|
||||||
|
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
@@ -76,15 +79,26 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
}
|
}
|
||||||
req := *input
|
req := *input
|
||||||
userCmd := strings.TrimSpace(req.Command)
|
userCmd := strings.TrimSpace(req.Command)
|
||||||
|
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||||
|
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||||
req.RunInBackendGround = true
|
req.RunInBackendGround = true
|
||||||
}
|
}
|
||||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
|
||||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
|
||||||
convID := mcp.MCPConversationIDFromContext(ctx)
|
convID := mcp.MCPConversationIDFromContext(ctx)
|
||||||
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||||||
|
|
||||||
|
var monitorExecID string
|
||||||
|
if w.beginMonitor != nil {
|
||||||
|
monitorExecID = w.beginMonitor(tid, userCmd)
|
||||||
|
}
|
||||||
|
if monitorExecID != "" && convID != "" {
|
||||||
|
if toolReg := mcp.ToolRunRegistryFromContext(ctx); toolReg != nil {
|
||||||
|
toolReg.RegisterRunningTool(convID, monitorExecID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolRunReg := mcp.ToolRunRegistryFromContext(ctx)
|
||||||
|
|
||||||
execCtx, execCancel := context.WithCancel(ctx)
|
execCtx, execCancel := context.WithCancel(ctx)
|
||||||
var timeoutCancel context.CancelFunc
|
var timeoutCancel context.CancelFunc
|
||||||
if w.toolTimeoutMinutes > 0 {
|
if w.toolTimeoutMinutes > 0 {
|
||||||
@@ -104,23 +118,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
}
|
}
|
||||||
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
if w.recordMonitor != nil {
|
if w.finishMonitor != nil {
|
||||||
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
w.finishMonitor(monitorExecID, tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||||||
}
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.finishMonitor != nil {
|
||||||
w.recordMonitor(tid, userCmd, "", false, err)
|
w.finishMonitor(monitorExecID, tid, userCmd, "", false, err)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sr == nil || w.invokeNotify == nil {
|
if sr == nil {
|
||||||
if timeoutCancel != nil {
|
if timeoutCancel != nil {
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
}
|
}
|
||||||
@@ -132,7 +146,7 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
|
|
||||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||||
|
|
||||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
|
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry, toolReg mcp.ToolRunRegistry, execID string, toolCallID string, noOutputSec int) {
|
||||||
var innerCloseOnce sync.Once
|
var innerCloseOnce sync.Once
|
||||||
closeInner := func() {
|
closeInner := func() {
|
||||||
innerCloseOnce.Do(func() { inner.Close() })
|
innerCloseOnce.Do(func() { inner.Close() })
|
||||||
@@ -147,6 +161,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if reg != nil && conversationID != "" {
|
if reg != nil && conversationID != "" {
|
||||||
defer reg.UnregisterActiveEinoExecute(conversationID)
|
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||||||
}
|
}
|
||||||
|
if toolReg != nil && conversationID != "" && execID != "" {
|
||||||
|
defer toolReg.UnregisterRunningTool(conversationID, execID)
|
||||||
|
}
|
||||||
|
|
||||||
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||||||
stopWatch := make(chan struct{})
|
stopWatch := make(chan struct{})
|
||||||
@@ -165,50 +182,103 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
exitCode := 0
|
exitCode := 0
|
||||||
hasExitCode := false
|
hasExitCode := false
|
||||||
|
|
||||||
|
idleWatch := security.NewShellInactivityWatch(noOutputSec)
|
||||||
|
if idleWatch != nil {
|
||||||
|
defer idleWatch.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
type execRecvMsg struct {
|
||||||
|
resp *filesystem.ExecuteResponse
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
recvCh := make(chan execRecvMsg, 1)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
resp, rerr := inner.Recv()
|
||||||
|
recvCh <- execRecvMsg{resp: resp, err: rerr}
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
fireInactivityTimeout := func() {
|
||||||
|
success = false
|
||||||
|
invokeErr = fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||||
|
msg := security.ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: msg}, nil)
|
||||||
|
sb.WriteString(msg)
|
||||||
|
if w.outputChunk != nil && toolCallID != "" {
|
||||||
|
w.outputChunk("execute", toolCallID, msg)
|
||||||
|
}
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
closeInner()
|
||||||
|
}
|
||||||
|
|
||||||
|
recvLoop:
|
||||||
for {
|
for {
|
||||||
resp, rerr := inner.Recv()
|
var idleCh <-chan struct{}
|
||||||
if errors.Is(rerr, io.EOF) {
|
if idleWatch != nil {
|
||||||
break
|
idleCh = idleWatch.Expired
|
||||||
}
|
}
|
||||||
if rerr != nil {
|
select {
|
||||||
success = false
|
case <-idleCh:
|
||||||
invokeErr = rerr
|
fireInactivityTimeout()
|
||||||
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
break recvLoop
|
||||||
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
case msg := <-recvCh:
|
||||||
invokeErr = context.DeadlineExceeded
|
rerr := msg.err
|
||||||
break
|
resp := msg.resp
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break recvLoop
|
||||||
}
|
}
|
||||||
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
if rerr != nil {
|
||||||
invokeErr = context.Canceled
|
|
||||||
break
|
|
||||||
}
|
|
||||||
_ = outW.Send(nil, rerr)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if resp != nil {
|
|
||||||
if resp.ExitCode != nil {
|
|
||||||
hasExitCode = true
|
|
||||||
exitCode = *resp.ExitCode
|
|
||||||
}
|
|
||||||
var appended string
|
|
||||||
if resp.Output != "" {
|
|
||||||
sb.WriteString(resp.Output)
|
|
||||||
appended = resp.Output
|
|
||||||
}
|
|
||||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
|
||||||
w.outputChunk("execute", tid, appended)
|
|
||||||
}
|
|
||||||
if outW.Send(resp, nil) {
|
|
||||||
success = false
|
success = false
|
||||||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
invokeErr = rerr
|
||||||
break
|
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||||
|
invokeErr = context.DeadlineExceeded
|
||||||
|
break recvLoop
|
||||||
|
}
|
||||||
|
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||||||
|
invokeErr = context.Canceled
|
||||||
|
break recvLoop
|
||||||
|
}
|
||||||
|
_ = outW.Send(nil, rerr)
|
||||||
|
break recvLoop
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
if resp.ExitCode != nil {
|
||||||
|
hasExitCode = true
|
||||||
|
exitCode = *resp.ExitCode
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var appended string
|
||||||
|
if resp.Output != "" {
|
||||||
|
if security.IsLegacyShellExitNoise(resp.Output) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if idleWatch != nil {
|
||||||
|
idleWatch.Bump()
|
||||||
|
}
|
||||||
|
sb.WriteString(resp.Output)
|
||||||
|
appended = resp.Output
|
||||||
|
}
|
||||||
|
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||||
|
w.outputChunk("execute", toolCallID, appended)
|
||||||
|
}
|
||||||
|
if outW.Send(resp, nil) {
|
||||||
|
success = false
|
||||||
|
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||||
|
break recvLoop
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if success && hasExitCode && exitCode != 0 {
|
if success && hasExitCode && exitCode != 0 {
|
||||||
success = false
|
success = false
|
||||||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
invokeErr = &ExecuteExitError{Code: exitCode}
|
||||||
}
|
}
|
||||||
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||||
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||||
@@ -248,12 +318,24 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if w.recordMonitor != nil {
|
rawOutput := sb.String()
|
||||||
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
fireBody := rawOutput
|
||||||
|
if !success && hasExitCode && exitCode != 0 {
|
||||||
|
statusLine := security.ExecuteFailureStatusLine(exitCode)
|
||||||
|
if !strings.Contains(rawOutput, "命令执行失败:") {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: statusLine}, nil)
|
||||||
|
sb.WriteString(statusLine)
|
||||||
|
}
|
||||||
|
fireBody = einomcp.ToolErrorPrefix + security.FormatCommandFailureResult(exitCode, rawOutput)
|
||||||
|
}
|
||||||
|
if w.finishMonitor != nil {
|
||||||
|
w.finishMonitor(execID, toolCallID, command, sb.String(), success, invokeErr)
|
||||||
|
}
|
||||||
|
if w.invokeNotify != nil {
|
||||||
|
w.invokeNotify.Fire(toolCallID, "execute", agentTag, success, fireBody, invokeErr)
|
||||||
}
|
}
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
|
||||||
outW.Close()
|
outW.Close()
|
||||||
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
|
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg, toolRunReg, monitorExecID, tid, w.shellNoOutputTimeoutSec)
|
||||||
|
|
||||||
return outR, nil
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,9 +19,15 @@ type mockStreamingShell struct {
|
|||||||
immediateErr error
|
immediateErr error
|
||||||
recvErr error
|
recvErr error
|
||||||
output string
|
output string
|
||||||
|
called bool
|
||||||
|
lastCommand string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
m.called = true
|
||||||
|
if input != nil {
|
||||||
|
m.lastCommand = input.Command
|
||||||
|
}
|
||||||
if m.immediateErr != nil {
|
if m.immediateErr != nil {
|
||||||
return nil, m.immediateErr
|
return nil, m.immediateErr
|
||||||
}
|
}
|
||||||
@@ -38,6 +44,129 @@ func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesy
|
|||||||
return outR, nil
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_PreparesNonInteractiveCommand(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{output: "ok\n"}
|
||||||
|
wrap := &einoStreamingShellWrap{inner: inner}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo ok"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
for {
|
||||||
|
_, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(inner.lastCommand, "PYTHONUNBUFFERED=1") {
|
||||||
|
t.Fatalf("missing python unbuffer in inner command: %q", inner.lastCommand)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_NoOutputTimeout(t *testing.T) {
|
||||||
|
inner := &mockStreamingShellHanging{}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
var fired string
|
||||||
|
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
fired = content
|
||||||
|
})
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
invokeNotify: notify,
|
||||||
|
shellNoOutputTimeoutSec: 1,
|
||||||
|
}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !inner.called {
|
||||||
|
t.Fatal("inner shell should run (no command blacklist)")
|
||||||
|
}
|
||||||
|
out := got.String()
|
||||||
|
if !strings.Contains(out, "没有新的输出") && !strings.Contains(out, "no new output") {
|
||||||
|
t.Fatalf("expected inactivity timeout message, got: %q notify=%q", out, fired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockStreamingShellPartialThenHang struct {
|
||||||
|
called bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShellPartialThenHang) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
m.called = true
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: "[sudo] password:\n"}, nil)
|
||||||
|
<-ctx.Done()
|
||||||
|
outW.Close()
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_InactivityAfterPartialOutput(t *testing.T) {
|
||||||
|
inner := &mockStreamingShellPartialThenHang{}
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
shellNoOutputTimeoutSec: 1,
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if time.Since(start) > 5*time.Second {
|
||||||
|
t.Fatalf("expected inactivity timeout ~1s, took %v", time.Since(start))
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), "没有新的输出") && !strings.Contains(got.String(), "no new output") {
|
||||||
|
t.Fatalf("expected inactivity message, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockStreamingShellHanging struct {
|
||||||
|
called bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShellHanging) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
m.called = true
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
outW.Close()
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
||||||
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -63,10 +63,43 @@ func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName
|
|||||||
return map[string]interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// beginEinoADKFilesystemToolMonitor 在 Eino ADK filesystem 工具开始调用时写入 running 状态。
|
||||||
|
func beginEinoADKFilesystemToolMonitor(
|
||||||
|
ag *agent.Agent,
|
||||||
|
rec einomcp.ExecutionRecorder,
|
||||||
|
binder *MCPExecutionBinder,
|
||||||
|
toolCallID, toolName string,
|
||||||
|
) {
|
||||||
|
if ag == nil || rec == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(toolName)
|
||||||
|
if name == "" || strings.EqualFold(name, "execute") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isBuiltinEinoADKFilesystemToolName(name) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
if tid == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
storedName := "eino_fs::" + strings.ToLower(name)
|
||||||
|
id := ag.BeginLocalToolExecution(storedName, map[string]interface{}{})
|
||||||
|
if id == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rec(id, tid)
|
||||||
|
if binder != nil {
|
||||||
|
binder.Bind(tid, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
||||||
func recordEinoADKFilesystemToolMonitor(
|
func recordEinoADKFilesystemToolMonitor(
|
||||||
ag *agent.Agent,
|
ag *agent.Agent,
|
||||||
rec einomcp.ExecutionRecorder,
|
rec einomcp.ExecutionRecorder,
|
||||||
|
binder *MCPExecutionBinder,
|
||||||
toolName string,
|
toolName string,
|
||||||
toolCallID string,
|
toolCallID string,
|
||||||
msgs []adk.Message,
|
msgs []adk.Message,
|
||||||
@@ -94,8 +127,12 @@ func recordEinoADKFilesystemToolMonitor(
|
|||||||
invErr = errors.New(t)
|
invErr = errors.New(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
execID := ""
|
||||||
if id != "" {
|
if binder != nil {
|
||||||
|
execID = binder.ExecutionID(toolCallID)
|
||||||
|
}
|
||||||
|
id := ag.FinishLocalToolExecution(execID, storedName, args, resultText, invErr)
|
||||||
|
if id != "" && execID == "" {
|
||||||
rec(id, toolCallID)
|
rec(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -136,7 +136,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||||
}
|
}
|
||||||
@@ -184,6 +184,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
Name: einoSingleAgentName,
|
Name: einoSingleAgentName,
|
||||||
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
|
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
|
||||||
Instruction: ins,
|
Instruction: ins,
|
||||||
|
GenModelInput: literalInstructionGenModelInput,
|
||||||
Model: mainModel,
|
Model: mainModel,
|
||||||
ToolsConfig: mainToolsCfg,
|
ToolsConfig: mainToolsCfg,
|
||||||
MaxIterations: maxIter,
|
MaxIterations: maxIter,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -81,8 +82,10 @@ func subAgentFilesystemMiddleware(
|
|||||||
loc *localbk.Local,
|
loc *localbk.Local,
|
||||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||||
einoAgentName string,
|
einoAgentName string,
|
||||||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
|
beginMonitor func(toolCallID, command string) string,
|
||||||
|
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
|
||||||
toolTimeoutMinutes int,
|
toolTimeoutMinutes int,
|
||||||
|
shellNoOutputTimeoutSec int,
|
||||||
outputChunk func(toolName, toolCallID, chunk string),
|
outputChunk func(toolName, toolCallID, chunk string),
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if loc == nil {
|
if loc == nil {
|
||||||
@@ -91,12 +94,14 @@ func subAgentFilesystemMiddleware(
|
|||||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||||
Backend: loc,
|
Backend: loc,
|
||||||
StreamingShell: &einoStreamingShellWrap{
|
StreamingShell: &einoStreamingShellWrap{
|
||||||
inner: loc,
|
inner: security.NewEinoStreamingShell(),
|
||||||
invokeNotify: invokeNotify,
|
invokeNotify: invokeNotify,
|
||||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||||
outputChunk: outputChunk,
|
outputChunk: outputChunk,
|
||||||
recordMonitor: recordMonitor,
|
beginMonitor: beginMonitor,
|
||||||
toolTimeoutMinutes: toolTimeoutMinutes,
|
finishMonitor: finishMonitor,
|
||||||
|
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||||
|
shellNoOutputTimeoutSec: shellNoOutputTimeoutSec,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -108,3 +113,18 @@ func agentToolTimeoutMinutes(cfg *config.Config) int {
|
|||||||
}
|
}
|
||||||
return cfg.Agent.ToolTimeoutMinutes
|
return cfg.Agent.ToolTimeoutMinutes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// agentShellNoOutputTimeoutSeconds:0=默认 300s(5 分钟);-1=关闭;>0=自定义秒数。
|
||||||
|
func agentShellNoOutputTimeoutSeconds(cfg *config.Config) int {
|
||||||
|
if cfg == nil {
|
||||||
|
return 300
|
||||||
|
}
|
||||||
|
v := cfg.Agent.ShellNoOutputTimeoutSeconds
|
||||||
|
if v < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if v == 0 {
|
||||||
|
return 300
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ func newEinoSummarizationMiddleware(
|
|||||||
}
|
}
|
||||||
if appCfg != nil {
|
if appCfg != nil {
|
||||||
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
||||||
|
out = refreshUserVerbatimAnchorInMessages(out, db, conversationID, appCfg.MultiAgent.UserVerbatimAnchorMaxRunesEffective(), logger)
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
},
|
},
|
||||||
@@ -413,6 +414,36 @@ func writeSummarizationTranscript(path string, msgs []adk.Message) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// refreshUserVerbatimAnchorInMessages 压缩后从 messages 表刷新 system 中的用户原文锚点。
|
||||||
|
func refreshUserVerbatimAnchorInMessages(msgs []adk.Message, db *database.DB, conversationID string, maxRunes int, logger *zap.Logger) []adk.Message {
|
||||||
|
if maxRunes < 0 || db == nil {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
rows, err := db.GetMessages(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("summarization: 刷新用户原文锚点失败",
|
||||||
|
zap.String("conversationId", conversationID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
block := project.BuildUserVerbatimAnchorBlockFromMessages(rows, maxRunes)
|
||||||
|
if block == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
out := project.RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("summarization: 已刷新用户原文锚点", zap.String("conversationId", conversationID))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||||
tc := agent.NewTikTokenCounter()
|
tc := agent.NewTikTokenCounter()
|
||||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||||
|
|||||||
@@ -1,35 +1,12 @@
|
|||||||
package multiagent
|
package multiagent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/bytedance/sonic"
|
copenai "cyberstrike-ai/internal/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
// stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a
|
// stripReasoningFromSummarizationPayload removes thinking / reasoning fields from a
|
||||||
// chat-completions JSON body. Applied only to summarization Generate calls via
|
// chat-completions JSON body. Applied only to summarization Generate calls via
|
||||||
// model.ModelOptions on the shared ChatModel — main-agent requests are unchanged.
|
// model.ModelOptions on the shared ChatModel — main-agent requests are unchanged.
|
||||||
func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) {
|
func stripReasoningFromSummarizationPayload(rawBody []byte) ([]byte, error) {
|
||||||
var payload map[string]any
|
return copenai.StripReasoningFromChatCompletionBody(rawBody)
|
||||||
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
|
||||||
return rawBody, nil
|
|
||||||
}
|
|
||||||
changed := false
|
|
||||||
for _, key := range []string{
|
|
||||||
"thinking",
|
|
||||||
"reasoning_effort",
|
|
||||||
"output_config",
|
|
||||||
"reasoning",
|
|
||||||
} {
|
|
||||||
if _, ok := payload[key]; ok {
|
|
||||||
delete(payload, key)
|
|
||||||
changed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !changed {
|
|
||||||
return rawBody, nil
|
|
||||||
}
|
|
||||||
out, err := sonic.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return rawBody, err
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -409,9 +409,9 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
|||||||
"需要写入请使用 upsert_project_fact。",
|
"需要写入请使用 upsert_project_fact。",
|
||||||
project.FactIndexSectionEndMarker,
|
project.FactIndexSectionEndMarker,
|
||||||
"",
|
"",
|
||||||
"# Skills System",
|
transcriptSkillsSystemMarker,
|
||||||
"**How to Use Skills**",
|
"**如何使用 Skill(技能)(渐进式展示):**",
|
||||||
"Remember: Skills make you more capable",
|
"记住:Skill 让你更加强大和稳定",
|
||||||
}, "\n")
|
}, "\n")
|
||||||
|
|
||||||
out := sanitizeSystemContentForTranscript(system)
|
out := sanitizeSystemContentForTranscript(system)
|
||||||
@@ -421,7 +421,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
|||||||
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
|
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
|
||||||
t.Fatalf("static persona should be stripped: %q", out)
|
t.Fatalf("static persona should be stripped: %q", out)
|
||||||
}
|
}
|
||||||
if strings.Contains(out, "# Skills System") || strings.Contains(out, "How to Use Skills") {
|
if strings.Contains(out, transcriptSkillsSystemMarker) || strings.Contains(out, "如何使用 Skill") {
|
||||||
t.Fatalf("skills boilerplate should be stripped: %q", out)
|
t.Fatalf("skills boilerplate should be stripped: %q", out)
|
||||||
}
|
}
|
||||||
if !strings.Contains(out, transcriptStaticSystemOmitNote) {
|
if !strings.Contains(out, transcriptStaticSystemOmitNote) {
|
||||||
@@ -435,7 +435,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
|||||||
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
msgs := []adk.Message{
|
msgs := []adk.Message{
|
||||||
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n# Skills System\nboiler"),
|
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n" + transcriptSkillsSystemMarker + "\nboiler"),
|
||||||
schema.UserMessage("hello"),
|
schema.UserMessage("hello"),
|
||||||
schema.AssistantMessage("reply", nil),
|
schema.AssistantMessage("reply", nil),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ const (
|
|||||||
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
|
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
|
||||||
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
|
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
|
||||||
transcriptPersonaStartMarker = "你是CyberStrikeAI"
|
transcriptPersonaStartMarker = "你是CyberStrikeAI"
|
||||||
transcriptSkillsSystemMarker = "# Skills System"
|
// ADK LanguageChinese injects skill middleware prompt with this header (see eino adk/middlewares/skill/prompt.go).
|
||||||
|
transcriptSkillsSystemMarker = "# Skill 系统"
|
||||||
|
transcriptSkillsSystemMarkerEnglish = "# Skills System"
|
||||||
)
|
)
|
||||||
|
|
||||||
type transcriptToolCall struct {
|
type transcriptToolCall struct {
|
||||||
@@ -86,13 +88,23 @@ func stripToolNamesIndexFromSystem(s string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func stripSkillsSystemBoilerplate(s string) string {
|
func stripSkillsSystemBoilerplate(s string) string {
|
||||||
idx := strings.Index(s, transcriptSkillsSystemMarker)
|
idx := indexFirstSubstring(s, transcriptSkillsSystemMarker, transcriptSkillsSystemMarkerEnglish)
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return strings.TrimSpace(s)
|
return strings.TrimSpace(s)
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(s[:idx])
|
return strings.TrimSpace(s[:idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func indexFirstSubstring(s string, markers ...string) int {
|
||||||
|
first := -1
|
||||||
|
for _, m := range markers {
|
||||||
|
if i := strings.Index(s, m); i >= 0 && (first < 0 || i < first) {
|
||||||
|
first = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return first
|
||||||
|
}
|
||||||
|
|
||||||
func extractProjectBlackboardSection(s string) string {
|
func extractProjectBlackboardSection(s string) string {
|
||||||
start := strings.Index(s, project.FactIndexSectionStartMarker)
|
start := strings.Index(s, project.FactIndexSectionStartMarker)
|
||||||
if start < 0 {
|
if start < 0 {
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, too
|
|||||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||||
}
|
}
|
||||||
|
if s := strings.TrimSpace(injectShellToolGuidance("", names)); s != "" {
|
||||||
|
sb.WriteString(s)
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
if s := strings.TrimSpace(instruction); s != "" {
|
if s := strings.TrimSpace(instruction); s != "" {
|
||||||
sb.WriteString(s)
|
sb.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
|
|||||||
|
|
||||||
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
||||||
|
|
||||||
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。
|
// reset 在退避重试后成功推进(流/消息完整接收)时清零计数,使后续临时错误从第 1 次退避重新开始。
|
||||||
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
||||||
|
|
||||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||||
@@ -190,29 +190,26 @@ func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated [
|
|||||||
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||||
}
|
}
|
||||||
|
|
||||||
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
|
// adkMessagesHasUserContent reports whether the conversation tail is already a user turn
|
||||||
|
// with the given content. Only the last message counts: matching text in an earlier round
|
||||||
|
// (e.g. user repeats the same prompt after an assistant reply) must not suppress appending
|
||||||
|
// the new user turn — Claude 4.6+ rejects requests whose final message is assistant.
|
||||||
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
||||||
want = strings.TrimSpace(want)
|
want = strings.TrimSpace(want)
|
||||||
if want == "" {
|
if want == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for i := len(msgs) - 1; i >= 0; i-- {
|
if len(msgs) == 0 {
|
||||||
m := msgs[i]
|
return false
|
||||||
if m == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if m.Role == schema.User {
|
|
||||||
return strings.TrimSpace(m.Content) == want
|
|
||||||
}
|
|
||||||
if m.Role == schema.Assistant || m.Role == schema.Tool {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
return false
|
last := msgs[len(msgs)-1]
|
||||||
|
if last == nil || last.Role != schema.User {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(last.Content) == want
|
||||||
}
|
}
|
||||||
|
|
||||||
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
|
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当尾部已是相同 user 句)。
|
||||||
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
||||||
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
||||||
return msgs
|
return msgs
|
||||||
|
|||||||
@@ -105,6 +105,32 @@ func TestEinoTransientRunRetrierReset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEinoTransientRunRetrierConsecutiveFailures(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
|
||||||
|
ctx := context.Background()
|
||||||
|
runErr := errors.New("internal server error")
|
||||||
|
args := &einoADKRunLoopArgs{}
|
||||||
|
base := []adk.Message{schema.UserMessage("hi")}
|
||||||
|
|
||||||
|
for want := 1; want <= 3; want++ {
|
||||||
|
restarted, _, _, _, err := r.tryRetry(ctx, runErr, args, base, nil, len(base))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tryRetry attempt %d: %v", want, err)
|
||||||
|
}
|
||||||
|
if !restarted {
|
||||||
|
t.Fatalf("tryRetry attempt %d: want restarted", want)
|
||||||
|
}
|
||||||
|
if got := r.attempt(); got != want {
|
||||||
|
t.Fatalf("after failure %d: attempt=%d, want %d", want, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.reset()
|
||||||
|
if r.attempt() != 0 {
|
||||||
|
t.Fatalf("after successful recovery reset: attempt=%d, want 0", r.attempt())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||||
@@ -117,3 +143,18 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
|
|||||||
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAppendUserMessageIfNeeded_repeatPromptAfterAssistant(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
msgs := []adk.Message{
|
||||||
|
schema.UserMessage("扫描 example.com"),
|
||||||
|
schema.AssistantMessage("开始扫描...", nil),
|
||||||
|
}
|
||||||
|
out := appendUserMessageIfNeeded(msgs, "扫描 example.com")
|
||||||
|
if len(out) != 3 {
|
||||||
|
t.Fatalf("should append new user turn after assistant reply: len=%d", len(out))
|
||||||
|
}
|
||||||
|
if out[2].Role != schema.User || out[2].Content != "扫描 example.com" {
|
||||||
|
t.Fatalf("tail should be repeated user prompt, got role=%s content=%q", out[2].Role, out[2].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// ExecuteExitError 表示 execute 命令非零退出(预期失败,非超时/中断/流异常)。
|
||||||
|
type ExecuteExitError struct {
|
||||||
|
Code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExecuteExitError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return "exit status unknown"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("exit status %d", e.Code)
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// literalInstructionGenModelInput passes Instruction through as a system message without
|
||||||
|
// FString template formatting. Eino defaultGenModelInput formats instruction whenever
|
||||||
|
// SessionValues exist; prompts with literal curly braces (project blackboard "{关系边: ...}",
|
||||||
|
// JSON examples, link syntax) then fail with "could not find key".
|
||||||
|
//
|
||||||
|
// Matches eino/adk/prebuilt/deep genModelInput — the supported fix per Eino docs.
|
||||||
|
func literalInstructionGenModelInput(ctx context.Context, instruction string, input *adk.AgentInput) ([]adk.Message, error) {
|
||||||
|
msgs := make([]adk.Message, 0, len(input.Messages)+1)
|
||||||
|
if instruction != "" {
|
||||||
|
msgs = append(msgs, schema.SystemMessage(instruction))
|
||||||
|
}
|
||||||
|
msgs = append(msgs, input.Messages...)
|
||||||
|
return msgs, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLiteralInstructionGenModelInput_PreservesLiteralCurlyBraces(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
instruction := "- [finding/x] summary {关系边: discovered_on←target/dev}\n" +
|
||||||
|
"如 finding 上 {from:target/*, type:discovered_on}"
|
||||||
|
msgs, err := literalInstructionGenModelInput(context.Background(), instruction, &adk.AgentInput{
|
||||||
|
Messages: []adk.Message{schema.UserMessage("继续")},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("expected 2 messages, got %d", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Role != schema.System {
|
||||||
|
t.Fatalf("first message must be system, got %s", msgs[0].Role)
|
||||||
|
}
|
||||||
|
for _, want := range []string{"{关系边:", "{from:target/*, type:discovered_on}"} {
|
||||||
|
if !strings.Contains(msgs[0].Content, want) {
|
||||||
|
t.Fatalf("system content missing %q: %q", want, msgs[0].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package multiagent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -75,8 +74,8 @@ func hitlInvokableToolCallMiddleware() compose.InvokableToolMiddleware {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if IsHumanRejectError(err) {
|
if IsHumanRejectError(err) {
|
||||||
// Human rejection should be a soft tool result so the model can continue iterating.
|
// Human rejection should be a soft tool result so the model can continue iterating.
|
||||||
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
// tool_search 须保持 JSON,否则 Eino toolsearch 中间件解析历史时会硬崩 ChatModel。
|
||||||
input.Name, strings.TrimSpace(err.Error()))
|
msg := HitlRejectToolResult(input.Name, err.Error())
|
||||||
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
||||||
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
||||||
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
||||||
@@ -103,8 +102,7 @@ func hitlStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
|||||||
edited, err := fn(ctx, input.Name, input.Arguments)
|
edited, err := fn(ctx, input.Name, input.Arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if IsHumanRejectError(err) {
|
if IsHumanRejectError(err) {
|
||||||
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
msg := HitlRejectToolResult(input.Name, err.Error())
|
||||||
input.Name, strings.TrimSpace(err.Error()))
|
|
||||||
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||||
return &compose.StreamToolOutput{
|
return &compose.StreamToolOutput{
|
||||||
Result: schema.StreamReaderFromArray([]string{msg}),
|
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const toolSearchToolName = "tool_search"
|
||||||
|
|
||||||
|
// HitlExemptMetaTools 为编排/元工具:不直接执行攻击动作,但会阻塞 agent 控制流。
|
||||||
|
// tool_search 必须免审批,否则其 HITL 拒绝结果与 Eino toolsearch 中间件不兼容(会硬崩 ChatModel)。
|
||||||
|
var HitlExemptMetaTools = []string{
|
||||||
|
toolSearchToolName,
|
||||||
|
"skill",
|
||||||
|
"task",
|
||||||
|
"write_todos",
|
||||||
|
"transfer_to_agent",
|
||||||
|
"exit",
|
||||||
|
"TaskCreate",
|
||||||
|
"TaskGet",
|
||||||
|
"TaskUpdate",
|
||||||
|
"TaskList",
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsToolSearchTool reports whether name is the Eino dynamictool tool_search meta-tool.
|
||||||
|
func IsToolSearchTool(name string) bool {
|
||||||
|
return strings.EqualFold(strings.TrimSpace(name), toolSearchToolName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeHitlExemptMetaTools unions configured whitelist with built-in meta-tool exemptions.
|
||||||
|
func MergeHitlExemptMetaTools(configured []string) []string {
|
||||||
|
merged := make([]string, 0, len(configured)+len(HitlExemptMetaTools))
|
||||||
|
seen := make(map[string]struct{}, len(configured)+len(HitlExemptMetaTools))
|
||||||
|
add := func(name string) {
|
||||||
|
n := strings.ToLower(strings.TrimSpace(name))
|
||||||
|
if n == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[n]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[n] = struct{}{}
|
||||||
|
merged = append(merged, strings.TrimSpace(name))
|
||||||
|
}
|
||||||
|
for _, t := range configured {
|
||||||
|
add(t)
|
||||||
|
}
|
||||||
|
for _, t := range HitlExemptMetaTools {
|
||||||
|
add(t)
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
type toolSearchHitlRejectPayload struct {
|
||||||
|
SelectedTools []string `json:"selectedTools"`
|
||||||
|
HitlRejected bool `json:"_hitlRejected"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HitlRejectToolResult returns a tool result body safe for downstream consumers.
|
||||||
|
// tool_search must stay JSON-shaped so toolsearch.extractSelectedTools does not terminate the graph.
|
||||||
|
func HitlRejectToolResult(toolName, reason string) string {
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if !IsToolSearchTool(toolName) {
|
||||||
|
if reason == "" {
|
||||||
|
reason = "rejected by reviewer"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||||
|
strings.TrimSpace(toolName), reason)
|
||||||
|
}
|
||||||
|
payload := toolSearchHitlRejectPayload{
|
||||||
|
SelectedTools: []string{},
|
||||||
|
HitlRejected: true,
|
||||||
|
Reason: reason,
|
||||||
|
}
|
||||||
|
if payload.Reason == "" {
|
||||||
|
payload.Reason = "tool_search rejected by reviewer; no dynamic tools unlocked"
|
||||||
|
}
|
||||||
|
out, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return `{"selectedTools":[],"_hitlRejected":true,"reason":"tool_search rejected by reviewer"}`
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHitlRejectToolResult_toolSearchIsJSON(t *testing.T) {
|
||||||
|
raw := HitlRejectToolResult("tool_search", "rejected by user: timeout")
|
||||||
|
var payload toolSearchHitlRejectPayload
|
||||||
|
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if len(payload.SelectedTools) != 0 {
|
||||||
|
t.Fatalf("expected empty selectedTools, got %v", payload.SelectedTools)
|
||||||
|
}
|
||||||
|
if !payload.HitlRejected {
|
||||||
|
t.Fatal("expected _hitlRejected true")
|
||||||
|
}
|
||||||
|
if !strings.Contains(payload.Reason, "timeout") {
|
||||||
|
t.Fatalf("reason=%q", payload.Reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHitlRejectToolResult_otherToolKeepsLegacyText(t *testing.T) {
|
||||||
|
raw := HitlRejectToolResult("nmap", "too risky")
|
||||||
|
if strings.HasPrefix(raw, "{") {
|
||||||
|
t.Fatalf("expected legacy text, got %q", raw)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(raw, "[HITL Reject]") {
|
||||||
|
t.Fatalf("expected [HITL Reject] prefix, got %q", raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeHitlExemptMetaTools_includesToolSearch(t *testing.T) {
|
||||||
|
merged := MergeHitlExemptMetaTools([]string{"read_file"})
|
||||||
|
found := false
|
||||||
|
for _, name := range merged {
|
||||||
|
if IsToolSearchTool(name) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("tool_search missing from %v", merged)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/agents"
|
"cyberstrike-ai/internal/agents"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
||||||
@@ -122,7 +123,9 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
|
|||||||
|
|
||||||
## 表达
|
## 表达
|
||||||
|
|
||||||
在调用工具或给出计划变更前,用 2~5 句中文说明当前决策依据与期望证据形态;最终对用户交付结构化结论(发现摘要、证据、风险、下一步)。`
|
在调用工具或给出计划变更前,用 2~5 句中文说明当前决策依据与期望证据形态;最终对用户交付结构化结论(发现摘要、证据、风险、下一步)。
|
||||||
|
|
||||||
|
` + projectprompt.ShellExecExecuteGuidanceSection()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
|
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -120,7 +121,7 @@ func RunDeepAgent(
|
|||||||
mcpIDs = append(mcpIDs, id)
|
mcpIDs = append(mcpIDs, id)
|
||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
}
|
}
|
||||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
|
||||||
|
|
||||||
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
||||||
snapshotMCPIDs := func() []string {
|
snapshotMCPIDs := func() []string {
|
||||||
@@ -223,7 +224,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||||
}
|
}
|
||||||
@@ -253,10 +254,11 @@ func RunDeepAgent(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||||
Name: id,
|
Name: id,
|
||||||
Description: desc,
|
Description: desc,
|
||||||
Instruction: subInstrFinal,
|
Instruction: subInstrFinal,
|
||||||
Model: subModel,
|
GenModelInput: literalInstructionGenModelInput,
|
||||||
|
Model: subModel,
|
||||||
ToolsConfig: adk.ToolsConfig{
|
ToolsConfig: adk.ToolsConfig{
|
||||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||||
Tools: subToolsForCfg,
|
Tools: subToolsForCfg,
|
||||||
@@ -358,19 +360,28 @@ func RunDeepAgent(
|
|||||||
if einoLoc != nil && einoFSTools {
|
if einoLoc != nil && einoFSTools {
|
||||||
deepBackend = einoLoc
|
deepBackend = einoLoc
|
||||||
deepShell = &einoStreamingShellWrap{
|
deepShell = &einoStreamingShellWrap{
|
||||||
inner: einoLoc,
|
inner: security.NewEinoStreamingShell(),
|
||||||
invokeNotify: toolInvokeNotify,
|
invokeNotify: toolInvokeNotify,
|
||||||
einoAgentName: orchestratorName,
|
einoAgentName: orchestratorName,
|
||||||
outputChunk: nil,
|
outputChunk: nil,
|
||||||
recordMonitor: einoExecMonitor,
|
beginMonitor: einoExecBegin,
|
||||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
finishMonitor: einoExecFinish,
|
||||||
|
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||||
|
shellNoOutputTimeoutSec: agentShellNoOutputTimeoutSeconds(appCfg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||||
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
||||||
taskEnrichExtra := systemPromptExtra
|
var taskBlackboardSupplement string
|
||||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil {
|
if appCfg.Project.Enabled && db != nil {
|
||||||
|
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||||
|
if block, err := project.BuildFactIndexBlock(db, pid, appCfg.Project); err == nil {
|
||||||
|
taskBlackboardSupplement = strings.TrimSpace(block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunesEffective(), taskBlackboardSupplement); mw != nil {
|
||||||
deepHandlers = append(deepHandlers, mw)
|
deepHandlers = append(deepHandlers, mw)
|
||||||
}
|
}
|
||||||
if len(mainOrchestratorPre) > 0 {
|
if len(mainOrchestratorPre) > 0 {
|
||||||
@@ -421,6 +432,22 @@ func RunDeepAgent(
|
|||||||
var da adk.Agent
|
var da adk.Agent
|
||||||
switch orchMode {
|
switch orchMode {
|
||||||
case "plan_execute":
|
case "plan_execute":
|
||||||
|
plannerModelCfg := &einoopenai.ChatModelConfig{
|
||||||
|
APIKey: appCfg.OpenAI.APIKey,
|
||||||
|
BaseURL: strings.TrimSuffix(appCfg.OpenAI.BaseURL, "/"),
|
||||||
|
Model: appCfg.OpenAI.Model,
|
||||||
|
HTTPClient: httpClient,
|
||||||
|
}
|
||||||
|
reasoning.ApplyPlanExecutePlannerModelConfig(plannerModelCfg, &appCfg.OpenAI)
|
||||||
|
peMainModel, perr := einoopenai.NewChatModel(ctx, plannerModelCfg)
|
||||||
|
if perr != nil {
|
||||||
|
return nil, fmt.Errorf("plan_execute 规划模型: %w", perr)
|
||||||
|
}
|
||||||
|
if logger != nil {
|
||||||
|
logger.Info("plan_execute: planner/replanner 使用无 reasoning 的独立 ChatModel(ToolChoiceForced 兼容)",
|
||||||
|
zap.String("model", appCfg.OpenAI.Model),
|
||||||
|
)
|
||||||
|
}
|
||||||
execModel, perr := einoopenai.NewChatModel(ctx, baseModelCfg)
|
execModel, perr := einoopenai.NewChatModel(ctx, baseModelCfg)
|
||||||
if perr != nil {
|
if perr != nil {
|
||||||
return nil, fmt.Errorf("plan_execute 执行器模型: %w", perr)
|
return nil, fmt.Errorf("plan_execute 执行器模型: %w", perr)
|
||||||
@@ -428,13 +455,13 @@ func RunDeepAgent(
|
|||||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||||
var peFsMw adk.ChatModelAgentMiddleware
|
var peFsMw adk.ChatModelAgentMiddleware
|
||||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
peRoot, perr := NewPlanExecuteRoot(ctx, &PlanExecuteRootArgs{
|
peRoot, perr := NewPlanExecuteRoot(ctx, &PlanExecuteRootArgs{
|
||||||
MainToolCallingModel: mainModel,
|
MainToolCallingModel: peMainModel,
|
||||||
ExecModel: execModel,
|
ExecModel: execModel,
|
||||||
OrchInstruction: orchInstruction,
|
OrchInstruction: orchInstruction,
|
||||||
ToolsCfg: mainToolsCfg,
|
ToolsCfg: mainToolsCfg,
|
||||||
@@ -469,6 +496,7 @@ func RunDeepAgent(
|
|||||||
Name: orchestratorName,
|
Name: orchestratorName,
|
||||||
Description: orchDescription,
|
Description: orchDescription,
|
||||||
Instruction: supInstr,
|
Instruction: supInstr,
|
||||||
|
GenModelInput: literalInstructionGenModelInput,
|
||||||
Model: mainModel,
|
Model: mainModel,
|
||||||
ToolsConfig: mainToolsCfg,
|
ToolsConfig: mainToolsCfg,
|
||||||
MaxIterations: deepMaxIter,
|
MaxIterations: deepMaxIter,
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func shellToolsPresent(toolNames []string) bool {
|
||||||
|
for _, n := range toolNames {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(n)) {
|
||||||
|
case "exec", "execute":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectShellToolGuidance 在系统提示末尾追加 exec/execute 分工(仅当工具列表含 exec 或 execute)。
|
||||||
|
func injectShellToolGuidance(instruction string, toolNames []string) string {
|
||||||
|
if !shellToolsPresent(toolNames) {
|
||||||
|
return instruction
|
||||||
|
}
|
||||||
|
block := strings.TrimSpace(projectprompt.ShellExecExecuteGuidanceSection())
|
||||||
|
if block == "" {
|
||||||
|
return instruction
|
||||||
|
}
|
||||||
|
s := strings.TrimSpace(instruction)
|
||||||
|
if s == "" {
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
return s + "\n\n" + block
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInjectShellToolGuidance(t *testing.T) {
|
||||||
|
got := injectShellToolGuidance("base", []string{"nmap"})
|
||||||
|
if got != "base" {
|
||||||
|
t.Fatalf("expected unchanged, got %q", got)
|
||||||
|
}
|
||||||
|
got = injectShellToolGuidance("base", []string{"exec", "nmap"})
|
||||||
|
if !strings.Contains(got, "exec/execute") || !strings.Contains(got, "base") {
|
||||||
|
t.Fatalf("expected shell guidance appended, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package multiagent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
@@ -11,7 +12,7 @@ import (
|
|||||||
"github.com/cloudwego/eino/components/tool"
|
"github.com/cloudwego/eino/components/tool"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultSubAgentUserContextMaxRunes = 2000
|
const userContextSupplementHeader = "\n\n## 用户历史输入(原文,子代理必读)\n"
|
||||||
|
|
||||||
// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator
|
// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator
|
||||||
// and appends the user's original conversation messages to the task description.
|
// and appends the user's original conversation messages to the task description.
|
||||||
@@ -30,13 +31,14 @@ type taskContextEnrichMiddleware struct {
|
|||||||
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
||||||
// descriptions with user conversation context. Returns nil if disabled
|
// descriptions with user conversation context. Returns nil if disabled
|
||||||
// (maxRunes < 0) or no user messages exist.
|
// (maxRunes < 0) or no user messages exist.
|
||||||
|
// projectBlackboard 仅传项目黑板索引块(BuildFactIndexBlock);勿传完整 systemPromptExtra。
|
||||||
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
|
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
|
||||||
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
||||||
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
|
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
|
||||||
if supplement != "" {
|
if supplement != "" {
|
||||||
supplement += "\n\n## 项目黑板索引\n" + bb
|
supplement += "\n\n" + bb
|
||||||
} else {
|
} else {
|
||||||
supplement = "\n\n## 项目黑板索引\n" + bb
|
supplement = "\n\n" + bb
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if supplement == "" {
|
if supplement == "" {
|
||||||
@@ -86,9 +88,6 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
|
|||||||
if maxRunes < 0 {
|
if maxRunes < 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if maxRunes == 0 {
|
|
||||||
maxRunes = defaultSubAgentUserContextMaxRunes
|
|
||||||
}
|
|
||||||
|
|
||||||
var userMsgs []string
|
var userMsgs []string
|
||||||
for _, h := range history {
|
for _, h := range history {
|
||||||
@@ -107,12 +106,16 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
joined := strings.Join(userMsgs, "\n---\n")
|
lines := make([]string, 0, len(userMsgs))
|
||||||
if len([]rune(joined)) > maxRunes {
|
for i, msg := range userMsgs {
|
||||||
|
lines = append(lines, fmt.Sprintf("[第%d轮] %s", i+1, msg))
|
||||||
|
}
|
||||||
|
joined := strings.Join(lines, "\n")
|
||||||
|
if maxRunes > 0 && len([]rune(joined)) > maxRunes {
|
||||||
joined = truncateKeepFirstLast(userMsgs, maxRunes)
|
joined = truncateKeepFirstLast(userMsgs, maxRunes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + joined
|
return userContextSupplementHeader + joined
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateKeepFirstLast keeps the first and last user messages, giving each
|
// truncateKeepFirstLast keeps the first and last user messages, giving each
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestBuildUserContextSupplement_DisabledByNegative(t *testing.T) {
|
|||||||
func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) {
|
func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) {
|
||||||
msg := strings.Repeat("A", 200)
|
msg := strings.Repeat("A", 200)
|
||||||
result := buildUserContextSupplement(msg, nil, 50)
|
result := buildUserContextSupplement(msg, nil, 50)
|
||||||
header := "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n"
|
header := userContextSupplementHeader
|
||||||
body := strings.TrimPrefix(result, header)
|
body := strings.TrimPrefix(result, header)
|
||||||
if len([]rune(body)) > 50 {
|
if len([]rune(body)) > 50 {
|
||||||
t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body)))
|
t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body)))
|
||||||
@@ -89,7 +89,7 @@ func TestBuildUserContextSupplement_TruncateKeepsFirstAndLast(t *testing.T) {
|
|||||||
history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)})
|
history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)})
|
||||||
}
|
}
|
||||||
last := "最后一条指令"
|
last := "最后一条指令"
|
||||||
result := buildUserContextSupplement(last, history, 0)
|
result := buildUserContextSupplement(last, history, 800)
|
||||||
if !strings.Contains(result, "http://target.com") {
|
if !strings.Contains(result, "http://target.com") {
|
||||||
t.Error("first message (target URL) should survive truncation")
|
t.Error("first message (target URL) should survive truncation")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -806,10 +806,12 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool {
|
|||||||
// Eino HTTP Client Bridge
|
// Eino HTTP Client Bridge
|
||||||
// ============================================================
|
// ============================================================
|
||||||
|
|
||||||
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含两层 transport 包装:
|
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含多层 transport 包装:
|
||||||
// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明
|
// 1. 当 cfg.Provider 为 claude 时,套 claudeRoundTripper,把 OpenAI /chat/completions 透明
|
||||||
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
|
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
|
||||||
// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
|
// 2. reasoningToolChoiceCompatRoundTripper:tool_choice=required/object 时剥离 thinking 字段,避免
|
||||||
|
// plan_execute replanner 等强制工具调用与推理模式冲突(部分网关返回 400)。
|
||||||
|
// 3. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
|
||||||
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
|
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
|
||||||
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
|
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
|
||||||
//
|
//
|
||||||
@@ -825,6 +827,7 @@ func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client
|
|||||||
if transport == nil {
|
if transport == nil {
|
||||||
transport = http.DefaultTransport
|
transport = http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
transport = &reasoningToolChoiceCompatRoundTripper{base: transport}
|
||||||
if isClaudeProvider(cfg) {
|
if isClaudeProvider(cfg) {
|
||||||
transport = &claudeRoundTripper{
|
transport = &claudeRoundTripper{
|
||||||
base: transport,
|
base: transport,
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bytedance/sonic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// reasoningPayloadKeys are OpenAI-compatible root fields that enable "thinking" /
|
||||||
|
// extended-reasoning modes on gateways such as DashScope/Qwen and MiniMax.
|
||||||
|
var reasoningPayloadKeys = []string{
|
||||||
|
"thinking",
|
||||||
|
"reasoning_effort",
|
||||||
|
"output_config",
|
||||||
|
"reasoning",
|
||||||
|
}
|
||||||
|
|
||||||
|
// StripReasoningFromChatCompletionBody removes thinking / reasoning fields from a
|
||||||
|
// chat-completions JSON body.
|
||||||
|
func StripReasoningFromChatCompletionBody(rawBody []byte) ([]byte, error) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
if !stripReasoningFields(payload) {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
out, err := sonic.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return rawBody, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StripReasoningIfForcedToolChoice removes thinking / reasoning fields when the
|
||||||
|
// request sets tool_choice to "required" or an object. Several providers reject
|
||||||
|
// that combination (e.g. DashScope: "tool_choice does not support being set to
|
||||||
|
// required or object in thinking mode").
|
||||||
|
func StripReasoningIfForcedToolChoice(rawBody []byte) ([]byte, error) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := sonic.Unmarshal(rawBody, &payload); err != nil {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
if !forcedToolChoiceIncompatibleWithThinking(payload) {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
if !stripReasoningFields(payload) {
|
||||||
|
return rawBody, nil
|
||||||
|
}
|
||||||
|
out, err := sonic.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return rawBody, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripReasoningFields(payload map[string]any) bool {
|
||||||
|
changed := false
|
||||||
|
for _, key := range reasoningPayloadKeys {
|
||||||
|
if _, ok := payload[key]; ok {
|
||||||
|
delete(payload, key)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed
|
||||||
|
}
|
||||||
|
|
||||||
|
func forcedToolChoiceIncompatibleWithThinking(payload map[string]any) bool {
|
||||||
|
tc, ok := payload["tool_choice"]
|
||||||
|
if !ok || tc == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch v := tc.(type) {
|
||||||
|
case string:
|
||||||
|
return v == "required"
|
||||||
|
case map[string]any:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStripReasoningFromChatCompletionBody(t *testing.T) {
|
||||||
|
in := []byte(`{"model":"deepseek-chat","messages":[],"thinking":{"type":"enabled"},"reasoning_effort":"high"}`)
|
||||||
|
out, err := StripReasoningFromChatCompletionBody(in)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
s := string(out)
|
||||||
|
if strings.Contains(s, "thinking") || strings.Contains(s, "reasoning_effort") {
|
||||||
|
t.Fatalf("expected reasoning fields stripped, got %s", s)
|
||||||
|
}
|
||||||
|
if !strings.Contains(s, `"model":"deepseek-chat"`) {
|
||||||
|
t.Fatalf("expected model preserved, got %s", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
plain := []byte(`{"model":"gpt-4o","messages":[]}`)
|
||||||
|
out2, err := StripReasoningFromChatCompletionBody(plain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(out2) != string(plain) {
|
||||||
|
t.Fatalf("expected unchanged payload, got %s", out2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripReasoningIfForcedToolChoice(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
strip bool
|
||||||
|
contain string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "required strips thinking",
|
||||||
|
in: `{"model":"minimax","messages":[],"thinking":{"type":"enabled"},"tool_choice":"required","tools":[]}`,
|
||||||
|
strip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object tool_choice strips thinking",
|
||||||
|
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"},"tool_choice":{"type":"function","function":{"name":"respond"}}}`,
|
||||||
|
strip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auto keeps thinking",
|
||||||
|
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"},"tool_choice":"auto"}`,
|
||||||
|
strip: false,
|
||||||
|
contain: "thinking",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no tool_choice keeps thinking",
|
||||||
|
in: `{"model":"qwen","messages":[],"thinking":{"type":"enabled"}}`,
|
||||||
|
strip: false,
|
||||||
|
contain: "thinking",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
out, err := StripReasoningIfForcedToolChoice([]byte(tc.in))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
s := string(out)
|
||||||
|
hasThinking := strings.Contains(s, "thinking")
|
||||||
|
if tc.strip && hasThinking {
|
||||||
|
t.Fatalf("expected thinking stripped, got %s", s)
|
||||||
|
}
|
||||||
|
if !tc.strip && tc.contain != "" && !strings.Contains(s, tc.contain) {
|
||||||
|
t.Fatalf("expected %q in %s", tc.contain, s)
|
||||||
|
}
|
||||||
|
if !tc.strip && string(out) != tc.in {
|
||||||
|
t.Fatalf("expected unchanged payload, got %s", s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReasoningToolChoiceCompatRoundTripper(t *testing.T) {
|
||||||
|
var gotBody string
|
||||||
|
rt := &reasoningToolChoiceCompatRoundTripper{
|
||||||
|
base: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
b, _ := io.ReadAll(req.Body)
|
||||||
|
gotBody = string(b)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"choices":[{"message":{"content":"ok"}}]}`)),
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "https://example.com/v1/chat/completions", strings.NewReader(
|
||||||
|
`{"model":"m","thinking":{"type":"enabled"},"tool_choice":"required","messages":[]}`,
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(gotBody, "thinking") {
|
||||||
|
t.Fatalf("expected thinking stripped in transit, got %s", gotBody)
|
||||||
|
}
|
||||||
|
if !strings.Contains(gotBody, `"tool_choice":"required"`) {
|
||||||
|
t.Fatalf("expected tool_choice preserved, got %s", gotBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// reasoningToolChoiceCompatRoundTripper strips thinking/reasoning fields from
|
||||||
|
// chat/completions requests that force tool_choice, which some gateways reject
|
||||||
|
// when thinking mode is enabled on the same request.
|
||||||
|
type reasoningToolChoiceCompatRoundTripper struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *reasoningToolChoiceCompatRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if rt == nil || rt.base == nil || req == nil || req.Body == nil {
|
||||||
|
if rt != nil && rt.base != nil {
|
||||||
|
return rt.base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
return http.DefaultTransport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
if req.Method != http.MethodPost || !strings.HasSuffix(req.URL.Path, "/chat/completions") {
|
||||||
|
return rt.base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
_ = req.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
patched, perr := StripReasoningIfForcedToolChoice(body)
|
||||||
|
if perr != nil {
|
||||||
|
patched = body
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(patched))
|
||||||
|
req.ContentLength = int64(len(patched))
|
||||||
|
req.Header.Set("Content-Length", strconv.Itoa(len(patched)))
|
||||||
|
return rt.base.RoundTrip(req)
|
||||||
|
}
|
||||||
@@ -0,0 +1,170 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UserVerbatimSectionHeading 用户原文锚点可读标题(块内保留,供 Agent 阅读)。
|
||||||
|
UserVerbatimSectionHeading = "## 用户历史输入(原文保留,勿省略或改写)"
|
||||||
|
|
||||||
|
// UserVerbatimSectionStartMarker / EndMarker:HTML 注释边界,供程序化替换;对模型无指令语义。
|
||||||
|
UserVerbatimSectionStartMarker = "<!-- user-verbatim-start -->"
|
||||||
|
UserVerbatimSectionEndMarker = "<!-- user-verbatim-end -->"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtractUserContentsFromMessages 按时间顺序提取 user 角色消息的原文(跳过空白)。
|
||||||
|
func ExtractUserContentsFromMessages(msgs []database.Message) []string {
|
||||||
|
out := make([]string, 0, len(msgs))
|
||||||
|
for i := range msgs {
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(msgs[i].Role), "user") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(msgs[i].Content)
|
||||||
|
if content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, content)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildUserVerbatimAnchorBlockFromMessages 从 messages 表行构建用户原文锚点块。
|
||||||
|
// maxRunes: 0 = 不截断;>0 = 总 rune 上限(仍保留每一轮,仅对超长单条做尾部截断提示)。
|
||||||
|
func BuildUserVerbatimAnchorBlockFromMessages(msgs []database.Message, maxRunes int) string {
|
||||||
|
return BuildUserVerbatimAnchorBlock(ExtractUserContentsFromMessages(msgs), maxRunes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildUserVerbatimAnchorBlock 将各轮用户原文格式化为 system prompt 锚点块。
|
||||||
|
func BuildUserVerbatimAnchorBlock(userContents []string, maxRunes int) string {
|
||||||
|
if len(userContents) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lines := make([]string, 0, len(userContents))
|
||||||
|
for _, content := range userContents {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lines = append(lines, fmt.Sprintf("[第%d轮] %s", len(lines)+1, content))
|
||||||
|
}
|
||||||
|
if len(lines) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
body := strings.Join(lines, "\n")
|
||||||
|
if maxRunes > 0 {
|
||||||
|
body = capUserVerbatimBody(body, maxRunes)
|
||||||
|
}
|
||||||
|
return wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n" + body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func capUserVerbatimBody(body string, maxRunes int) string {
|
||||||
|
rs := []rune(body)
|
||||||
|
if len(rs) <= maxRunes {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
suffix := "\n\n...(用户原文锚点已达配置上限,更早轮次可能被截断;完整原文见 messages 表)..."
|
||||||
|
suffixRunes := []rune(suffix)
|
||||||
|
keep := maxRunes - len(suffixRunes)
|
||||||
|
if keep <= 0 {
|
||||||
|
return string(rs[:maxRunes])
|
||||||
|
}
|
||||||
|
return string(rs[:keep]) + suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapUserVerbatimBlock(content string) string {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if content == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return UserVerbatimSectionStartMarker + "\n" + content + "\n" + UserVerbatimSectionEndMarker + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceUserVerbatimAnchorSection 用 freshBlock 替换 content 中已有的用户原文锚点段。
|
||||||
|
func ReplaceUserVerbatimAnchorSection(content, freshBlock string) (string, bool) {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
freshBlock = strings.TrimSpace(freshBlock)
|
||||||
|
if freshBlock == "" {
|
||||||
|
return content, false
|
||||||
|
}
|
||||||
|
start, ok := userVerbatimSectionStart(content)
|
||||||
|
if !ok {
|
||||||
|
return content, false
|
||||||
|
}
|
||||||
|
end, ok := userVerbatimSectionEnd(content, start)
|
||||||
|
if !ok {
|
||||||
|
return content, false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(content[:start] + freshBlock + content[end:]), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func userVerbatimSectionStart(content string) (int, bool) {
|
||||||
|
idx := strings.Index(content, UserVerbatimSectionStartMarker)
|
||||||
|
if idx < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return idx, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func userVerbatimSectionEnd(content string, start int) (int, bool) {
|
||||||
|
if start < 0 || start >= len(content) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
tail := content[start:]
|
||||||
|
idx := strings.LastIndex(tail, UserVerbatimSectionEndMarker)
|
||||||
|
if idx < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return start + idx + len(UserVerbatimSectionEndMarker), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshUserVerbatimAnchorInMessages 在 summarization 等压缩后,用 freshBlock 刷新 system 中的用户原文锚点。
|
||||||
|
// 若尚无锚点段,则追加到首条 system 消息;若无 system 消息则在开头插入一条。
|
||||||
|
func RefreshUserVerbatimAnchorInMessages(msgs []adk.Message, freshBlock string) []adk.Message {
|
||||||
|
freshBlock = strings.TrimSpace(freshBlock)
|
||||||
|
if freshBlock == "" || len(msgs) == 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]adk.Message, len(msgs))
|
||||||
|
changed := false
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if msg == nil || msg.Role != schema.System {
|
||||||
|
out[i] = msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newContent, ok := ReplaceUserVerbatimAnchorSection(msg.Content, freshBlock)
|
||||||
|
if !ok {
|
||||||
|
out[i] = msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cloned := *msg
|
||||||
|
cloned.Content = newContent
|
||||||
|
out[i] = &cloned
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if msg == nil || msg.Role != schema.System {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cloned := *msg
|
||||||
|
cloned.Content = AppendSystemPromptBlock(cloned.Content, freshBlock)
|
||||||
|
out[i] = &cloned
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := make([]adk.Message, 0, len(msgs)+1)
|
||||||
|
prefix = append(prefix, schema.SystemMessage(freshBlock))
|
||||||
|
return append(prefix, msgs...)
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildUserVerbatimAnchorBlock_MultiTurn(t *testing.T) {
|
||||||
|
msgs := []database.Message{
|
||||||
|
{Role: "user", Content: "目标 https://a.com 仅测 /api"},
|
||||||
|
{Role: "assistant", Content: "好的"},
|
||||||
|
{Role: "user", Content: "用 admin:test 登录"},
|
||||||
|
}
|
||||||
|
block := BuildUserVerbatimAnchorBlockFromMessages(msgs, 0)
|
||||||
|
if block == "" {
|
||||||
|
t.Fatal("expected non-empty block")
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, UserVerbatimSectionStartMarker) {
|
||||||
|
t.Error("missing start marker")
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "[第1轮]") || !strings.Contains(block, "https://a.com") {
|
||||||
|
t.Error("missing first user turn")
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "[第2轮]") || !strings.Contains(block, "admin:test") {
|
||||||
|
t.Error("missing second user turn")
|
||||||
|
}
|
||||||
|
if strings.Contains(block, "好的") {
|
||||||
|
t.Error("assistant content should not appear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceUserVerbatimAnchorSection(t *testing.T) {
|
||||||
|
old := "prefix\n\n" + wrapUserVerbatimBlock("## old\n\n[第1轮] a") + "\nsuffix"
|
||||||
|
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] b\n[第2轮] c")
|
||||||
|
out, ok := ReplaceUserVerbatimAnchorSection(old, newBlock)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected replace ok")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "[第2轮] c") {
|
||||||
|
t.Errorf("expected new block, got %q", out)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(out), "prefix") {
|
||||||
|
t.Error("prefix should remain")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "suffix") {
|
||||||
|
t.Error("suffix should remain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshUserVerbatimAnchorInMessages_ReplaceExisting(t *testing.T) {
|
||||||
|
oldBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] old")
|
||||||
|
msgs := []adk.Message{
|
||||||
|
schema.SystemMessage("instr\n\n" + oldBlock),
|
||||||
|
schema.UserMessage("hi"),
|
||||||
|
}
|
||||||
|
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] new")
|
||||||
|
out := RefreshUserVerbatimAnchorInMessages(msgs, newBlock)
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("message count: got %d", len(out))
|
||||||
|
}
|
||||||
|
if !strings.Contains(out[0].Content, "[第1轮] new") {
|
||||||
|
t.Errorf("system content: %q", out[0].Content)
|
||||||
|
}
|
||||||
|
if strings.Contains(out[0].Content, "[第1轮] old") {
|
||||||
|
t.Error("old anchor should be replaced")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshUserVerbatimAnchorInMessages_InsertWhenMissing(t *testing.T) {
|
||||||
|
msgs := []adk.Message{
|
||||||
|
schema.SystemMessage("base instruction"),
|
||||||
|
schema.UserMessage("hi"),
|
||||||
|
}
|
||||||
|
block := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] anchor")
|
||||||
|
out := RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||||
|
if !strings.Contains(out[0].Content, "[第1轮] anchor") {
|
||||||
|
t.Errorf("expected appended anchor, got %q", out[0].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUserVerbatimAnchorBlock_MaxRunes(t *testing.T) {
|
||||||
|
long := strings.Repeat("字", 200)
|
||||||
|
block := BuildUserVerbatimAnchorBlock([]string{long}, 50)
|
||||||
|
body := block
|
||||||
|
if idx := strings.Index(body, UserVerbatimSectionStartMarker); idx >= 0 {
|
||||||
|
body = strings.TrimPrefix(body[idx+len(UserVerbatimSectionStartMarker):], "\n")
|
||||||
|
}
|
||||||
|
if len([]rune(body)) > 120 {
|
||||||
|
t.Errorf("expected capped body, got %d runes", len([]rune(body)))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func sanitizeWorkspacePathSegment(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
s = strings.ReplaceAll(s, string(filepath.Separator), "-")
|
||||||
|
s = strings.ReplaceAll(s, "/", "-")
|
||||||
|
s = strings.ReplaceAll(s, "\\", "-")
|
||||||
|
s = strings.ReplaceAll(s, "..", "__")
|
||||||
|
if len(s) > 180 {
|
||||||
|
s = s[:180]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// WorkspaceRootDir returns the relative workspace root for downloads and local analysis.
|
||||||
|
// Project-bound sessions share projects/<id>/; otherwise conversations/<id>/.
|
||||||
|
func WorkspaceRootDir(configuredBase, projectID, conversationID string) string {
|
||||||
|
base := strings.TrimSpace(configuredBase)
|
||||||
|
if base == "" {
|
||||||
|
base = filepath.Join("tmp", "workspace")
|
||||||
|
}
|
||||||
|
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||||
|
return filepath.Join(base, "projects", sanitizeWorkspacePathSegment(pid))
|
||||||
|
}
|
||||||
|
conv := strings.TrimSpace(conversationID)
|
||||||
|
if conv == "" {
|
||||||
|
conv = "default"
|
||||||
|
}
|
||||||
|
return filepath.Join(base, "conversations", sanitizeWorkspacePathSegment(conv))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureWorkspace creates the workspace directory and returns its absolute path.
|
||||||
|
func EnsureWorkspace(root string) (string, error) {
|
||||||
|
abs, err := filepath.Abs(strings.TrimSpace(root))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("workspace abs: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(abs, 0o755); err != nil {
|
||||||
|
return "", fmt.Errorf("workspace mkdir: %w", err)
|
||||||
|
}
|
||||||
|
return abs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildWorkspaceBlock instructs the agent to use the session workspace instead of /tmp.
|
||||||
|
func BuildWorkspaceBlock(absPath string) string {
|
||||||
|
absPath = strings.TrimSpace(absPath)
|
||||||
|
if absPath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(`## 会话工作目录(下载与本地分析)
|
||||||
|
|
||||||
|
**必须使用以下目录**保存 curl/wget 下载的文件、临时 HTML/JS,以及 read_file/glob/grep 的检索范围:
|
||||||
|
`+"`%s`"+`
|
||||||
|
|
||||||
|
- **禁止**使用系统 `+"`/tmp`"+` 或其它全局临时目录(多项目/多会话会互窜遗留文件)。
|
||||||
|
- 下载示例:`+"`curl -o '%s/page.html' 'https://target/'`"+`;exec 时可将 `+"`workdir`"+` 设为该目录。
|
||||||
|
- 读取前用 glob/grep/read_file **限定在该目录**下搜索,勿在 `+"`/tmp`"+` 盲目检索。`, absPath, absPath)
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWorkspaceRootDirProjectScoped(t *testing.T) {
|
||||||
|
got := WorkspaceRootDir("", "proj-1", "conv-1")
|
||||||
|
want := filepath.Join("tmp", "workspace", "projects", "proj-1")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkspaceRootDirConversationScoped(t *testing.T) {
|
||||||
|
got := WorkspaceRootDir("/data/ws", "", "conv-abc")
|
||||||
|
want := filepath.Join("/data/ws", "conversations", "conv-abc")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureWorkspaceCreatesDir(t *testing.T) {
|
||||||
|
root := filepath.Join(t.TempDir(), "nested", "workspace")
|
||||||
|
abs, err := EnsureWorkspace(root)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EnsureWorkspace: %v", err)
|
||||||
|
}
|
||||||
|
st, err := os.Stat(abs)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stat: %v", err)
|
||||||
|
}
|
||||||
|
if !st.IsDir() {
|
||||||
|
t.Fatal("expected directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildWorkspaceBlockMentionsPath(t *testing.T) {
|
||||||
|
block := BuildWorkspaceBlock("/opt/csai/tmp/workspace/projects/p1")
|
||||||
|
if block == "" {
|
||||||
|
t.Fatal("expected non-empty block")
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "/opt/csai/tmp/workspace/projects/p1") {
|
||||||
|
t.Fatalf("block missing path: %s", block)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "/tmp") {
|
||||||
|
t.Fatalf("block should warn about /tmp: %s", block)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package projectprompt
|
||||||
|
|
||||||
|
// ShellExecExecuteGuidanceSection 供单代理/多代理系统提示追加:exec 与 execute 分工(尽量短)。
|
||||||
|
func ShellExecExecuteGuidanceSection() string {
|
||||||
|
return `Shell(exec/execute):有专用 MCP 工具时优先专用工具;系统命令(管道、workdir、后台 &)用 exec;skills/ 内脚本(配合 read_file、skill)用 execute;多步扫描分拆调用,禁止一条 shell 串多个扫描器。下载/临时文件须写入系统提示中的「会话工作目录」,禁止用 /tmp。`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShellExecExecuteGuidanceReconSuffix 侦察子代理可选追加(一行)。
|
||||||
|
func ShellExecExecuteGuidanceReconSuffix() string {
|
||||||
|
return `枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。`
|
||||||
|
}
|
||||||
@@ -26,6 +26,35 @@ const (
|
|||||||
wireOutputConfig
|
wireOutputConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ApplyPlanExecutePlannerModelConfig configures the plan_execute planner/replanner
|
||||||
|
// ChatModel. Those Eino agents call WithToolChoice(Forced); several gateways reject
|
||||||
|
// thinking / reasoning fields on the same request (tool_choice required/object).
|
||||||
|
// Executor should keep the normal ApplyToEinoChatModelConfig path.
|
||||||
|
func ApplyPlanExecutePlannerModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig) {
|
||||||
|
if cfg == nil || oa == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offOA := *oa
|
||||||
|
offReasoning := oa.Reasoning
|
||||||
|
offReasoning.Mode = "off"
|
||||||
|
offOA.Reasoning = offReasoning
|
||||||
|
ApplyToEinoChatModelConfig(cfg, &offOA, nil)
|
||||||
|
clearReasoningFromChatModelConfig(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearReasoningFromChatModelConfig(cfg *einoopenai.ChatModelConfig) {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.ReasoningEffort = ""
|
||||||
|
if cfg.ExtraFields != nil {
|
||||||
|
for _, key := range []string{"thinking", "reasoning_effort", "output_config", "reasoning"} {
|
||||||
|
delete(cfg.ExtraFields, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
applyThinkingDisabled(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
|
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
|
||||||
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
|
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
|
||||||
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
|
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
|
||||||
|
|||||||
@@ -49,6 +49,30 @@ func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyPlanExecutePlannerModelConfig_stripsReasoningWhenGlobalOn(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
BaseURL: "https://antchat.example.com/v1",
|
||||||
|
Model: "minimax-m3",
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Profile: "openai_compat",
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "high",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyPlanExecutePlannerModelConfig(cfg, oa)
|
||||||
|
if cfg.ReasoningEffort != "" {
|
||||||
|
t.Fatalf("expected ReasoningEffort cleared, got %q", cfg.ReasoningEffort)
|
||||||
|
}
|
||||||
|
th, ok := cfg.ExtraFields["thinking"].(map[string]any)
|
||||||
|
if !ok || th["type"] != "disabled" {
|
||||||
|
t.Fatalf("expected thinking disabled, got %#v", cfg.ExtraFields)
|
||||||
|
}
|
||||||
|
if _, ok := cfg.ExtraFields["reasoning_effort"]; ok {
|
||||||
|
t.Fatalf("expected reasoning_effort stripped, got %#v", cfg.ExtraFields)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyReasoningOff_disablesThinking(t *testing.T) {
|
func TestApplyReasoningOff_disablesThinking(t *testing.T) {
|
||||||
cfg := &einoopenai.ChatModelConfig{}
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
oa := &config.OpenAIConfig{
|
oa := &config.OpenAIConfig{
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FormatCommandFailureResult 与 exec 工具 ToolResult 文案一致(不含 ToolErrorPrefix)。
|
||||||
|
func FormatCommandFailureResult(exitCode int, output string) string {
|
||||||
|
output = strings.TrimSpace(output)
|
||||||
|
errMsg := fmt.Sprintf("exit status %d", exitCode)
|
||||||
|
if output == "" {
|
||||||
|
return fmt.Sprintf("命令执行失败: %s", errMsg)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(output, "命令执行失败:") {
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("命令执行失败: %s\n输出: %s", errMsg, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatCommandFailureFromErr 根据 exec/execute 返回的 error 生成统一失败文案(IsError 正文)。
|
||||||
|
func FormatCommandFailureFromErr(err error, output string) string {
|
||||||
|
if err == nil {
|
||||||
|
return strings.TrimSpace(output)
|
||||||
|
}
|
||||||
|
var exitError *exec.ExitError
|
||||||
|
if errors.As(err, &exitError) {
|
||||||
|
return FormatCommandFailureResult(exitError.ExitCode(), output)
|
||||||
|
}
|
||||||
|
output = strings.TrimSpace(output)
|
||||||
|
if output == "" {
|
||||||
|
return fmt.Sprintf("命令执行失败: %v", err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(output, "命令执行失败:") {
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("命令执行失败: %v\n输出: %s", err, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteFailureStatusLine 流式 execute 结束时追加的单行状态(输出正文已在流中推送过)。
|
||||||
|
func ExecuteFailureStatusLine(exitCode int) string {
|
||||||
|
return fmt.Sprintf("\n命令执行失败: exit status %d", exitCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCommandFailureResult 判断工具结果正文是否表示命令非零退出(用于 execute / exec 对齐 isError)。
|
||||||
|
func IsCommandFailureResult(content string) bool {
|
||||||
|
return strings.Contains(content, "命令执行失败:")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLegacyShellExitNoise 过滤旧版 shell 流中冗余的 exit code 行。
|
||||||
|
func IsLegacyShellExitNoise(s string) bool {
|
||||||
|
trimmed := strings.TrimSpace(s)
|
||||||
|
return strings.HasPrefix(trimmed, "command exited with non-zero code ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatCommandFailureResult(t *testing.T) {
|
||||||
|
got := FormatCommandFailureResult(1, "sudo: password required")
|
||||||
|
want := "命令执行失败: exit status 1\n输出: sudo: password required"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if FormatCommandFailureResult(2, "") != "命令执行失败: exit status 2" {
|
||||||
|
t.Fatal("empty output format")
|
||||||
|
}
|
||||||
|
if FormatCommandFailureResult(1, "命令执行失败: exit status 1") != "命令执行失败: exit status 1" {
|
||||||
|
t.Fatal("should not double-wrap")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsCommandFailureResult(t *testing.T) {
|
||||||
|
if !IsCommandFailureResult("sudo: err\n命令执行失败: exit status 1") {
|
||||||
|
t.Fatal("expected true")
|
||||||
|
}
|
||||||
|
if IsCommandFailureResult("sudo: err only") {
|
||||||
|
t.Fatal("expected false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatCommandFailureFromErr(t *testing.T) {
|
||||||
|
cmd := exec.Command("sh", "-c", "exit 42")
|
||||||
|
err := cmd.Run()
|
||||||
|
got := FormatCommandFailureFromErr(err, "oops")
|
||||||
|
if got != "命令执行失败: exit status 42\n输出: oops" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
timeoutErr := errors.New("shell inactivity timeout (300s)")
|
||||||
|
got2 := FormatCommandFailureFromErr(timeoutErr, "already timed out")
|
||||||
|
if !strings.Contains(got2, "shell inactivity timeout") || !strings.Contains(got2, "already timed out") {
|
||||||
|
t.Fatalf("got %q", got2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsLegacyShellExitNoise(t *testing.T) {
|
||||||
|
if !IsLegacyShellExitNoise("command exited with non-zero code 1\n") {
|
||||||
|
t.Fatal("expected legacy noise")
|
||||||
|
}
|
||||||
|
if IsLegacyShellExitNoise("sudo: failed") {
|
||||||
|
t.Fatal("unexpected noise")
|
||||||
|
}
|
||||||
|
}
|
||||||
+139
-110
@@ -32,10 +32,11 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
|
|||||||
|
|
||||||
// Executor 安全工具执行器
|
// Executor 安全工具执行器
|
||||||
type Executor struct {
|
type Executor struct {
|
||||||
config *config.SecurityConfig
|
config *config.SecurityConfig
|
||||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||||
mcpServer *mcp.Server
|
mcpServer *mcp.Server
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300;-1=关闭(见 SetShellNoOutputTimeoutSeconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewExecutor 创建新的执行器
|
// NewExecutor 创建新的执行器
|
||||||
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
|
|||||||
return executor
|
return executor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
|
||||||
|
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
|
||||||
|
e.shellNoOutputTimeoutSec = sec
|
||||||
|
}
|
||||||
|
|
||||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||||
func (e *Executor) buildToolIndex() {
|
func (e *Executor) buildToolIndex() {
|
||||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||||
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
// 执行命令
|
// 执行命令
|
||||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||||
applyDefaultTerminalEnv(cmd)
|
applyDefaultTerminalEnv(cmd)
|
||||||
|
attachNonInteractiveStdin(cmd)
|
||||||
_ = prepareShellCmdSession(cmd)
|
_ = prepareShellCmdSession(cmd)
|
||||||
|
|
||||||
e.logger.Info("执行安全工具",
|
e.logger.Info("执行安全工具",
|
||||||
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
var err error
|
var err error
|
||||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||||
zap.String("tool", toolName),
|
zap.String("tool", toolName),
|
||||||
@@ -155,9 +162,8 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outputBytes, err2 := cmd.CombinedOutput()
|
// 非流式:内存缓冲 + ctx 取消杀进程组;行为对齐原 CombinedOutput,避免双流管道 fan-in 死锁。
|
||||||
output = string(outputBytes)
|
output, err = combinedOutputCancellable(ctx, cmd)
|
||||||
err = err2
|
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||||
zap.String("tool", toolName),
|
zap.String("tool", toolName),
|
||||||
@@ -685,83 +691,21 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac
|
|||||||
// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。
|
// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。
|
||||||
// command1 & command2 不算完全后台(command2 仍在前台执行)。
|
// command1 & command2 不算完全后台(command2 仍在前台执行)。
|
||||||
func IsBackgroundShellCommand(command string) bool {
|
func IsBackgroundShellCommand(command string) bool {
|
||||||
// 移除首尾空格
|
|
||||||
command = strings.TrimSpace(command)
|
command = strings.TrimSpace(command)
|
||||||
if command == "" {
|
if command == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
positions := findStandaloneAmpersandPositions(command)
|
||||||
// 检查命令中所有不在引号内的 & 符号
|
if len(positions) == 0 {
|
||||||
// 找到最后一个 & 符号,检查它是否在命令末尾
|
|
||||||
inSingleQuote := false
|
|
||||||
inDoubleQuote := false
|
|
||||||
escaped := false
|
|
||||||
lastAmpersandPos := -1
|
|
||||||
|
|
||||||
for i, r := range command {
|
|
||||||
if escaped {
|
|
||||||
escaped = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r == '\\' {
|
|
||||||
escaped = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r == '\'' && !inDoubleQuote {
|
|
||||||
inSingleQuote = !inSingleQuote
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r == '"' && !inSingleQuote {
|
|
||||||
inDoubleQuote = !inDoubleQuote
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r == '&' && !inSingleQuote && !inDoubleQuote {
|
|
||||||
// 检查 & 前后是否有空格或换行(确保是独立的 &,而不是变量名的一部分)
|
|
||||||
isStandalone := false
|
|
||||||
|
|
||||||
// 检查前面:空格、制表符、换行符,或者是命令开头
|
|
||||||
if i == 0 {
|
|
||||||
isStandalone = true
|
|
||||||
} else {
|
|
||||||
prev := command[i-1]
|
|
||||||
if prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r' {
|
|
||||||
isStandalone = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查后面:空格、制表符、换行符,或者是命令末尾
|
|
||||||
if isStandalone {
|
|
||||||
if i == len(command)-1 {
|
|
||||||
// 在末尾,肯定是独立的 &
|
|
||||||
lastAmpersandPos = i
|
|
||||||
} else {
|
|
||||||
next := command[i+1]
|
|
||||||
if next == ' ' || next == '\t' || next == '\n' || next == '\r' {
|
|
||||||
// 后面有空格,是独立的 &
|
|
||||||
lastAmpersandPos = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果没有找到 & 符号,不是后台命令
|
|
||||||
if lastAmpersandPos == -1 {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
last := positions[len(positions)-1]
|
||||||
// 检查最后一个 & 后面是否还有非空内容
|
afterAmpersand := strings.TrimSpace(command[last+1:])
|
||||||
afterAmpersand := strings.TrimSpace(command[lastAmpersandPos+1:])
|
if afterAmpersand != "" {
|
||||||
if afterAmpersand == "" {
|
return false
|
||||||
// & 在末尾或后面只有空白字符,这是完全后台命令
|
|
||||||
// 检查 & 前面是否有内容
|
|
||||||
beforeAmpersand := strings.TrimSpace(command[:lastAmpersandPos])
|
|
||||||
return beforeAmpersand != ""
|
|
||||||
}
|
}
|
||||||
|
beforeAmpersand := strings.TrimSpace(command[:last])
|
||||||
// 如果 & 后面还有非空内容,说明是 command1 & command2 的情况
|
return beforeAmpersand != ""
|
||||||
// 这种情况下,command2会在前台执行,所以不算完全后台命令
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeSystemCommand 执行系统命令
|
// executeSystemCommand 执行系统命令
|
||||||
@@ -797,6 +741,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
zap.String("command", command),
|
zap.String("command", command),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
command = PrepareShellCommandForExecute(command)
|
||||||
|
|
||||||
// 获取shell类型(可选,默认为sh)
|
// 获取shell类型(可选,默认为sh)
|
||||||
shell := "sh"
|
shell := "sh"
|
||||||
if s, ok := args["shell"].(string); ok && s != "" {
|
if s, ok := args["shell"].(string); ok && s != "" {
|
||||||
@@ -820,8 +766,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
} else {
|
} else {
|
||||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd)
|
ConfigureShellCmdForAgentExecute(cmd)
|
||||||
_ = prepareShellCmdSession(cmd)
|
|
||||||
|
|
||||||
// 执行命令
|
// 执行命令
|
||||||
e.logger.Info("执行系统命令",
|
e.logger.Info("执行系统命令",
|
||||||
@@ -837,10 +782,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&")
|
commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&")
|
||||||
commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand)
|
commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand)
|
||||||
|
|
||||||
// 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid。
|
// 构建新命令:后台作业重定向标准流后 echo $pid(与 RedirectBackgroundJobStdio 一致)。
|
||||||
// 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行,
|
pidCommand := RedirectBackgroundJobStdio(commandWithoutAmpersand+" &") + " pid=$!; echo $pid"
|
||||||
// bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。
|
|
||||||
pidCommand := fmt.Sprintf("%s </dev/null >/dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand)
|
|
||||||
|
|
||||||
// 创建新命令来获取PID
|
// 创建新命令来获取PID
|
||||||
var pidCmd *exec.Cmd
|
var pidCmd *exec.Cmd
|
||||||
@@ -850,8 +793,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
} else {
|
} else {
|
||||||
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(pidCmd)
|
ConfigureShellCmdForAgentExecute(pidCmd)
|
||||||
_ = prepareShellCmdSession(pidCmd)
|
|
||||||
|
|
||||||
// 获取stdout管道
|
// 获取stdout管道
|
||||||
stdout, err := pidCmd.StdoutPipe()
|
stdout, err := pidCmd.StdoutPipe()
|
||||||
@@ -963,29 +905,25 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
var err error
|
var err error
|
||||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||||
if workDir != "" {
|
if workDir != "" {
|
||||||
cmd2.Dir = workDir
|
cmd2.Dir = workDir
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd2)
|
ConfigureShellCmdForAgentExecute(cmd2)
|
||||||
_ = prepareShellCmdSession(cmd2)
|
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
outputBytes, err2 := cmd.CombinedOutput()
|
output, err = combinedOutputCancellable(ctx, cmd)
|
||||||
output = string(outputBytes)
|
|
||||||
err = err2
|
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||||
if workDir != "" {
|
if workDir != "" {
|
||||||
cmd2.Dir = workDir
|
cmd2.Dir = workDir
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd2)
|
ConfigureShellCmdForAgentExecute(cmd2)
|
||||||
_ = prepareShellCmdSession(cmd2)
|
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -999,7 +937,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
Content: []mcp.Content{
|
Content: []mcp.Content{
|
||||||
{
|
{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
|
Text: FormatCommandFailureFromErr(err, output),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
IsError: true,
|
IsError: true,
|
||||||
@@ -1022,12 +960,58 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
// combinedOutputCancellable 行为对齐 cmd.CombinedOutput(stdout/stderr 写入内存缓冲),
|
||||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
// 但在 ctx 取消时 terminateCmdTree 终止整棵进程树。
|
||||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
// 非流式路径不使用双流管道 fan-in,避免 stderr 撑满管道缓冲区时与 stdout 互相阻塞导致死锁。
|
||||||
if err := prepareShellCmdSession(cmd); err != nil {
|
// 无输出空闲检测由上层 agent.tool_timeout_minutes 兜底,不改变原 CombinedOutput 语义。
|
||||||
|
func combinedOutputCancellable(ctx context.Context, cmd *exec.Cmd) (string, error) {
|
||||||
|
var stdoutBuf, stderrBuf strings.Builder
|
||||||
|
cmd.Stdout = &stdoutBuf
|
||||||
|
cmd.Stderr = &stderrBuf
|
||||||
|
|
||||||
|
session, err := StartShellSession(cmd)
|
||||||
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- session.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
stopWatch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
TerminateShellCmdSession(session)
|
||||||
|
case <-stopWatch:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(stopWatch)
|
||||||
|
|
||||||
|
var waitErr error
|
||||||
|
select {
|
||||||
|
case waitErr = <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
waitErr = <-done
|
||||||
|
return joinCommandOutput(stdoutBuf.String(), stderrBuf.String()), ctx.Err()
|
||||||
|
}
|
||||||
|
return joinCommandOutput(stdoutBuf.String(), stderrBuf.String()), waitErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinCommandOutput(stdout, stderr string) string {
|
||||||
|
if stderr == "" {
|
||||||
|
return stdout
|
||||||
|
}
|
||||||
|
if stdout == "" {
|
||||||
|
return stderr
|
||||||
|
}
|
||||||
|
return stdout + stderr
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||||
|
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||||
|
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
|
||||||
stdoutPipe, err := cmd.StdoutPipe()
|
stdoutPipe, err := cmd.StdoutPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1037,7 +1021,8 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
|||||||
_ = stdoutPipe.Close()
|
_ = stdoutPipe.Close()
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := cmd.Start(); err != nil {
|
session, err := StartShellSession(cmd)
|
||||||
|
if err != nil {
|
||||||
_ = stdoutPipe.Close()
|
_ = stdoutPipe.Close()
|
||||||
_ = stderrPipe.Close()
|
_ = stderrPipe.Close()
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1047,7 +1032,7 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
|||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
terminateCmdTree(cmd)
|
TerminateShellCmdSession(session)
|
||||||
case <-stopWatch:
|
case <-stopWatch:
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -1086,23 +1071,61 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
|||||||
if deltaBuilder.Len() == 0 {
|
if deltaBuilder.Len() == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cb(deltaBuilder.String())
|
if cb != nil {
|
||||||
|
cb(deltaBuilder.String())
|
||||||
|
}
|
||||||
deltaBuilder.Reset()
|
deltaBuilder.Reset()
|
||||||
lastFlush = time.Now()
|
lastFlush = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
for chunk := range chunks {
|
idleWatch := NewShellInactivityWatch(noOutputSec)
|
||||||
outBuilder.WriteString(chunk)
|
if idleWatch != nil {
|
||||||
deltaBuilder.WriteString(chunk)
|
defer idleWatch.Stop()
|
||||||
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
|
}
|
||||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
|
||||||
|
fireInactivity := func() {
|
||||||
|
TerminateShellCmdSession(session)
|
||||||
|
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||||
|
outBuilder.WriteString(msg)
|
||||||
|
if cb != nil {
|
||||||
|
cb(msg)
|
||||||
|
}
|
||||||
|
_ = session.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
chunksLoop:
|
||||||
|
for {
|
||||||
|
var idleCh <-chan struct{}
|
||||||
|
if idleWatch != nil {
|
||||||
|
idleCh = idleWatch.Expired
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
TerminateShellCmdSession(session)
|
||||||
flush()
|
flush()
|
||||||
|
_ = session.Wait()
|
||||||
|
return outBuilder.String(), ctx.Err()
|
||||||
|
case <-idleCh:
|
||||||
|
fireInactivity()
|
||||||
|
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||||
|
case chunk, ok := <-chunks:
|
||||||
|
if !ok {
|
||||||
|
break chunksLoop
|
||||||
|
}
|
||||||
|
if chunk != "" && idleWatch != nil {
|
||||||
|
idleWatch.Bump()
|
||||||
|
}
|
||||||
|
outBuilder.WriteString(chunk)
|
||||||
|
deltaBuilder.WriteString(chunk)
|
||||||
|
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||||
|
flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
// 等待命令结束,返回最终退出状态
|
// 等待命令结束,返回最终退出状态
|
||||||
waitErr := cmd.Wait()
|
waitErr := session.Wait()
|
||||||
return outBuilder.String(), waitErr
|
return outBuilder.String(), waitErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1116,6 +1139,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
|
|||||||
if cmd.Env == nil {
|
if cmd.Env == nil {
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
}
|
}
|
||||||
|
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
|
||||||
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
||||||
has := func(k string) bool {
|
has := func(k string) bool {
|
||||||
prefix := k + "="
|
prefix := k + "="
|
||||||
@@ -1159,7 +1183,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
return streamCommandOutput(ctx, cmd, cb)
|
return streamCommandOutput(ctx, cmd, cb, 0)
|
||||||
}
|
}
|
||||||
_ = prepareShellCmdSession(cmd)
|
_ = prepareShellCmdSession(cmd)
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
@@ -1173,13 +1197,18 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
|||||||
}
|
}
|
||||||
defer func() { _ = ptmx.Close() }()
|
defer func() { _ = ptmx.Close() }()
|
||||||
|
|
||||||
|
rootPID := 0
|
||||||
|
if cmd.Process != nil {
|
||||||
|
rootPID = cmd.Process.Pid
|
||||||
|
}
|
||||||
|
|
||||||
// ctx 取消时尽快终止子进程
|
// ctx 取消时尽快终止子进程
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
_ = ptmx.Close() // 触发读退出
|
_ = ptmx.Close() // 触发读退出
|
||||||
terminateCmdTree(cmd)
|
terminateProcessGroup(rootPID, cmd)
|
||||||
case <-done:
|
case <-done:
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user