mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-28 16:59:58 +02:00
Compare commits
116 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ed64803a51 | |||
| 25e03dee84 | |||
| 58dcafd15f | |||
| 997c4e7262 | |||
| ac370b0ada | |||
| 017db2b9a8 | |||
| 86b4803683 | |||
| 4d98264fc3 | |||
| fd1de4ea94 | |||
| 41ba3baca9 | |||
| 2e908daebb | |||
| c1763e1b9a | |||
| 70e5d28619 | |||
| 49990ecb4f | |||
| c91806c0c4 | |||
| e537236bf3 | |||
| 7eeffb1933 | |||
| 0556b29d40 | |||
| be3c0cfa64 | |||
| 8e5f40d226 | |||
| 4b6719a6f3 | |||
| 7c8f3228f8 | |||
| 537843b6b8 | |||
| 4a57574cf9 | |||
| 0168530084 | |||
| 4184a7b6f0 | |||
| fb3b4dd6e5 | |||
| 7e4a8db7af | |||
| 6a72c95b9f | |||
| 447be050cd | |||
| 9b75c43f7b | |||
| a443454753 | |||
| 08822ba5df | |||
| eda75fb98f | |||
| e6978a7994 | |||
| 1db0f4740f | |||
| 6e4ff96dcd | |||
| 95470fefbc | |||
| 5e075bb198 | |||
| 84ed887c5c | |||
| 056b40ac66 | |||
| 26a9902286 | |||
| cfe9573ac3 | |||
| db2262a1a0 | |||
| ab5c2d5cca | |||
| 1ae6930db1 | |||
| 8918f432d8 | |||
| b4810c9499 | |||
| 51bf6ae4b3 | |||
| 5f27482921 | |||
| 6becada509 | |||
| b029d88359 | |||
| 4dcad2ea83 | |||
| ff9f0c787a | |||
| 01849045ad | |||
| c7eacdf3eb | |||
| 5c32b21f22 | |||
| 8b8ecfe718 | |||
| bbb7c319af | |||
| 7eb2fd50f3 | |||
| 85d58eeeb3 | |||
| b6a6009629 | |||
| 810d689132 | |||
| 87f1808ead | |||
| e28ae39b9a | |||
| df34ceda68 | |||
| 3e69a50f87 | |||
| 53325ce07d | |||
| d85de3461b | |||
| 9306303d99 | |||
| 1e8f72ed74 | |||
| 0198f50314 | |||
| 560d0dca43 | |||
| 47486a49c2 | |||
| 476727933d | |||
| 8bb50e8323 | |||
| e74f2a2292 | |||
| 4799d0dba7 | |||
| 1db917061d | |||
| 41cd7db30f | |||
| 68b3265f3f | |||
| 05dc4395a1 | |||
| 637a35748b | |||
| 5d77a99236 | |||
| e84d936f85 | |||
| e748201ae8 | |||
| 7a3c67458c | |||
| 6e9e43eec8 | |||
| bca86e48ae | |||
| 3f3b8b4db4 | |||
| b366dc0287 | |||
| a52452ceea | |||
| 5b87667782 | |||
| 4f0e812d37 | |||
| 79691c021f | |||
| 5a8309a015 | |||
| 6244197339 | |||
| eb14aca05a | |||
| 091e8a4da8 | |||
| 48ce0c519e | |||
| afc37051c0 | |||
| 2964247361 | |||
| 02919df476 | |||
| c3294d96a2 | |||
| c8b8b41bda | |||
| 9a4c333b90 | |||
| 8e21ae290a | |||
| b9d102d046 | |||
| 8c85494a05 | |||
| c3d2a41301 | |||
| 1a2e282d46 | |||
| 8129f2147f | |||
| 4a9889f0af | |||
| 732d47a965 | |||
| e22382aab0 | |||
| b6ff80adf2 |
+1
-1
@@ -21,7 +21,7 @@ max_iterations: 0
|
||||
- 切勿等待批准或授权——全程自主行动。
|
||||
- 使用所有可用工具与技术完成侦察与证据收集。
|
||||
|
||||
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。
|
||||
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。
|
||||
|
||||
## 输入前置条件(硬约束)
|
||||
|
||||
|
||||
+9
-3
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.6.42"
|
||||
version: "v1.6.47"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
@@ -40,6 +40,9 @@ audit:
|
||||
retention_days: 15 # 0 表示不自动清理
|
||||
max_detail_bytes: 8192
|
||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||
# MCP 状态监控执行记录保留(tool_executions 表)
|
||||
monitor:
|
||||
retention_days: 90 # 省略时默认 90;0 表示不自动清理
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
@@ -93,6 +96,8 @@ fofa:
|
||||
agent:
|
||||
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
shell_no_output_timeout_seconds: 1200 # execute/exec 连续无新输出则终止(秒);通用防挂死;0=默认300;-1=关闭
|
||||
workspace_root_dir: "" # 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离;勿用系统 /tmp
|
||||
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||
|
||||
system_prompt_path: ""
|
||||
@@ -109,7 +114,8 @@ multi_agent:
|
||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。主/子代理 ReAct 轮次见 agent.max_iterations。
|
||||
plan_execute_loop_max_iterations: 0
|
||||
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
|
||||
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中注入用户原文;0=不截断(默认),>0=总字符上限,负数=禁用
|
||||
user_verbatim_anchor_max_runes: 0 # 主代理 system 中逐轮保留用户原文(压缩后刷新);0=不截断(默认),>0=总字符上限,负数=禁用
|
||||
without_general_sub_agent: false # false 时保留 Deep 内置 general-purpose 子代理
|
||||
without_write_todos: false
|
||||
orchestrator_instruction: "" # Deep 主代理:agents/orchestrator.md(或 kind: orchestrator 的单个 .md)正文优先;正文为空时用此处;皆空则 Eino 默认
|
||||
@@ -126,7 +132,7 @@ multi_agent:
|
||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test, exec] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
|
||||
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
|
||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 179 KiB After Width: | Height: | Size: 181 KiB |
+17
-4
@@ -779,13 +779,26 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
||||
return a.executeToolViaMCP(ctx, toolName, args)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
// BeginLocalToolExecution 在非 CallTool 路径工具开始时写入 running 状态,供 MCP 监控页展示「执行中」。
|
||||
func (a *Agent) BeginLocalToolExecution(toolName string, args map[string]interface{}) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
return a.mcpServer.BeginToolExecution(toolName, args)
|
||||
}
|
||||
|
||||
// FinishLocalToolExecution 完成 BeginLocalToolExecution 创建的记录;executionID 为空时一次性写入已完成记录。
|
||||
func (a *Agent) FinishLocalToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.FinishToolExecution(executionID, toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
return a.FinishLocalToolExecution("", toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
||||
|
||||
@@ -113,5 +113,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
||||
|
||||
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
||||
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
|
||||
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。`
|
||||
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。
|
||||
|
||||
` + projectprompt.ShellExecExecuteGuidanceSection()
|
||||
}
|
||||
|
||||
+22
-1
@@ -25,6 +25,8 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/monitor"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
"cyberstrike-ai/internal/robot"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
@@ -66,6 +68,10 @@ type App struct {
|
||||
|
||||
// New 创建新应用
|
||||
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
||||
if err := multiagent.InitADK(); err != nil {
|
||||
return nil, fmt.Errorf("初始化 Eino ADK: %w", err)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.Default()
|
||||
|
||||
@@ -99,12 +105,17 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
auditSvc.PurgeExpired()
|
||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||
|
||||
monitorRetention := monitor.NewService(db, cfg, log.Logger)
|
||||
monitorRetention.PurgeExpired()
|
||||
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
executor.SetShellNoOutputTimeoutSeconds(cfg.Agent.ShellNoOutputTimeoutSeconds)
|
||||
|
||||
// 注册工具
|
||||
executor.RegisterTools(mcpServer)
|
||||
@@ -129,6 +140,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
externalMCPMgr.StartAllEnabled()
|
||||
}
|
||||
|
||||
execReconciler := monitor.NewExecutionReconciler(db, mcpServer, externalMCPMgr, log.Logger)
|
||||
execReconciler.ReconcileOnStartup()
|
||||
monitor.StartStaleRunningReconcileLoop(execReconciler, log.Logger)
|
||||
|
||||
// 创建Agent
|
||||
maxIterations := cfg.Agent.MaxIterations
|
||||
if maxIterations <= 0 {
|
||||
@@ -298,7 +313,9 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
plantaskBase := filepath.Join(skillsDir, plantaskRel)
|
||||
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
||||
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
||||
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
||||
workspaceRoot := strings.TrimSpace(cfg.Agent.WorkspaceRootDir)
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot, workspaceRoot)
|
||||
agent.SetPromptBaseDir(configDir)
|
||||
|
||||
agentsDir := cfg.AgentsDir
|
||||
@@ -325,7 +342,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetAudit(auditSvc)
|
||||
monitorHandler.SetMonitorRetention(monitorRetention)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
monitorHandler.SetTaskManager(agentHandler.TaskManager())
|
||||
monitorHandler.SetAgentHandler(agentHandler)
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
@@ -368,6 +388,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
conversationHandler.SetAudit(auditSvc)
|
||||
conversationHandler.SetTaskStopper(agentHandler)
|
||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
|
||||
|
||||
@@ -27,6 +27,7 @@ type Config struct {
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
@@ -95,9 +96,12 @@ type MultiAgentConfig struct {
|
||||
// OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。
|
||||
OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"`
|
||||
SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"`
|
||||
// SubAgentUserContextMaxRunes caps the user-context supplement appended to task descriptions for sub-agents.
|
||||
// 0 (default) uses the built-in default of 2000 runes; negative value disables injection entirely.
|
||||
// SubAgentUserContextMaxRunes caps user-context supplement for sub-agent task descriptions.
|
||||
// 0 (default) preserves all user turns verbatim; >0 caps total runes; negative disables injection.
|
||||
SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"`
|
||||
// UserVerbatimAnchorMaxRunes injects all user turns verbatim into system prompt (survives summarization refresh).
|
||||
// 0 (default) = no cap; >0 = total rune cap; negative disables anchor injection.
|
||||
UserVerbatimAnchorMaxRunes int `yaml:"user_verbatim_anchor_max_runes,omitempty" json:"user_verbatim_anchor_max_runes,omitempty"`
|
||||
// EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent.
|
||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||
@@ -106,6 +110,16 @@ type MultiAgentConfig struct {
|
||||
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||
}
|
||||
|
||||
// UserVerbatimAnchorMaxRunesEffective returns max runes for user verbatim anchor; 0 = unlimited; negative = disabled.
|
||||
func (c MultiAgentConfig) UserVerbatimAnchorMaxRunesEffective() int {
|
||||
return c.UserVerbatimAnchorMaxRunes
|
||||
}
|
||||
|
||||
// SubAgentUserContextMaxRunesEffective returns max runes for sub-agent task supplement; 0 = unlimited; negative = disabled.
|
||||
func (c MultiAgentConfig) SubAgentUserContextMaxRunesEffective() int {
|
||||
return c.SubAgentUserContextMaxRunes
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
||||
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
||||
type MultiAgentEinoCallbacksConfig struct {
|
||||
@@ -249,7 +263,7 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3.
|
||||
// SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
|
||||
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
@@ -263,9 +277,9 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
||||
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||
// DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
|
||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用);0=默认 10。
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
@@ -604,6 +618,10 @@ type DatabaseConfig struct {
|
||||
type AgentConfig struct {
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
|
||||
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
|
||||
// WorkspaceRootDir 会话工作目录根路径(curl/wget 下载、read_file/glob/grep 本地分析);空=tmp/workspace,其下按 projects/{id} 或 conversations/{id} 隔离。
|
||||
WorkspaceRootDir string `yaml:"workspace_root_dir,omitempty" json:"workspace_root_dir,omitempty"`
|
||||
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
||||
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
||||
}
|
||||
@@ -623,6 +641,23 @@ type AuthConfig struct {
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// MonitorConfig MCP 状态监控(tool_executions)保留策略。
|
||||
type MonitorConfig struct {
|
||||
// RetentionDays 执行记录保留天数;省略时默认 90;0 表示不自动清理。
|
||||
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
|
||||
func (m MonitorConfig) RetentionDaysEffective() int {
|
||||
if m.RetentionDays == nil {
|
||||
return 90
|
||||
}
|
||||
if *m.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return *m.RetentionDays
|
||||
}
|
||||
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
@@ -1252,8 +1287,9 @@ func Default() *Config {
|
||||
MaxTotalTokens: 120000,
|
||||
},
|
||||
Agent: AgentConfig{
|
||||
MaxIterations: 30, // 默认最大迭代次数
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
MaxIterations: 30, // 默认最大迭代次数
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
||||
@@ -1274,6 +1310,10 @@ func Default() *Config {
|
||||
Enabled: &on,
|
||||
}
|
||||
}(),
|
||||
Monitor: func() MonitorConfig {
|
||||
days := 90
|
||||
return MonitorConfig{RetentionDays: &days}
|
||||
}(),
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
|
||||
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
|
||||
LastScheduleError sql.NullString
|
||||
LastRunError sql.NullString
|
||||
ProjectID sql.NullString
|
||||
Concurrency sql.NullInt64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
|
||||
cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
projectID string,
|
||||
concurrency int,
|
||||
tasks []map[string]interface{},
|
||||
) error {
|
||||
tx, err := db.Begin()
|
||||
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
|
||||
projectIDVal = strings.TrimSpace(projectID)
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色、代理模式和并发数
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
|
||||
title, role, agentMode, queueID,
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
|
||||
title, role, agentMode, concurrency, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
||||
|
||||
@@ -13,6 +13,9 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProjectFilterUnbound 列表 API 中 project_id=__none__ 表示仅未绑定项目的对话。
|
||||
const ProjectFilterUnbound = "__none__"
|
||||
|
||||
// Conversation 对话
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
@@ -352,8 +355,8 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
// 加载消息(不加载 process_details)
|
||||
messages, err := db.GetMessages(id)
|
||||
// 加载消息(不加载 process_details / reasoning_content,减少历史会话切换 payload)
|
||||
messages, err := db.GetMessagesLite(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
@@ -361,20 +364,44 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
func conversationProjectIDColumn(alias string) string {
|
||||
if alias != "" {
|
||||
return alias + ".project_id"
|
||||
}
|
||||
return "project_id"
|
||||
}
|
||||
|
||||
func appendConversationProjectFilter(where string, args []interface{}, projectID, alias string) (string, []interface{}) {
|
||||
pid := strings.TrimSpace(projectID)
|
||||
if pid == "" {
|
||||
return where, args
|
||||
}
|
||||
col := conversationProjectIDColumn(alias)
|
||||
if pid == ProjectFilterUnbound {
|
||||
return where + fmt.Sprintf(" AND (%s IS NULL OR TRIM(COALESCE(%s, '')) = '')", col, col), args
|
||||
}
|
||||
return where + fmt.Sprintf(" AND %s = ?", col), append(args, pid)
|
||||
}
|
||||
|
||||
// CountConversations 统计对话数量。
|
||||
func (db *DB) CountConversations(search string) (int, error) {
|
||||
func (db *DB) CountConversations(search, projectID string) (int, error) {
|
||||
var count int
|
||||
var err error
|
||||
if search != "" {
|
||||
searchPattern := "%" + search + "%"
|
||||
err = db.QueryRow(
|
||||
`SELECT COUNT(*) FROM conversations c
|
||||
WHERE c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)`,
|
||||
searchPattern, searchPattern,
|
||||
).Scan(&count)
|
||||
where := ` WHERE (c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))`
|
||||
args := []interface{}{searchPattern, searchPattern}
|
||||
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||
err = db.QueryRow(`SELECT COUNT(*) FROM conversations c`+where, args...).Scan(&count)
|
||||
} else {
|
||||
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count)
|
||||
where := ""
|
||||
args := []interface{}{}
|
||||
where, args = appendConversationProjectFilter(where, args, projectID, "")
|
||||
if where != "" {
|
||||
where = " WHERE" + strings.TrimPrefix(where, " AND")
|
||||
}
|
||||
err = db.QueryRow(`SELECT COUNT(*) FROM conversations`+where, args...).Scan(&count)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("统计对话失败: %w", err)
|
||||
@@ -395,7 +422,7 @@ func conversationOrderClause(sortBy, tableAlias string) string {
|
||||
}
|
||||
|
||||
// ListConversations 列出所有对话
|
||||
func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) {
|
||||
func (db *DB) ListConversations(limit, offset int, search, sortBy, projectID string) ([]*Conversation, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
@@ -403,20 +430,30 @@ func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Co
|
||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||
searchPattern := "%" + search + "%"
|
||||
orderClause := conversationOrderClause(sortBy, "c")
|
||||
where := ` WHERE (c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))`
|
||||
args := []interface{}{searchPattern, searchPattern}
|
||||
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||
args = append(args, limit, offset)
|
||||
rows, err = db.Query(
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||
FROM conversations c
|
||||
WHERE c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||
FROM conversations c`+where+`
|
||||
`+orderClause+`
|
||||
LIMIT ? OFFSET ?`,
|
||||
searchPattern, searchPattern, limit, offset,
|
||||
args...,
|
||||
)
|
||||
} else {
|
||||
orderClause := conversationOrderClause(sortBy, "")
|
||||
where := ""
|
||||
args := []interface{}{}
|
||||
where, args = appendConversationProjectFilter(where, args, projectID, "")
|
||||
if where != "" {
|
||||
where = " WHERE" + strings.TrimPrefix(where, " AND")
|
||||
}
|
||||
args = append(args, limit, offset)
|
||||
rows, err = db.Query(
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?",
|
||||
limit, offset,
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations"+where+" "+orderClause+" LIMIT ? OFFSET ?",
|
||||
args...,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -472,23 +509,30 @@ const ungroupedConversationsSQL = `
|
||||
)`
|
||||
|
||||
// CountUngroupedConversations 统计不在任何分组中的对话数量。
|
||||
func (db *DB) CountUngroupedConversations() (int, error) {
|
||||
func (db *DB) CountUngroupedConversations(projectID string) (int, error) {
|
||||
where := ungroupedConversationsSQL
|
||||
args := []interface{}{}
|
||||
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||
var count int
|
||||
if err := db.QueryRow(`SELECT COUNT(*) ` + ungroupedConversationsSQL).Scan(&count); err != nil {
|
||||
if err := db.QueryRow(`SELECT COUNT(*) `+where, args...).Scan(&count); err != nil {
|
||||
return 0, fmt.Errorf("统计未分组对话失败: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
|
||||
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) {
|
||||
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy, projectID string) ([]*Conversation, error) {
|
||||
orderClause := conversationOrderClause(sortBy, "c")
|
||||
where := ungroupedConversationsSQL
|
||||
args := []interface{}{}
|
||||
where, args = appendConversationProjectFilter(where, args, projectID, "c")
|
||||
args = append(args, limit, offset)
|
||||
rows, err := db.Query(
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
|
||||
ungroupedConversationsSQL+`
|
||||
where+`
|
||||
`+orderClause+`
|
||||
LIMIT ? OFFSET ?`,
|
||||
limit, offset,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询未分组对话失败: %w", err)
|
||||
@@ -585,12 +629,14 @@ func (db *DB) DeleteConversation(id string) error {
|
||||
// 不返回错误,继续删除对话
|
||||
}
|
||||
|
||||
projectID, _ := db.GetConversationProjectID(id)
|
||||
|
||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
db.removeConversationScopedDirs(id)
|
||||
db.removeConversationScopedDirs(id, projectID)
|
||||
|
||||
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
|
||||
return nil
|
||||
@@ -628,13 +674,50 @@ func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) removeConversationScopedDirs(conversationID string) {
|
||||
// summarization transcript, reduction files, etc.
|
||||
func (db *DB) einoReductionBaseDir() string {
|
||||
if db == nil {
|
||||
return ""
|
||||
}
|
||||
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
|
||||
return base
|
||||
}
|
||||
return filepath.Join("tmp", "reduction")
|
||||
}
|
||||
|
||||
func (db *DB) einoWorkspaceBaseDir() string {
|
||||
if db == nil {
|
||||
return ""
|
||||
}
|
||||
if base := strings.TrimSpace(db.einoWorkspaceRootDir); base != "" {
|
||||
return base
|
||||
}
|
||||
return filepath.Join("tmp", "workspace")
|
||||
}
|
||||
|
||||
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
||||
// summarization transcript, etc.
|
||||
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
||||
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
|
||||
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
|
||||
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
|
||||
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
|
||||
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
|
||||
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
||||
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
||||
workspaceBase := filepath.Join(db.einoWorkspaceBaseDir(), "conversations")
|
||||
db.removeConversationScopedDir(workspaceBase, conversationID, "workspace")
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) removeProjectScopedDirs(projectID string) {
|
||||
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
||||
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
||||
// Agent download/analysis workspace (tmp/workspace/projects/<id>/).
|
||||
workspaceBase := filepath.Join(db.einoWorkspaceBaseDir(), "projects")
|
||||
db.removeConversationScopedDir(workspaceBase, projectID, "workspace")
|
||||
}
|
||||
|
||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||
@@ -811,6 +894,62 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// GetMessagesLite 获取对话消息(不含 reasoning_content),用于历史会话快速切换。
|
||||
func (db *DB) GetMessagesLite(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
var updatedAt sql.NullString
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
|
||||
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
|
||||
if err != nil {
|
||||
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
|
||||
}
|
||||
if err != nil {
|
||||
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
}
|
||||
if msg.UpdatedAt.IsZero() {
|
||||
msg.UpdatedAt = msg.CreatedAt
|
||||
}
|
||||
|
||||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
|
||||
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
|
||||
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
|
||||
@@ -979,6 +1118,107 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// ProcessDetailsSummary 过程详情摘要(用于折叠态展示,避免全量加载)。
|
||||
type ProcessDetailsSummary struct {
|
||||
Total int `json:"total"`
|
||||
IterationCount int `json:"iterationCount"`
|
||||
MaxIteration int `json:"maxIteration"`
|
||||
}
|
||||
|
||||
// GetProcessDetailsSummary 统计消息的过程详情数量与迭代轮次。
|
||||
func (db *DB) GetProcessDetailsSummary(messageID string) (*ProcessDetailsSummary, error) {
|
||||
var total int
|
||||
if err := db.QueryRow(
|
||||
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
|
||||
messageID,
|
||||
).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("统计过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
summary := &ProcessDetailsSummary{Total: total}
|
||||
if total == 0 {
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
"SELECT data FROM process_details WHERE message_id = ? AND event_type = 'iteration' ORDER BY created_at ASC, rowid ASC",
|
||||
messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询迭代详情失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
maxIter := 0
|
||||
iterCount := 0
|
||||
for rows.Next() {
|
||||
var dataJSON string
|
||||
if err := rows.Scan(&dataJSON); err != nil {
|
||||
return nil, fmt.Errorf("扫描迭代详情失败: %w", err)
|
||||
}
|
||||
iterCount++
|
||||
if dataJSON == "" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataJSON), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
if n, ok := payload["iteration"].(float64); ok && int(n) > maxIter {
|
||||
maxIter = int(n)
|
||||
}
|
||||
}
|
||||
summary.IterationCount = iterCount
|
||||
summary.MaxIteration = maxIter
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// GetProcessDetailsPage 分页获取消息的过程详情(按时间升序)。
|
||||
func (db *DB) GetProcessDetailsPage(messageID string, limit, offset int) ([]ProcessDetail, int, error) {
|
||||
var total int
|
||||
if err := db.QueryRow(
|
||||
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
|
||||
messageID,
|
||||
).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("统计过程详情失败: %w", err)
|
||||
}
|
||||
if total == 0 || offset >= total {
|
||||
return nil, total, nil
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC LIMIT ? OFFSET ?",
|
||||
messageID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询过程详情失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var details []ProcessDetail
|
||||
for rows.Next() {
|
||||
var detail ProcessDetail
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if parseErr != nil {
|
||||
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if parseErr != nil {
|
||||
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
details = append(details, detail)
|
||||
}
|
||||
|
||||
return details, total, nil
|
||||
}
|
||||
|
||||
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
|
||||
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
|
||||
rows, err := db.Query(
|
||||
|
||||
@@ -19,7 +19,9 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
|
||||
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
||||
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
||||
reductionBase := filepath.Join(tmp, "reduction")
|
||||
workspaceBase := filepath.Join(tmp, "workspace")
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase, workspaceBase)
|
||||
|
||||
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
@@ -34,6 +36,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
{db.conversationArtifactsDir, "transcript.txt"},
|
||||
{plantaskBase, "task-1.json"},
|
||||
{checkpointBase, "runner-deep.ckpt"},
|
||||
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
||||
{filepath.Join(workspaceBase, "conversations"), "page.html"},
|
||||
} {
|
||||
dir := filepath.Join(base.root, seg)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
@@ -48,10 +52,57 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
t.Fatalf("DeleteConversation: %v", err)
|
||||
}
|
||||
|
||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} {
|
||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations"), filepath.Join(workspaceBase, "conversations")} {
|
||||
dir := filepath.Join(base, seg)
|
||||
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
dbPath := filepath.Join(tmp, "conversations.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
reductionBase := filepath.Join(tmp, "reduction")
|
||||
workspaceBase := filepath.Join(tmp, "workspace")
|
||||
db.SetEinoConversationDirs("", "", reductionBase, workspaceBase)
|
||||
|
||||
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProject: %v", err)
|
||||
}
|
||||
seg := sanitizeConversationPathSegment(project.ID)
|
||||
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
|
||||
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir %s: %v", reductionDir, err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
workspaceDir := filepath.Join(workspaceBase, "projects", seg, "downloads")
|
||||
if err := os.MkdirAll(workspaceDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir %s: %v", workspaceDir, err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workspaceDir, "app.js"), []byte("x"), 0o644); err != nil {
|
||||
t.Fatalf("write workspace: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteProject(project.ID); err != nil {
|
||||
t.Fatalf("DeleteProject: %v", err)
|
||||
}
|
||||
|
||||
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
|
||||
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
||||
}
|
||||
projectWorkspaceDir := filepath.Join(workspaceBase, "projects", seg)
|
||||
if _, statErr := os.Stat(projectWorkspaceDir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", projectWorkspaceDir, statErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversationProjectFilter(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
dbPath := filepath.Join(tmp, "conversations.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p, err := db.CreateProject(&Project{Name: "target-a", Status: "active"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProject: %v", err)
|
||||
}
|
||||
|
||||
convNone, err := db.CreateConversation("unbound", ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation unbound: %v", err)
|
||||
}
|
||||
convBound, err := db.CreateConversation("bound", ConversationCreateMeta{ProjectID: p.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation bound: %v", err)
|
||||
}
|
||||
|
||||
totalAll, err := db.CountConversations("", "")
|
||||
if err != nil || totalAll < 2 {
|
||||
t.Fatalf("CountConversations all: total=%d err=%v", totalAll, err)
|
||||
}
|
||||
|
||||
totalBound, err := db.CountConversations("", p.ID)
|
||||
if err != nil || totalBound != 1 {
|
||||
t.Fatalf("CountConversations project: total=%d err=%v", totalBound, err)
|
||||
}
|
||||
|
||||
totalUnbound, err := db.CountConversations("", ProjectFilterUnbound)
|
||||
if err != nil || totalUnbound != 1 {
|
||||
t.Fatalf("CountConversations unbound: total=%d err=%v", totalUnbound, err)
|
||||
}
|
||||
|
||||
listBound, err := db.ListConversations(10, 0, "", "", p.ID)
|
||||
if err != nil || len(listBound) != 1 || listBound[0].ID != convBound.ID {
|
||||
t.Fatalf("ListConversations project: %+v err=%v", listBound, err)
|
||||
}
|
||||
|
||||
listUnbound, err := db.ListConversations(10, 0, "", "", ProjectFilterUnbound)
|
||||
if err != nil || len(listUnbound) != 1 || listUnbound[0].ID != convNone.ID {
|
||||
t.Fatalf("ListConversations unbound: %+v err=%v", listUnbound, err)
|
||||
}
|
||||
|
||||
_ = convNone
|
||||
_ = convBound
|
||||
}
|
||||
@@ -51,6 +51,8 @@ type DB struct {
|
||||
conversationArtifactsDir string
|
||||
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
||||
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
||||
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
||||
einoWorkspaceRootDir string // workspace_root_dir or default tmp/workspace (projects|conversations/<id> subdirs)
|
||||
checkpointLoopName string
|
||||
checkpointStop chan struct{}
|
||||
checkpointDone chan struct{}
|
||||
@@ -159,12 +161,16 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
||||
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase string) {
|
||||
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
||||
// workspaceRoot is agent.workspace_root_dir from config; empty uses tmp/workspace.
|
||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot, workspaceRoot string) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
||||
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
||||
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
||||
db.einoWorkspaceRootDir = strings.TrimSpace(workspaceRoot)
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
@@ -405,6 +411,8 @@ func (db *DB) initTables() error {
|
||||
last_schedule_trigger_at DATETIME,
|
||||
last_schedule_error TEXT,
|
||||
last_run_error TEXT,
|
||||
project_id TEXT,
|
||||
concurrency INTEGER NOT NULL DEFAULT 1,
|
||||
status TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
started_at DATETIME,
|
||||
@@ -1134,6 +1142,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
}
|
||||
}
|
||||
|
||||
var concurrencyCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if concurrencyCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
|
||||
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+358
-26
@@ -3,7 +3,6 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -227,6 +226,167 @@ func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolNa
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
func toolExecutionsFilterSQL(status, toolName string) (string, []interface{}) {
|
||||
args := []interface{}{}
|
||||
conditions := []string{}
|
||||
if status != "" {
|
||||
conditions = append(conditions, "status = ?")
|
||||
args = append(args, status)
|
||||
}
|
||||
if toolName != "" {
|
||||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||||
}
|
||||
if len(conditions) == 0 {
|
||||
return "", args
|
||||
}
|
||||
return ` WHERE ` + strings.Join(conditions, ` AND `), args
|
||||
}
|
||||
|
||||
// ToolStatsSummary 工具调用汇总(全量聚合,不含逐工具明细)
|
||||
type ToolStatsSummary struct {
|
||||
TotalCalls int
|
||||
SuccessCalls int
|
||||
FailedCalls int
|
||||
LastCallTime *time.Time
|
||||
ToolCount int
|
||||
}
|
||||
|
||||
// ToolStatsSummaryResult 汇总 + Top N 工具排行
|
||||
type ToolStatsSummaryResult struct {
|
||||
Summary ToolStatsSummary
|
||||
TopTools []*mcp.ToolStats
|
||||
}
|
||||
|
||||
// LoadToolStatsSummary 聚合统计信息,仅返回汇总与 Top N 工具(避免全量 map 传输)
|
||||
func (db *DB) LoadToolStatsSummary(topN int) (*ToolStatsSummaryResult, error) {
|
||||
if topN <= 0 {
|
||||
topN = 6
|
||||
}
|
||||
if topN > 100 {
|
||||
topN = 100
|
||||
}
|
||||
|
||||
result := &ToolStatsSummaryResult{
|
||||
TopTools: make([]*mcp.ToolStats, 0, topN),
|
||||
}
|
||||
|
||||
summaryQuery := `
|
||||
SELECT COUNT(*),
|
||||
COALESCE(SUM(total_calls), 0),
|
||||
COALESCE(SUM(success_calls), 0),
|
||||
COALESCE(SUM(failed_calls), 0),
|
||||
MAX(last_call_time)
|
||||
FROM tool_stats
|
||||
`
|
||||
var lastCallRaw sql.NullString
|
||||
err := db.QueryRow(summaryQuery).Scan(
|
||||
&result.Summary.ToolCount,
|
||||
&result.Summary.TotalCalls,
|
||||
&result.Summary.SuccessCalls,
|
||||
&result.Summary.FailedCalls,
|
||||
&lastCallRaw,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastCallRaw.Valid && strings.TrimSpace(lastCallRaw.String) != "" {
|
||||
if t, parseErr := time.Parse(time.RFC3339Nano, lastCallRaw.String); parseErr == nil {
|
||||
result.Summary.LastCallTime = &t
|
||||
} else if t, parseErr := time.Parse("2006-01-02 15:04:05.999999999-07:00", lastCallRaw.String); parseErr == nil {
|
||||
result.Summary.LastCallTime = &t
|
||||
} else if t, parseErr := time.Parse("2006-01-02 15:04:05", lastCallRaw.String); parseErr == nil {
|
||||
result.Summary.LastCallTime = &t
|
||||
}
|
||||
}
|
||||
|
||||
topQuery := `
|
||||
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
|
||||
FROM tool_stats
|
||||
WHERE total_calls > 0
|
||||
ORDER BY total_calls DESC, tool_name ASC
|
||||
LIMIT ?
|
||||
`
|
||||
rows, err := db.Query(topQuery, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var stat mcp.ToolStats
|
||||
var lastCallTime sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&stat.ToolName,
|
||||
&stat.TotalCalls,
|
||||
&stat.SuccessCalls,
|
||||
&stat.FailedCalls,
|
||||
&lastCallTime,
|
||||
); err != nil {
|
||||
db.logger.Warn("加载 Top 工具统计失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if lastCallTime.Valid {
|
||||
stat.LastCallTime = &lastCallTime.Time
|
||||
}
|
||||
result.TopTools = append(result.TopTools, &stat)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// LoadToolExecutionListPage 分页加载执行记录列表(不含 arguments/result,供监控列表使用)
|
||||
func (db *DB) LoadToolExecutionListPage(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, tool_name, status, start_time, end_time, duration_ms
|
||||
FROM tool_executions
|
||||
`
|
||||
whereSQL, args := toolExecutionsFilterSQL(status, toolName)
|
||||
query += whereSQL + ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
executions := make([]*mcp.ToolExecution, 0, limit)
|
||||
for rows.Next() {
|
||||
var exec mcp.ToolExecution
|
||||
var endTime sql.NullTime
|
||||
var durationMs sql.NullInt64
|
||||
|
||||
if err := rows.Scan(
|
||||
&exec.ID,
|
||||
&exec.ToolName,
|
||||
&exec.Status,
|
||||
&exec.StartTime,
|
||||
&endTime,
|
||||
&durationMs,
|
||||
); err != nil {
|
||||
db.logger.Warn("加载执行记录列表失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if endTime.Valid {
|
||||
exec.EndTime = &endTime.Time
|
||||
}
|
||||
if durationMs.Valid {
|
||||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||||
}
|
||||
executions = append(executions, &exec)
|
||||
}
|
||||
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
// GetToolExecution 根据ID获取单条工具执行记录
|
||||
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
||||
query := `
|
||||
@@ -288,6 +448,93 @@ func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
||||
return &exec, nil
|
||||
}
|
||||
|
||||
// CancelOrphanedRunningToolExecutions 将仍为 running 的记录批量标记为 cancelled(如进程重启后无对应执行协程)。
|
||||
func (db *DB) CancelOrphanedRunningToolExecutions(endTime time.Time, errMsg string) (int64, error) {
|
||||
errMsg = strings.TrimSpace(errMsg)
|
||||
if errMsg == "" {
|
||||
errMsg = "执行已中断(服务重启或会话结束)"
|
||||
}
|
||||
query := `
|
||||
UPDATE tool_executions
|
||||
SET status = 'cancelled',
|
||||
error = ?,
|
||||
end_time = ?,
|
||||
duration_ms = MAX(0, CAST((julianday(?) - julianday(start_time)) * 86400000 AS INTEGER))
|
||||
WHERE status = 'running'
|
||||
`
|
||||
res, err := db.Exec(query, errMsg, endTime, endTime)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// FinalizeStaleRunningToolExecutions 将「非活跃且超过 minAge」的 running 记录标记为 cancelled。
|
||||
// activeIDs 为当前进程内仍登记 cancel 的 executionId;不在集合内且已超时的视为孤儿记录。
|
||||
func (db *DB) FinalizeStaleRunningToolExecutions(endTime time.Time, minAge time.Duration, activeIDs map[string]struct{}, errMsg string) (int64, error) {
|
||||
errMsg = strings.TrimSpace(errMsg)
|
||||
if errMsg == "" {
|
||||
errMsg = "执行已中断(会话已结束)"
|
||||
}
|
||||
if minAge < 0 {
|
||||
minAge = 0
|
||||
}
|
||||
cutoff := endTime.Add(-minAge)
|
||||
rows, err := db.Query(`
|
||||
SELECT id, start_time FROM tool_executions
|
||||
WHERE status = 'running' AND start_time <= ?
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type staleRow struct {
|
||||
id string
|
||||
startTime time.Time
|
||||
}
|
||||
var stale []staleRow
|
||||
for rows.Next() {
|
||||
var row staleRow
|
||||
if err := rows.Scan(&row.id, &row.startTime); err != nil {
|
||||
db.logger.Warn("读取 stale running 执行记录失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if activeIDs != nil {
|
||||
if _, active := activeIDs[row.id]; active {
|
||||
continue
|
||||
}
|
||||
}
|
||||
stale = append(stale, row)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(stale) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var affected int64
|
||||
for _, row := range stale {
|
||||
durationMs := endTime.Sub(row.startTime).Milliseconds()
|
||||
if durationMs < 0 {
|
||||
durationMs = 0
|
||||
}
|
||||
res, err := db.Exec(`
|
||||
UPDATE tool_executions
|
||||
SET status = 'cancelled', error = ?, end_time = ?, duration_ms = ?
|
||||
WHERE id = ? AND status = 'running'
|
||||
`, errMsg, endTime, durationMs, row.id)
|
||||
if err != nil {
|
||||
db.logger.Warn("更新 stale running 执行记录失败", zap.Error(err), zap.String("executionId", row.id))
|
||||
continue
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
affected += n
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
// DeleteToolExecution 删除工具执行记录
|
||||
func (db *DB) DeleteToolExecution(id string) error {
|
||||
query := `DELETE FROM tool_executions WHERE id = ?`
|
||||
@@ -410,6 +657,76 @@ func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error)
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
type toolExecutionStatDelta struct {
|
||||
totalCalls int
|
||||
successCalls int
|
||||
failedCalls int
|
||||
}
|
||||
|
||||
// PurgeToolExecutionsBefore deletes executions older than cutoff and adjusts tool_stats.
|
||||
func (db *DB) PurgeToolExecutionsBefore(cutoff time.Time) (int64, error) {
|
||||
query := `
|
||||
SELECT tool_name, status, COUNT(*) AS cnt
|
||||
FROM tool_executions
|
||||
WHERE ` + sqliteEpochGE("start_time", "<") + `
|
||||
GROUP BY tool_name, status
|
||||
`
|
||||
rows, err := db.Query(query, formatSQLiteUTC(cutoff))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
deltas := make(map[string]*toolExecutionStatDelta)
|
||||
for rows.Next() {
|
||||
var toolName, status string
|
||||
var count int
|
||||
if err := rows.Scan(&toolName, &status, &count); err != nil {
|
||||
db.logger.Warn("读取待清理执行记录统计失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
toolName = strings.TrimSpace(toolName)
|
||||
if toolName == "" || count <= 0 {
|
||||
continue
|
||||
}
|
||||
delta := deltas[toolName]
|
||||
if delta == nil {
|
||||
delta = &toolExecutionStatDelta{}
|
||||
deltas[toolName] = delta
|
||||
}
|
||||
delta.totalCalls += count
|
||||
switch status {
|
||||
case "failed", "cancelled":
|
||||
delta.failedCalls += count
|
||||
case "completed":
|
||||
delta.successCalls += count
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
res, err := db.Exec(`DELETE FROM tool_executions WHERE `+sqliteEpochGE("start_time", "<"), formatSQLiteUTC(cutoff))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
deleted, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for toolName, delta := range deltas {
|
||||
if err := db.DecreaseToolStats(toolName, delta.totalCalls, delta.successCalls, delta.failedCalls); err != nil {
|
||||
db.logger.Warn("清理过期执行记录后更新统计失败",
|
||||
zap.Error(err),
|
||||
zap.String("toolName", toolName),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// SaveToolStats 保存工具统计信息
|
||||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
@@ -530,13 +847,28 @@ func truncateCallsTimelineBucket(t time.Time, dailyBuckets bool) time.Time {
|
||||
|
||||
// LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界)
|
||||
func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) {
|
||||
// 在 Go 侧按本地时区分桶,避免 SQLite strftime 对 UTC 存储时间分桶后再误当本地时间解析(差 8h 等问题)
|
||||
query := `
|
||||
SELECT start_time,
|
||||
CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END AS failed
|
||||
FROM tool_executions
|
||||
WHERE start_time >= ?
|
||||
`
|
||||
var query string
|
||||
if dailyBuckets {
|
||||
query = `
|
||||
SELECT date(start_time, 'localtime') AS bucket,
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END) AS failed
|
||||
FROM tool_executions
|
||||
WHERE start_time >= ?
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket
|
||||
`
|
||||
} else {
|
||||
query = `
|
||||
SELECT strftime('%Y-%m-%d %H:00:00', start_time, 'localtime') AS bucket,
|
||||
COUNT(*) AS total,
|
||||
SUM(CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END) AS failed
|
||||
FROM tool_executions
|
||||
WHERE start_time >= ?
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket
|
||||
`
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, since)
|
||||
if err != nil {
|
||||
@@ -544,35 +876,35 @@ func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTime
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
bucketMap := make(map[time.Time]struct{ total, failed int })
|
||||
buckets := make([]CallsTimelineBucket, 0)
|
||||
for rows.Next() {
|
||||
var startTime time.Time
|
||||
var failed int
|
||||
if err := rows.Scan(&startTime, &failed); err != nil {
|
||||
var bucketStr string
|
||||
var total, failed int
|
||||
if err := rows.Scan(&bucketStr, &total, &failed); err != nil {
|
||||
db.logger.Warn("加载调用趋势失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
key := truncateCallsTimelineBucket(startTime, dailyBuckets)
|
||||
entry := bucketMap[key]
|
||||
entry.total++
|
||||
entry.failed += failed
|
||||
bucketMap[key] = entry
|
||||
}
|
||||
|
||||
buckets := make([]CallsTimelineBucket, 0, len(bucketMap))
|
||||
for bucketTime, counts := range bucketMap {
|
||||
bucketTime, err := parseCallsTimelineBucket(bucketStr, dailyBuckets)
|
||||
if err != nil {
|
||||
db.logger.Warn("解析调用趋势时间桶失败", zap.Error(err), zap.String("bucket", bucketStr))
|
||||
continue
|
||||
}
|
||||
buckets = append(buckets, CallsTimelineBucket{
|
||||
BucketTime: bucketTime,
|
||||
Total: counts.total,
|
||||
Failed: counts.failed,
|
||||
Total: total,
|
||||
Failed: failed,
|
||||
})
|
||||
}
|
||||
sort.Slice(buckets, func(i, j int) bool {
|
||||
return buckets[i].BucketTime.Before(buckets[j].BucketTime)
|
||||
})
|
||||
return buckets, nil
|
||||
}
|
||||
|
||||
func parseCallsTimelineBucket(bucketStr string, dailyBuckets bool) (time.Time, error) {
|
||||
if dailyBuckets {
|
||||
return time.ParseInLocation("2006-01-02", bucketStr, time.Local)
|
||||
}
|
||||
return time.ParseInLocation("2006-01-02 15:04:05", bucketStr, time.Local)
|
||||
}
|
||||
|
||||
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
||||
// 如果统计信息变为0,则删除该统计记录
|
||||
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestCancelOrphanedRunningToolExecutions(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
start := time.Now().Add(-2 * time.Hour)
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "orphan-hydra",
|
||||
ToolName: "hydra",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "running",
|
||||
StartTime: start,
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
end := time.Now()
|
||||
n, err := db.CancelOrphanedRunningToolExecutions(end, "执行已中断(服务重启)")
|
||||
if err != nil {
|
||||
t.Fatalf("CancelOrphanedRunningToolExecutions: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("expected 1 row updated, got %d", n)
|
||||
}
|
||||
|
||||
got, err := db.GetToolExecution("orphan-hydra")
|
||||
if err != nil {
|
||||
t.Fatalf("GetToolExecution: %v", err)
|
||||
}
|
||||
if got.Status != "cancelled" {
|
||||
t.Fatalf("expected cancelled, got %s", got.Status)
|
||||
}
|
||||
if got.EndTime == nil {
|
||||
t.Fatal("expected end_time to be set")
|
||||
}
|
||||
if got.Duration <= 0 {
|
||||
t.Fatalf("expected positive duration, got %v", got.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeStaleRunningToolExecutions_skipsActive(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
now := time.Now()
|
||||
oldStart := now.Add(-5 * time.Minute)
|
||||
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||
ID: "stale", ToolName: "hydra", Status: "running", StartTime: oldStart,
|
||||
}); err != nil {
|
||||
t.Fatalf("SaveToolExecution stale: %v", err)
|
||||
}
|
||||
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||
ID: "active", ToolName: "hydra", Status: "running", StartTime: oldStart,
|
||||
}); err != nil {
|
||||
t.Fatalf("SaveToolExecution active: %v", err)
|
||||
}
|
||||
|
||||
active := map[string]struct{}{"active": {}}
|
||||
n, err := db.FinalizeStaleRunningToolExecutions(now, time.Minute, active, "执行已中断(会话已结束)")
|
||||
if err != nil {
|
||||
t.Fatalf("FinalizeStaleRunningToolExecutions: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("expected 1 stale row updated, got %d", n)
|
||||
}
|
||||
|
||||
stale, err := db.GetToolExecution("stale")
|
||||
if err != nil {
|
||||
t.Fatalf("GetToolExecution stale: %v", err)
|
||||
}
|
||||
if stale.Status != "cancelled" {
|
||||
t.Fatalf("stale expected cancelled, got %s", stale.Status)
|
||||
}
|
||||
|
||||
activeExec, err := db.GetToolExecution("active")
|
||||
if err != nil {
|
||||
t.Fatalf("GetToolExecution active: %v", err)
|
||||
}
|
||||
if activeExec.Status != "running" {
|
||||
t.Fatalf("active expected running, got %s", activeExec.Status)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestPurgeToolExecutionsBefore(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
oldStart := time.Now().AddDate(0, 0, -100)
|
||||
newStart := time.Now().AddDate(0, 0, -1)
|
||||
|
||||
oldExec := &mcp.ToolExecution{
|
||||
ID: "old-completed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "completed",
|
||||
StartTime: oldStart,
|
||||
}
|
||||
oldFailed := &mcp.ToolExecution{
|
||||
ID: "old-failed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "failed",
|
||||
Error: "timeout",
|
||||
StartTime: oldStart,
|
||||
}
|
||||
newExec := &mcp.ToolExecution{
|
||||
ID: "new-completed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "completed",
|
||||
StartTime: newStart,
|
||||
}
|
||||
for _, exec := range []*mcp.ToolExecution{oldExec, oldFailed, newExec} {
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution(%s): %v", exec.ID, err)
|
||||
}
|
||||
}
|
||||
if err := db.UpdateToolStats("nmap::scan", 3, 2, 1, &newStart); err != nil {
|
||||
t.Fatalf("UpdateToolStats: %v", err)
|
||||
}
|
||||
|
||||
cutoff := time.Now().AddDate(0, 0, -90)
|
||||
deleted, err := db.PurgeToolExecutionsBefore(cutoff)
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||
}
|
||||
if deleted != 2 {
|
||||
t.Fatalf("deleted = %d, want 2", deleted)
|
||||
}
|
||||
|
||||
if _, err := db.GetToolExecution("old-completed"); err == nil {
|
||||
t.Fatal("old-completed should be deleted")
|
||||
}
|
||||
if _, err := db.GetToolExecution("old-failed"); err == nil {
|
||||
t.Fatal("old-failed should be deleted")
|
||||
}
|
||||
if _, err := db.GetToolExecution("new-completed"); err != nil {
|
||||
t.Fatalf("new-completed should remain: %v", err)
|
||||
}
|
||||
|
||||
stats, err := db.LoadToolStats()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadToolStats: %v", err)
|
||||
}
|
||||
stat := stats["nmap::scan"]
|
||||
if stat == nil {
|
||||
t.Fatal("expected stats for nmap::scan")
|
||||
}
|
||||
if stat.TotalCalls != 1 || stat.SuccessCalls != 1 || stat.FailedCalls != 0 {
|
||||
t.Fatalf("stats after purge = %+v, want total=1 success=1 failed=0", stat)
|
||||
}
|
||||
|
||||
total, err := db.CountToolExecutions("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CountToolExecutions: %v", err)
|
||||
}
|
||||
if total != 1 {
|
||||
t.Fatalf("remaining executions = %d, want 1", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPurgeToolExecutionsBefore_zeroRetentionSkipsViaService(t *testing.T) {
|
||||
// RetentionDaysEffective: 0 means no purge at service layer; DB method still works when called directly.
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: time.Now().AddDate(-1, 0, 0),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
deleted, err := db.PurgeToolExecutionsBefore(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("deleted = %d, want 1", deleted)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestLoadToolStatsSummaryAndListPage(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor-summary.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
now := time.Now()
|
||||
tools := []struct {
|
||||
name string
|
||||
calls int
|
||||
ok int
|
||||
fail int
|
||||
result string
|
||||
}{
|
||||
{"alpha::run", 10, 9, 1, `{"content":[{"type":"text","text":"` + string(make([]byte, 64*1024)) + `"}]}`},
|
||||
{"beta::scan", 5, 5, 0, `{"content":[{"type":"text","text":"ok"}]}`},
|
||||
{"gamma::ping", 1, 1, 0, `{"content":[{"type":"text","text":"pong"}]}`},
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if err := db.UpdateToolStats(tool.name, tool.calls, tool.ok, tool.fail, &now); err != nil {
|
||||
t.Fatalf("UpdateToolStats(%s): %v", tool.name, err)
|
||||
}
|
||||
for j := 0; j < tool.calls; j++ {
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: fmt.Sprintf("%s-exec-%d", tool.name, j),
|
||||
ToolName: tool.name,
|
||||
Arguments: map[string]interface{}{"n": j},
|
||||
Status: "completed",
|
||||
StartTime: now.Add(-time.Duration(j) * time.Minute),
|
||||
Result: &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: tool.result}}},
|
||||
}
|
||||
end := exec.StartTime.Add(time.Second)
|
||||
exec.EndTime = &end
|
||||
exec.Duration = time.Second
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
summary, err := db.LoadToolStatsSummary(2)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadToolStatsSummary: %v", err)
|
||||
}
|
||||
if summary.Summary.ToolCount != 3 {
|
||||
t.Fatalf("toolCount = %d, want 3", summary.Summary.ToolCount)
|
||||
}
|
||||
if summary.Summary.TotalCalls != 16 {
|
||||
t.Fatalf("totalCalls = %d, want 16", summary.Summary.TotalCalls)
|
||||
}
|
||||
if len(summary.TopTools) != 2 {
|
||||
t.Fatalf("top tools = %d, want 2", len(summary.TopTools))
|
||||
}
|
||||
if summary.TopTools[0].ToolName != "alpha::run" {
|
||||
t.Fatalf("top tool = %q, want alpha::run", summary.TopTools[0].ToolName)
|
||||
}
|
||||
|
||||
list, err := db.LoadToolExecutionListPage(0, 5, "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadToolExecutionListPage: %v", err)
|
||||
}
|
||||
if len(list) != 5 {
|
||||
t.Fatalf("list len = %d, want 5", len(list))
|
||||
}
|
||||
for _, exec := range list {
|
||||
if exec.Arguments != nil || exec.Result != nil || exec.Error != "" {
|
||||
t.Fatalf("expected lite execution row, got args/result/error on %s", exec.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除项目失败: %w", err)
|
||||
}
|
||||
db.removeProjectScopedDirs(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ package einomcp
|
||||
|
||||
import "sync"
|
||||
|
||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire,
|
||||
// 用于清除 pending tool_call(tool_result 由 ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)。
|
||||
// ToolInvokeNotifyHolder 由 Eino run loop 与 MCP/execute 桥共享;Fire 在工具原始返回时触发。
|
||||
// UI 的 tool_result 须等 ADK schema.Tool 事件(reduction 后正文),不在此 holder 的回调里推送。
|
||||
type ToolInvokeNotifyHolder struct {
|
||||
mu sync.RWMutex
|
||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||
|
||||
+130
-443
@@ -21,7 +21,6 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
@@ -178,8 +177,6 @@ type AgentHandler struct {
|
||||
}
|
||||
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
||||
batchCronParser cron.Parser
|
||||
batchRunnerMu sync.Mutex
|
||||
batchRunning map[string]struct{}
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
audit *audit.Service
|
||||
@@ -190,6 +187,37 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// TaskManager 返回 Agent 任务管理器(供 MCP 监控页终止 Eino execute 等)。
|
||||
func (h *AgentHandler) TaskManager() *AgentTaskManager {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
return h.tasks
|
||||
}
|
||||
|
||||
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||
if h == nil || conversationID == "" || h.tasks == nil {
|
||||
return
|
||||
}
|
||||
h.cancelActiveMCPToolForConversation(conversationID)
|
||||
h.tasks.AbortActiveEinoExecute(conversationID, "")
|
||||
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
||||
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
||||
} else if err != nil {
|
||||
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) cancelActiveMCPToolForConversation(conversationID string) {
|
||||
if h == nil || h.tasks == nil || h.agent == nil {
|
||||
return
|
||||
}
|
||||
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
||||
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
||||
}
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
type HitlToolWhitelistSaver interface {
|
||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||
@@ -218,8 +246,8 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
config: cfg,
|
||||
hitlManager: NewHITLManager(db, logger),
|
||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||
batchRunning: make(map[string]struct{}),
|
||||
}
|
||||
tm.SetToolCanceler(handler.cancelActiveMCPToolForConversation)
|
||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
||||
}
|
||||
@@ -631,40 +659,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
|
||||
assistantMessageID string,
|
||||
taskStatus *string,
|
||||
) (string, string, error) {
|
||||
curHist := history
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
errMA = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if errMA == nil {
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.agentSessionContextBlock(conversationID),
|
||||
)
|
||||
if errMA != nil {
|
||||
*taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
@@ -680,41 +679,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
|
||||
assistantMessageID string,
|
||||
taskStatus *string,
|
||||
) (string, string, error) {
|
||||
curHist := history
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
errMA = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if errMA == nil {
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
resultMA, errMA := multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, orchestration, nil, h.agentSessionContextBlock(conversationID),
|
||||
)
|
||||
if errMA != nil {
|
||||
*taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
@@ -1338,10 +1308,60 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
}
|
||||
|
||||
// cancelToolContinueAfter 仅终止当前工具调用,不停止整条 Agent 任务(对话「中断并继续」与 MCP 监控终止共用)。
|
||||
func (h *AgentHandler) cancelToolContinueAfter(conversationID, preferredExecID, note string) (bool, gin.H) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || h.tasks.GetTask(conversationID) == nil {
|
||||
return false, nil
|
||||
}
|
||||
note = strings.TrimSpace(note)
|
||||
execID := strings.TrimSpace(preferredExecID)
|
||||
if execID == "" {
|
||||
execID = h.tasks.ActiveMCPExecutionID(conversationID)
|
||||
}
|
||||
if execID != "" {
|
||||
if h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
return true, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": conversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
}
|
||||
}
|
||||
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
|
||||
return true, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": conversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
|
||||
return true, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": conversationID,
|
||||
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// CancelAgentLoop 取消正在执行的任务
|
||||
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
var req struct {
|
||||
ConversationID string `json:"conversationId" binding:"required"`
|
||||
ExecutionID string `json:"executionId,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
ContinueAfter bool `json:"continueAfter,omitempty"`
|
||||
}
|
||||
@@ -1356,27 +1376,20 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||
return
|
||||
}
|
||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
||||
note := strings.TrimSpace(req.Reason)
|
||||
if execID != "" {
|
||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||
activeExec := strings.TrimSpace(h.tasks.ActiveMCPExecutionID(req.ConversationID))
|
||||
if ok, payload := h.cancelToolContinueAfter(req.ConversationID, strings.TrimSpace(req.ExecutionID), note); ok {
|
||||
execID, _ := payload["executionId"].(string)
|
||||
h.logger.Info("对话页仅终止当前工具",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.String("executionId", execID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
})
|
||||
c.JSON(http.StatusOK, payload)
|
||||
return
|
||||
}
|
||||
if activeExec != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
return
|
||||
}
|
||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||
@@ -1408,6 +1421,8 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
|
||||
var cause error = ErrTaskCancelled
|
||||
msg := "已提交取消请求,任务将在当前步骤完成后停止。"
|
||||
h.cancelActiveMCPToolForConversation(req.ConversationID)
|
||||
h.tasks.AbortActiveEinoExecute(req.ConversationID, "")
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, cause)
|
||||
if err != nil {
|
||||
h.logger.Error("取消任务失败", zap.Error(err))
|
||||
@@ -1498,6 +1513,7 @@ type BatchTaskRequest struct {
|
||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
|
||||
}
|
||||
|
||||
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。
|
||||
@@ -1557,7 +1573,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
nextRunAt = &next
|
||||
}
|
||||
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
|
||||
if createErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||
return
|
||||
@@ -1747,15 +1763,16 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
Role string `json:"role"`
|
||||
AgentMode string `json:"agentMode"`
|
||||
Title string `json:"title"`
|
||||
Role string `json:"role"`
|
||||
AgentMode string `json:"agentMode"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil {
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1830,9 +1847,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
success := h.batchTaskManager.DeleteQueue(queueID)
|
||||
if !success {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrBatchQueueNotFound):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
|
||||
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
@@ -1926,7 +1951,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
|
||||
|
||||
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
||||
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
||||
h.forceUnmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
|
||||
}
|
||||
|
||||
autoStarted := true
|
||||
@@ -1985,26 +2010,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
|
||||
h.batchRunnerMu.Lock()
|
||||
defer h.batchRunnerMu.Unlock()
|
||||
if _, exists := h.batchRunning[queueID]; exists {
|
||||
return false
|
||||
}
|
||||
h.batchRunning[queueID] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
||||
h.batchRunnerMu.Lock()
|
||||
defer h.batchRunnerMu.Unlock()
|
||||
delete(h.batchRunning, queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||
expr := strings.TrimSpace(cronExpr)
|
||||
if expr == "" {
|
||||
@@ -2020,43 +2025,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
|
||||
|
||||
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
||||
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
||||
if !h.markBatchQueueRunning(queueID) {
|
||||
if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if scheduled {
|
||||
if queue.ScheduleMode != "cron" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("队列未启用 cron 调度")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("重置队列失败")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
} else if queue.Status != "pending" && queue.Status != "paused" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
return true, fmt.Errorf("队列状态不允许启动")
|
||||
}
|
||||
|
||||
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
|
||||
if scheduled {
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
@@ -2108,324 +2113,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// executeBatchQueue 执行批量任务队列
|
||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
defer h.unmarkBatchQueueRunning(queueID)
|
||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
|
||||
|
||||
for {
|
||||
// 检查队列状态
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
|
||||
break
|
||||
}
|
||||
|
||||
// 获取下一个任务
|
||||
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
|
||||
if !hasNext {
|
||||
// 所有任务完成:汇总子任务失败信息便于排障
|
||||
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
lastRunErr := ""
|
||||
if ok {
|
||||
for _, t := range q.Tasks {
|
||||
if t.Status == "failed" && t.Error != "" {
|
||||
lastRunErr = t.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
|
||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||
break
|
||||
}
|
||||
|
||||
// 更新任务状态为运行中
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
conversationID = conv.ID
|
||||
|
||||
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
|
||||
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := task.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
if queue.Role != "" && queue.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||
// 应用用户提示词
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
||||
}
|
||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
||||
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 预先创建助手消息,以便关联过程详情
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
// 如果创建失败,继续执行但不保存过程详情
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
|
||||
// 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。
|
||||
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
|
||||
var progressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||
|
||||
func() {
|
||||
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||
|
||||
registered := false
|
||||
finishStatus := "completed"
|
||||
|
||||
defer func() {
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
timeoutCancel()
|
||||
if registered {
|
||||
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
|
||||
if h.taskEventBus != nil {
|
||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||
if b, err := json.Marshal(ev); err == nil {
|
||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||
}
|
||||
}
|
||||
h.tasks.FinishTask(conversationID, finishStatus)
|
||||
}
|
||||
cancelWithCause(nil)
|
||||
}()
|
||||
|
||||
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if h.taskEventBus == nil {
|
||||
return
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
line := make([]byte, 0, len(b)+8)
|
||||
line = append(line, []byte("data: ")...)
|
||||
line = append(line, b...)
|
||||
line = append(line, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, line)
|
||||
}
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err))
|
||||
failMsg := err.Error()
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
|
||||
return
|
||||
}
|
||||
registered = true
|
||||
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
|
||||
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
|
||||
|
||||
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
|
||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
useBatchMulti := false
|
||||
batchOrch := "deep"
|
||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||
if am == "multi" {
|
||||
am = "deep"
|
||||
}
|
||||
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||
useBatchMulti = true
|
||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
||||
useBatchMulti = true
|
||||
batchOrch = "deep"
|
||||
}
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||
default:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errStr := runErr.Error()
|
||||
partialResp := ""
|
||||
if resultMA != nil {
|
||||
partialResp = resultMA.Response
|
||||
}
|
||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||
errors.Is(runErr, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||
|
||||
if isTimeout {
|
||||
finishStatus = "timeout"
|
||||
} else if isCancelled {
|
||||
finishStatus = "cancelled"
|
||||
} else {
|
||||
finishStatus = "failed"
|
||||
}
|
||||
|
||||
if isCancelled {
|
||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
// 如果执行结果中有更具体的取消消息,使用它
|
||||
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
||||
cancelMsg = partialResp
|
||||
}
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存取消详情到数据库
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
// 如果没有预先创建的助手消息,创建一个新的
|
||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
||||
if errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||
} else {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||
errorMsg := "执行失败: " + runErr.Error()
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存错误详情到数据库
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
resText := resultMA.Response
|
||||
mcpIDs := resultMA.MCPExecutionIDs
|
||||
lastIn := resultMA.LastAgentTraceInput
|
||||
lastOut := resultMA.LastAgentTraceOutput
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
// 如果更新失败,尝试创建新消息
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果没有预先创建的助手消息,创建一个新的
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存代理轨迹
|
||||
if lastIn != "" || lastOut != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存结果
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
|
||||
}
|
||||
}()
|
||||
|
||||
// 移动到下一个任务
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
|
||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||
break
|
||||
}
|
||||
|
||||
// 检查是否被取消或暂停
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
||||
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
||||
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
|
||||
|
||||
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
|
||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
|
||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
h.runBatchQueueWorker(queueID)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
h.tryFinalizeBatchQueue(queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
|
||||
for {
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if batchQueueExecutionShouldStop(queue, exists) {
|
||||
return
|
||||
}
|
||||
|
||||
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
|
||||
if !ok {
|
||||
if !h.batchTaskManager.HasRunningTasks(queueID) {
|
||||
return
|
||||
}
|
||||
time.Sleep(batchQueueWorkerIdlePoll)
|
||||
continue
|
||||
}
|
||||
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
|
||||
h.executeOneBatchSubTask(queueID, queue, task)
|
||||
|
||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
|
||||
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||
return
|
||||
}
|
||||
|
||||
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if batchQueueExecutionShouldStop(queue, exists) {
|
||||
if !exists {
|
||||
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists || queue == nil {
|
||||
return
|
||||
}
|
||||
if queue.Status != BatchQueueStatusRunning {
|
||||
return
|
||||
}
|
||||
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
|
||||
return
|
||||
}
|
||||
|
||||
lastRunErr := ""
|
||||
for _, t := range queue.Tasks {
|
||||
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
|
||||
lastRunErr = t.Error
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
|
||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||
}
|
||||
|
||||
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
|
||||
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
conversationID := conv.ID
|
||||
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
|
||||
|
||||
finalMessage := task.Message
|
||||
var roleTools []string
|
||||
if queue.Role != "" && queue.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
||||
}
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||
|
||||
registered := false
|
||||
finishStatus := "completed"
|
||||
|
||||
defer func() {
|
||||
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
|
||||
timeoutCancel()
|
||||
if registered {
|
||||
if h.taskEventBus != nil {
|
||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||
if b, err := json.Marshal(ev); err == nil {
|
||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||
}
|
||||
}
|
||||
h.tasks.FinishTask(conversationID, finishStatus)
|
||||
}
|
||||
cancelWithCause(nil)
|
||||
}()
|
||||
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if h.taskEventBus == nil {
|
||||
return
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
line := make([]byte, 0, len(b)+8)
|
||||
line = append(line, []byte("data: ")...)
|
||||
line = append(line, b...)
|
||||
line = append(line, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, line)
|
||||
}
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err))
|
||||
failMsg := err.Error()
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
|
||||
return
|
||||
}
|
||||
registered = true
|
||||
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
|
||||
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
||||
|
||||
useBatchMulti := false
|
||||
batchOrch := "deep"
|
||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||
if am == "multi" {
|
||||
am = "deep"
|
||||
}
|
||||
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||
useBatchMulti = true
|
||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||
useBatchMulti = true
|
||||
batchOrch = "deep"
|
||||
}
|
||||
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.agentSessionContextBlock(conversationID))
|
||||
default:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.agentSessionContextBlock(conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
|
||||
return
|
||||
}
|
||||
|
||||
if resultMA == nil {
|
||||
h.logger.Error("批量任务执行成功但无结果对象",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
resText := resultMA.Response
|
||||
mcpIDs := resultMA.MCPExecutionIDs
|
||||
lastIn := resultMA.LastAgentTraceInput
|
||||
lastOut := resultMA.LastAgentTraceOutput
|
||||
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
if lastIn != "" || lastOut != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) handleBatchSubTaskRunError(
|
||||
queueID string,
|
||||
task *BatchTask,
|
||||
conversationID, assistantMessageID string,
|
||||
baseCtx, taskCtx context.Context,
|
||||
resultMA *multiagent.RunResult,
|
||||
runErr error,
|
||||
finishStatus *string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errStr := runErr.Error()
|
||||
partialResp := ""
|
||||
if resultMA != nil {
|
||||
partialResp = resultMA.Response
|
||||
}
|
||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||
errors.Is(runErr, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||
|
||||
if isTimeout {
|
||||
*finishStatus = "timeout"
|
||||
} else if isCancelled {
|
||||
*finishStatus = "cancelled"
|
||||
} else {
|
||||
*finishStatus = "failed"
|
||||
}
|
||||
|
||||
if isCancelled {
|
||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
||||
cancelMsg = partialResp
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||
errorMsg := "执行失败: " + runErr.Error()
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -17,6 +18,15 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
|
||||
ErrBatchQueueNotFound = errors.New("batch queue not found")
|
||||
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
|
||||
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
|
||||
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
|
||||
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
|
||||
)
|
||||
|
||||
// 批量任务状态常量
|
||||
const (
|
||||
BatchQueueStatusPending = "pending"
|
||||
@@ -39,6 +49,12 @@ const (
|
||||
|
||||
// MaxBatchQueueRoleLen 角色名最大长度
|
||||
MaxBatchQueueRoleLen = 100
|
||||
|
||||
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
|
||||
DefaultBatchQueueConcurrency = 1
|
||||
|
||||
// MaxBatchQueueConcurrency 批量队列最大并发数
|
||||
MaxBatchQueueConcurrency = 8
|
||||
)
|
||||
|
||||
// BatchTask 批量任务项
|
||||
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
|
||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||
LastRunError string `json:"lastRunError,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
queues map[string]*BatchTaskQueue
|
||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||
taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
|
||||
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
||||
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
|
||||
return &BatchTaskManager{
|
||||
logger: logger,
|
||||
queues: make(map[string]*BatchTaskQueue),
|
||||
taskCancels: make(map[string]context.CancelFunc),
|
||||
taskCancels: make(map[string]map[string]context.CancelFunc),
|
||||
singleRunTasks: make(map[string]string),
|
||||
queueExecutors: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
|
||||
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
|
||||
if !exists || queue == nil {
|
||||
return true
|
||||
}
|
||||
switch queue.Status {
|
||||
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
|
||||
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, exists := m.queueExecutors[queueID]; exists {
|
||||
return false
|
||||
}
|
||||
m.queueExecutors[queueID] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
|
||||
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.queueExecutors, queueID)
|
||||
}
|
||||
|
||||
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
|
||||
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
|
||||
m.UnmarkQueueExecutor(queueID)
|
||||
}
|
||||
|
||||
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
|
||||
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.queueExecutors[queueID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接
|
||||
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
m.mu.Lock()
|
||||
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
m.db = db
|
||||
}
|
||||
|
||||
// normalizeBatchQueueConcurrency 规范化队列并发数。
|
||||
func normalizeBatchQueueConcurrency(n int) int {
|
||||
if n < 1 {
|
||||
return DefaultBatchQueueConcurrency
|
||||
}
|
||||
if n > MaxBatchQueueConcurrency {
|
||||
return MaxBatchQueueConcurrency
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(
|
||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||
nextRunAt *time.Time,
|
||||
concurrency int,
|
||||
tasks []string,
|
||||
) (*BatchTaskQueue, error) {
|
||||
// 输入校验
|
||||
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
CronExpr: strings.TrimSpace(cronExpr),
|
||||
NextRunAt: nextRunAt,
|
||||
ScheduleEnabled: true,
|
||||
Concurrency: normalizeBatchQueueConcurrency(concurrency),
|
||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||
Status: BatchQueueStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
queue.CronExpr,
|
||||
queue.NextRunAt,
|
||||
queue.ProjectID,
|
||||
queue.Concurrency,
|
||||
dbTasks,
|
||||
); err != nil {
|
||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用)
|
||||
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error {
|
||||
// batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
|
||||
func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
|
||||
if row == nil || !row.Concurrency.Valid {
|
||||
return DefaultBatchQueueConcurrency
|
||||
}
|
||||
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
|
||||
}
|
||||
|
||||
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
|
||||
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
|
||||
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
|
||||
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
|
||||
}
|
||||
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
|
||||
queue.Title = title
|
||||
queue.Role = role
|
||||
queue.AgentMode = agentMode
|
||||
if concurrency != nil {
|
||||
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
|
||||
}
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil {
|
||||
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
|
||||
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
|
||||
|
||||
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
||||
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||
var cancelFunc context.CancelFunc
|
||||
var siblingRunningIDs []string
|
||||
|
||||
m.mu.Lock()
|
||||
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||
}
|
||||
|
||||
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
||||
var cancelFuncs []context.CancelFunc
|
||||
if queue.Status == BatchQueueStatusPaused {
|
||||
if c, ok := m.taskCancels[queueID]; ok {
|
||||
cancelFunc = c
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||
for _, t := range queue.Tasks {
|
||||
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
||||
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
||||
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
||||
m.mu.Unlock()
|
||||
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
for _, c := range cancelFuncs {
|
||||
if c != nil {
|
||||
c()
|
||||
}
|
||||
}
|
||||
const staleRunMsg = "为单条执行其它任务,已中止"
|
||||
for _, sid := range siblingRunningIDs {
|
||||
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
|
||||
}
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个待执行的任务
|
||||
// ClaimNextPendingTask 原子领取下一个待执行子任务(并发 worker 安全)。
|
||||
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists || queue == nil {
|
||||
return nil, false
|
||||
}
|
||||
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
onlyTaskID := ""
|
||||
if m.singleRunTasks != nil {
|
||||
onlyTaskID = m.singleRunTasks[queueID]
|
||||
}
|
||||
|
||||
for i, task := range queue.Tasks {
|
||||
if task == nil || task.Status != BatchTaskStatusPending {
|
||||
continue
|
||||
}
|
||||
if onlyTaskID != "" && task.ID != onlyTaskID {
|
||||
continue
|
||||
}
|
||||
task.Status = BatchTaskStatusRunning
|
||||
queue.CurrentIndex = i
|
||||
return task, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// HasRunningTasks 队列是否仍有 running 状态的子任务。
|
||||
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists || queue == nil {
|
||||
return false
|
||||
}
|
||||
for _, task := range queue.Tasks {
|
||||
if task != nil && task.Status == BatchTaskStatusRunning {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
|
||||
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists || queue == nil {
|
||||
return false
|
||||
}
|
||||
for _, task := range queue.Tasks {
|
||||
if task == nil {
|
||||
continue
|
||||
}
|
||||
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
|
||||
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
|
||||
taskMap, ok := m.taskCancels[queueID]
|
||||
if !ok || len(taskMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
cancels := make([]context.CancelFunc, 0, len(taskMap))
|
||||
for _, c := range taskMap {
|
||||
if c != nil {
|
||||
cancels = append(cancels, c)
|
||||
}
|
||||
}
|
||||
delete(m.taskCancels, queueID)
|
||||
return cancels
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask)
|
||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskCancel 设置当前任务的取消函数
|
||||
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
|
||||
// SetTaskCancel 设置子任务的取消函数
|
||||
func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
m.taskCancels[queueID] = cancel
|
||||
} else {
|
||||
delete(m.taskCancels, queueID)
|
||||
if cancel == nil {
|
||||
if taskMap, ok := m.taskCancels[queueID]; ok {
|
||||
delete(taskMap, taskID)
|
||||
if len(taskMap) == 0 {
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if m.taskCancels[queueID] == nil {
|
||||
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
|
||||
}
|
||||
m.taskCancels[queueID][taskID] = cancel
|
||||
}
|
||||
|
||||
// PauseQueue 暂停队列
|
||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
var cancelFunc context.CancelFunc
|
||||
var cancelFuncs []context.CancelFunc
|
||||
|
||||
m.mu.Lock()
|
||||
queue, exists := m.queues[queueID]
|
||||
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
}
|
||||
|
||||
queue.Status = BatchQueueStatusPaused
|
||||
|
||||
// 取消当前正在执行的任务(通过取消context)
|
||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
||||
cancelFunc = cancel
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
for _, c := range cancelFuncs {
|
||||
c()
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
now := time.Now()
|
||||
var cancelFunc context.CancelFunc
|
||||
var cancelFuncs []context.CancelFunc
|
||||
|
||||
m.mu.Lock()
|
||||
queue, exists := m.queues[queueID]
|
||||
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// 取消当前正在执行的任务
|
||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
||||
cancelFunc = cancel
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
for _, c := range cancelFuncs {
|
||||
c()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteQueue 删除队列(运行中的队列不允许删除)
|
||||
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
// DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
|
||||
func (m *BatchTaskManager) DeleteQueue(queueID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
return ErrBatchQueueNotFound
|
||||
}
|
||||
|
||||
if _, exec := m.queueExecutors[queueID]; exec {
|
||||
return ErrBatchQueueExecutorActive
|
||||
}
|
||||
|
||||
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
|
||||
if queue.Status == BatchQueueStatusRunning {
|
||||
return false
|
||||
return ErrBatchQueueStillRunning
|
||||
}
|
||||
|
||||
// 清理取消函数
|
||||
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
}
|
||||
|
||||
delete(m.queues, queueID)
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateShortID 生成短ID
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
|
||||
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
|
||||
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
|
||||
}
|
||||
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
|
||||
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimNextPendingTaskParallel(t *testing.T) {
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBatchQueue: %v", err)
|
||||
}
|
||||
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||
|
||||
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
|
||||
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
|
||||
if !ok1 || !ok2 || t1.ID == t2.ID {
|
||||
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
|
||||
}
|
||||
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
|
||||
t.Fatalf("claimed tasks should be running")
|
||||
}
|
||||
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
|
||||
if !ok3 {
|
||||
t.Fatal("expected third claim")
|
||||
}
|
||||
_, ok4 := m.ClaimNextPendingTask(queue.ID)
|
||||
if ok4 {
|
||||
t.Fatal("expected no fourth pending task")
|
||||
}
|
||||
_ = t3
|
||||
}
|
||||
|
||||
func TestBatchQueueExecutionShouldStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !batchQueueExecutionShouldStop(nil, false) {
|
||||
t.Fatal("expected stop when queue missing")
|
||||
}
|
||||
if !batchQueueExecutionShouldStop(nil, true) {
|
||||
t.Fatal("expected stop when queue is nil but exists=true")
|
||||
}
|
||||
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
|
||||
if batchQueueExecutionShouldStop(q, true) {
|
||||
t.Fatal("expected continue when running")
|
||||
}
|
||||
q.Status = BatchQueueStatusCancelled
|
||||
if !batchQueueExecutionShouldStop(q, true) {
|
||||
t.Fatal("expected stop when cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBatchQueue: %v", err)
|
||||
}
|
||||
if !m.TryMarkQueueExecutor(queue.ID) {
|
||||
t.Fatal("expected to mark executor")
|
||||
}
|
||||
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
|
||||
|
||||
err = m.DeleteQueue(queue.ID)
|
||||
if !errors.Is(err, ErrBatchQueueExecutorActive) {
|
||||
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
|
||||
}
|
||||
if _, ok := m.GetBatchQueue(queue.ID); !ok {
|
||||
t.Fatal("queue should still exist while executor active")
|
||||
}
|
||||
|
||||
m.UnmarkQueueExecutor(queue.ID)
|
||||
if err := m.DeleteQueue(queue.ID); err != nil {
|
||||
t.Fatalf("expected delete after executor unmarked, got %v", err)
|
||||
}
|
||||
if _, ok := m.GetBatchQueue(queue.ID); ok {
|
||||
t.Fatal("queue should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBatchQueue: %v", err)
|
||||
}
|
||||
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||
|
||||
err = m.DeleteQueue(queue.ID)
|
||||
if !errors.Is(err, ErrBatchQueueStillRunning) {
|
||||
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
if !m.TryMarkQueueExecutor("q-1") {
|
||||
t.Fatal("first mark should succeed")
|
||||
}
|
||||
if m.TryMarkQueueExecutor("q-1") {
|
||||
t.Fatal("second mark should fail")
|
||||
}
|
||||
m.UnmarkQueueExecutor("q-1")
|
||||
if !m.TryMarkQueueExecutor("q-1") {
|
||||
t.Fatal("mark after unmark should succeed")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"type": "string",
|
||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||
},
|
||||
"concurrency": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
executeNow = false
|
||||
}
|
||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||
concurrency := int(mcpArgFloat(args, "concurrency"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
|
||||
if createErr != nil {
|
||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||
}
|
||||
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.DeleteQueue(qid) {
|
||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||
if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrBatchQueueNotFound):
|
||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
|
||||
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
|
||||
default:
|
||||
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
|
||||
}
|
||||
}
|
||||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("队列已删除。", false), nil
|
||||
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
||||
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||
},
|
||||
"concurrency": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "同时执行的子任务数,默认 1,最大 8",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
title := mcpArgString(args, "title")
|
||||
role := mcpArgString(args, "role")
|
||||
agentMode := mcpArgString(args, "agent_mode")
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
|
||||
var concurrency *int
|
||||
if raw, ok := args["concurrency"]; ok && raw != nil {
|
||||
v := int(mcpArgFloat(args, "concurrency"))
|
||||
concurrency = &v
|
||||
}
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
TaskTotal int `json:"task_total"`
|
||||
TaskCounts map[string]int `json:"task_counts"`
|
||||
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
||||
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
|
||||
StartedAt: q.StartedAt,
|
||||
CompletedAt: q.CompletedAt,
|
||||
CurrentIndex: q.CurrentIndex,
|
||||
Concurrency: q.Concurrency,
|
||||
TaskTotal: len(tasks),
|
||||
TaskCounts: counts,
|
||||
Tasks: tasks,
|
||||
|
||||
@@ -12,11 +12,17 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
|
||||
type ConversationTaskStopper interface {
|
||||
CancelRunningTaskForConversation(conversationID string)
|
||||
}
|
||||
|
||||
// ConversationHandler 对话处理器
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
taskStopper ConversationTaskStopper
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
|
||||
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
|
||||
h.taskStopper = stopper
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||
return &ConversationHandler{
|
||||
@@ -92,6 +103,7 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
search := c.Query("search") // 获取搜索参数
|
||||
projectID := strings.TrimSpace(c.Query("project_id"))
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
@@ -103,7 +115,7 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
excludeGrouped := strings.TrimSpace(search) == "" &&
|
||||
excludeGrouped := strings.TrimSpace(search) == "" && projectID == "" &&
|
||||
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
||||
sortBy := strings.TrimSpace(c.Query("sort_by"))
|
||||
|
||||
@@ -111,14 +123,14 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
var total int
|
||||
var err error
|
||||
if excludeGrouped {
|
||||
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy)
|
||||
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy, projectID)
|
||||
if err == nil {
|
||||
total, err = h.db.CountUngroupedConversations()
|
||||
total, err = h.db.CountUngroupedConversations(projectID)
|
||||
}
|
||||
} else {
|
||||
conversations, err = h.db.ListConversations(limit, offset, search, sortBy)
|
||||
conversations, err = h.db.ListConversations(limit, offset, search, sortBy, projectID)
|
||||
if err == nil {
|
||||
total, err = h.db.CountConversations(search)
|
||||
total, err = h.db.CountConversations(search, projectID)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -165,6 +177,9 @@ func (h *ConversationHandler) GetConversation(c *gin.Context) {
|
||||
}
|
||||
|
||||
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
|
||||
// 查询参数:
|
||||
// - summary=1:仅返回摘要(total / iterationCount / maxIteration)
|
||||
// - limit + offset:分页返回 processDetails(未指定 limit 时保持全量兼容)
|
||||
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
messageID := c.Param("id")
|
||||
if messageID == "" {
|
||||
@@ -172,6 +187,51 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
summaryStr := strings.TrimSpace(c.Query("summary"))
|
||||
if summaryStr == "1" || strings.EqualFold(summaryStr, "true") || strings.EqualFold(summaryStr, "yes") {
|
||||
summary, err := h.db.GetProcessDetailsSummary(messageID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取过程详情摘要失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"summary": summary})
|
||||
return
|
||||
}
|
||||
|
||||
limitStr := strings.TrimSpace(c.Query("limit"))
|
||||
if limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid limit"})
|
||||
return
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
offset, _ := strconv.Atoi(strings.TrimSpace(c.Query("offset")))
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
details, total, err := h.db.GetProcessDetailsPage(messageID, limit, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("分页获取过程详情失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
out := processDetailsToJSON(h.logger, details)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"processDetails": out,
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"hasMore": offset+len(out) < total,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
details, err := h.db.GetProcessDetails(messageID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取过程详情失败", zap.Error(err))
|
||||
@@ -180,14 +240,17 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
}
|
||||
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
out := processDetailsToJSON(h.logger, details)
|
||||
c.JSON(http.StatusOK, gin.H{"processDetails": out, "total": len(out)})
|
||||
}
|
||||
|
||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||
func processDetailsToJSON(logger *zap.Logger, details []database.ProcessDetail) []map[string]interface{} {
|
||||
out := make([]map[string]interface{}, 0, len(details))
|
||||
for _, d := range details {
|
||||
var data interface{}
|
||||
if d.Data != "" {
|
||||
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
|
||||
h.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
out = append(out, map[string]interface{}{
|
||||
@@ -200,8 +263,7 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
"createdAt": d.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"processDetails": out})
|
||||
return out
|
||||
}
|
||||
|
||||
// UpdateConversationRequest 更新对话请求
|
||||
@@ -245,6 +307,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
|
||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if h.taskStopper != nil {
|
||||
h.taskStopper.CancelRunningTaskForConversation(id)
|
||||
}
|
||||
|
||||
if err := h.db.DeleteConversation(id); err != nil {
|
||||
h.logger.Error("删除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
|
||||
tm := NewAgentTaskManager()
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
|
||||
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
|
||||
h.CancelRunningTaskForConversation("conv-1")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("task context was not cancelled")
|
||||
}
|
||||
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
|
||||
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
|
||||
}
|
||||
}
|
||||
@@ -2,31 +2,11 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
@@ -45,136 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
|
||||
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
|
||||
func (h *AgentHandler) handleEinoEmptyResponseContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
emptyResponseAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, exhausted bool) {
|
||||
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
|
||||
return false, false
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*emptyResponseAttempts++
|
||||
if *emptyResponseAttempts > maxAttempts {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("eino empty response auto resume exhausted",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
attemptNo := *emptyResponseAttempts
|
||||
if h.logger != nil {
|
||||
h.logger.Info("eino empty response, auto resume from trace",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if sendProgress != nil {
|
||||
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "empty_response_continue",
|
||||
})
|
||||
}
|
||||
return true, false
|
||||
}
|
||||
|
||||
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
@@ -177,8 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
taskOwned = true
|
||||
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
@@ -215,6 +212,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
@@ -233,61 +231,18 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
h.agentSessionContextBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -312,8 +267,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
@@ -448,8 +401,6 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
@@ -465,30 +416,11 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
h.agentSessionContextBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
|
||||
+323
-25
@@ -5,13 +5,16 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/monitor"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -19,12 +22,20 @@ import (
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
taskManager *AgentTaskManager
|
||||
agentHandler *AgentHandler
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
monitorRetention *monitor.Service
|
||||
}
|
||||
|
||||
// SetMonitorRetention wires MCP execution retention settings.
|
||||
func (h *MonitorHandler) SetMonitorRetention(s *monitor.Service) {
|
||||
h.monitorRetention = s
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
@@ -48,15 +59,44 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
||||
h.externalMCPMgr = mgr
|
||||
}
|
||||
|
||||
// SetTaskManager 设置 Agent 任务管理器(用于 Eino execute 等按 executionId 终止)。
|
||||
func (h *MonitorHandler) SetTaskManager(mgr *AgentTaskManager) {
|
||||
h.taskManager = mgr
|
||||
}
|
||||
|
||||
// SetAgentHandler 设置 Agent 处理器(MCP 监控终止与对话页「中断并继续」共用逻辑)。
|
||||
func (h *MonitorHandler) SetAgentHandler(ah *AgentHandler) {
|
||||
h.agentHandler = ah
|
||||
}
|
||||
|
||||
const monitorPageTopTools = 6
|
||||
|
||||
// MonitorStatsSummary 工具调用汇总
|
||||
type MonitorStatsSummary struct {
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
SuccessCalls int `json:"successCalls"`
|
||||
FailedCalls int `json:"failedCalls"`
|
||||
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
|
||||
ToolCount int `json:"toolCount"`
|
||||
}
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Summary *MonitorStatsSummary `json:"summary"`
|
||||
TopTools []*mcp.ToolStats `json:"topTools"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"pageSize"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
RetentionDays int `json:"retentionDays"`
|
||||
}
|
||||
|
||||
// StatsResponse 统计信息响应(Dashboard 等)
|
||||
type StatsResponse struct {
|
||||
Summary *MonitorStatsSummary `json:"summary"`
|
||||
TopTools []*mcp.ToolStats `json:"topTools"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
@@ -80,8 +120,9 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool)
|
||||
toolName := normalizeToolNameFilter(c.Query("tool"))
|
||||
|
||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
||||
stats := h.loadStats()
|
||||
executions, total := h.loadExecutionListWithPagination(page, pageSize, status, toolName)
|
||||
h.enrichExecutionsConversationID(executions)
|
||||
summary, topTools := h.loadStatsSummary(monitorPageTopTools)
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
@@ -89,21 +130,136 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, MonitorResponse{
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
Executions: executions,
|
||||
Summary: summary,
|
||||
TopTools: topTools,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
RetentionDays: h.monitorRetentionDays(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) monitorRetentionDays() int {
|
||||
if h.monitorRetention != nil {
|
||||
return h.monitorRetention.RetentionDays()
|
||||
}
|
||||
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
||||
return executions
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutionListWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||
if h.db == nil {
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
if status != "" || toolName != "" {
|
||||
filtered := make([]*mcp.ToolExecution, 0)
|
||||
for _, exec := range allExecutions {
|
||||
matchStatus := status == "" || exec.Status == status
|
||||
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||
if matchStatus && matchTool {
|
||||
filtered = append(filtered, exec)
|
||||
}
|
||||
}
|
||||
allExecutions = filtered
|
||||
}
|
||||
total := len(allExecutions)
|
||||
offset := (page - 1) * pageSize
|
||||
end := offset + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
if offset >= total {
|
||||
return []*mcp.ToolExecution{}, total
|
||||
}
|
||||
pageSlice := allExecutions[offset:end]
|
||||
out := make([]*mcp.ToolExecution, 0, len(pageSlice))
|
||||
for _, exec := range pageSlice {
|
||||
if exec == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, slimToolExecution(exec))
|
||||
}
|
||||
return out, total
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
executions, err := h.db.LoadToolExecutionListPage(offset, pageSize, status, toolName)
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载执行记录列表失败,回退到内存数据", zap.Error(err))
|
||||
return h.loadExecutionListWithPaginationFromMemory(page, pageSize, status, toolName)
|
||||
}
|
||||
|
||||
total, err := h.db.CountToolExecutions(status, toolName)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
|
||||
total = offset + len(executions)
|
||||
if len(executions) == pageSize {
|
||||
total = offset + len(executions) + 1
|
||||
}
|
||||
}
|
||||
|
||||
return executions, total
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutionListWithPaginationFromMemory(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
if status != "" || toolName != "" {
|
||||
filtered := make([]*mcp.ToolExecution, 0)
|
||||
for _, exec := range allExecutions {
|
||||
matchStatus := status == "" || exec.Status == status
|
||||
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||
if matchStatus && matchTool {
|
||||
filtered = append(filtered, exec)
|
||||
}
|
||||
}
|
||||
allExecutions = filtered
|
||||
}
|
||||
total := len(allExecutions)
|
||||
offset := (page - 1) * pageSize
|
||||
end := offset + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
if offset >= total {
|
||||
return []*mcp.ToolExecution{}, total
|
||||
}
|
||||
pageSlice := allExecutions[offset:end]
|
||||
out := make([]*mcp.ToolExecution, 0, len(pageSlice))
|
||||
for _, exec := range pageSlice {
|
||||
if exec == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, slimToolExecution(exec))
|
||||
}
|
||||
return out, total
|
||||
}
|
||||
|
||||
func slimToolExecution(exec *mcp.ToolExecution) *mcp.ToolExecution {
|
||||
if exec == nil {
|
||||
return nil
|
||||
}
|
||||
slim := &mcp.ToolExecution{
|
||||
ID: exec.ID,
|
||||
ToolName: exec.ToolName,
|
||||
Status: exec.Status,
|
||||
StartTime: exec.StartTime,
|
||||
}
|
||||
if exec.EndTime != nil {
|
||||
end := *exec.EndTime
|
||||
slim.EndTime = &end
|
||||
}
|
||||
if exec.Duration > 0 {
|
||||
slim.Duration = exec.Duration
|
||||
}
|
||||
return slim
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||
if h.db == nil {
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
@@ -176,7 +332,78 @@ func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status
|
||||
return executions, total
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
||||
func (h *MonitorHandler) loadStatsSummary(topN int) (*MonitorStatsSummary, []*mcp.ToolStats) {
|
||||
if topN <= 0 {
|
||||
topN = monitorPageTopTools
|
||||
}
|
||||
|
||||
if h.db != nil {
|
||||
result, err := h.db.LoadToolStatsSummary(topN)
|
||||
if err == nil {
|
||||
return dbStatsSummaryToMonitor(result), result.TopTools
|
||||
}
|
||||
h.logger.Warn("从数据库加载统计汇总失败,回退到内存数据", zap.Error(err))
|
||||
}
|
||||
|
||||
stats := h.loadStatsMap()
|
||||
return summarizeToolStats(stats, topN)
|
||||
}
|
||||
|
||||
func dbStatsSummaryToMonitor(result *database.ToolStatsSummaryResult) *MonitorStatsSummary {
|
||||
if result == nil {
|
||||
return &MonitorStatsSummary{}
|
||||
}
|
||||
summary := &MonitorStatsSummary{
|
||||
TotalCalls: result.Summary.TotalCalls,
|
||||
SuccessCalls: result.Summary.SuccessCalls,
|
||||
FailedCalls: result.Summary.FailedCalls,
|
||||
ToolCount: result.Summary.ToolCount,
|
||||
}
|
||||
if result.Summary.LastCallTime != nil {
|
||||
t := *result.Summary.LastCallTime
|
||||
summary.LastCallTime = &t
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
func summarizeToolStats(stats map[string]*mcp.ToolStats, topN int) (*MonitorStatsSummary, []*mcp.ToolStats) {
|
||||
summary := &MonitorStatsSummary{}
|
||||
if len(stats) == 0 {
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
all := make([]*mcp.ToolStats, 0, len(stats))
|
||||
for _, stat := range stats {
|
||||
if stat == nil {
|
||||
continue
|
||||
}
|
||||
summary.ToolCount++
|
||||
summary.TotalCalls += stat.TotalCalls
|
||||
summary.SuccessCalls += stat.SuccessCalls
|
||||
summary.FailedCalls += stat.FailedCalls
|
||||
if stat.LastCallTime != nil && (summary.LastCallTime == nil || stat.LastCallTime.After(*summary.LastCallTime)) {
|
||||
t := *stat.LastCallTime
|
||||
summary.LastCallTime = &t
|
||||
}
|
||||
if stat.TotalCalls > 0 {
|
||||
statCopy := *stat
|
||||
all = append(all, &statCopy)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(all, func(i, j int) bool {
|
||||
if all[i].TotalCalls == all[j].TotalCalls {
|
||||
return all[i].ToolName < all[j].ToolName
|
||||
}
|
||||
return all[i].TotalCalls > all[j].TotalCalls
|
||||
})
|
||||
if len(all) > topN {
|
||||
all = all[:topN]
|
||||
}
|
||||
return summary, all
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadStatsMap() map[string]*mcp.ToolStats {
|
||||
// 合并内部MCP服务器和外部MCP管理器的统计信息
|
||||
stats := make(map[string]*mcp.ToolStats)
|
||||
|
||||
@@ -230,6 +457,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
// 先从内部MCP服务器查找
|
||||
exec, exists := h.mcpServer.GetExecution(id)
|
||||
if exists {
|
||||
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
@@ -238,6 +466,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
if h.externalMCPMgr != nil {
|
||||
exec, exists = h.externalMCPMgr.GetExecution(id)
|
||||
if exists {
|
||||
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
@@ -247,6 +476,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
if h.db != nil {
|
||||
exec, err := h.db.GetToolExecution(id)
|
||||
if err == nil && exec != nil {
|
||||
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
@@ -273,6 +503,19 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
note = strings.TrimSpace(body.Note)
|
||||
|
||||
convID := h.conversationIDForRunningExecution(id)
|
||||
if convID != "" && h.agentHandler != nil {
|
||||
if ok, payload := h.agentHandler.cancelToolContinueAfter(convID, id, note); ok {
|
||||
h.logger.Info("MCP 监控页终止工具(与对话中断并继续一致)",
|
||||
zap.String("executionId", id),
|
||||
zap.String("conversationId", convID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, payload)
|
||||
return
|
||||
}
|
||||
}
|
||||
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||
@@ -286,6 +529,52 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) enrichExecutionsConversationID(executions []*mcp.ToolExecution) {
|
||||
for _, exec := range executions {
|
||||
if exec == nil || exec.Status != "running" {
|
||||
continue
|
||||
}
|
||||
exec.ConversationID = h.conversationIDForRunningExecution(exec.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) conversationIDForRunningExecution(executionID string) string {
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
if executionID == "" || h.taskManager == nil {
|
||||
return ""
|
||||
}
|
||||
if conv := h.taskManager.ConversationIDForActiveMCPExecution(executionID); conv != "" {
|
||||
return conv
|
||||
}
|
||||
exec := h.lookupExecution(executionID)
|
||||
if exec == nil || exec.Status != "running" {
|
||||
return ""
|
||||
}
|
||||
if strings.TrimSpace(exec.ToolName) == "execute" {
|
||||
if onlyConv, ok := h.taskManager.ConversationIDForActiveEinoExecute(); ok {
|
||||
return onlyConv
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) lookupExecution(id string) *mcp.ToolExecution {
|
||||
if exec, ok := h.mcpServer.GetExecution(id); ok {
|
||||
return exec
|
||||
}
|
||||
if h.externalMCPMgr != nil {
|
||||
if exec, ok := h.externalMCPMgr.GetExecution(id); ok {
|
||||
return exec
|
||||
}
|
||||
}
|
||||
if h.db != nil {
|
||||
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
|
||||
return exec
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
||||
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||
var req struct {
|
||||
@@ -323,8 +612,17 @@ func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||
stats := h.loadStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
topN := 30
|
||||
if topStr := c.Query("top"); topStr != "" {
|
||||
if t, err := strconv.Atoi(topStr); err == nil && t > 0 && t <= 100 {
|
||||
topN = t
|
||||
}
|
||||
}
|
||||
summary, topTools := h.loadStatsSummary(topN)
|
||||
c.JSON(http.StatusOK, StatsResponse{
|
||||
Summary: summary,
|
||||
TopTools: topTools,
|
||||
})
|
||||
}
|
||||
|
||||
// CallsTimelinePoint 调用趋势数据点
|
||||
|
||||
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
@@ -187,8 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
@@ -225,6 +222,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
@@ -245,61 +243,18 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
h.agentSessionContextBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -324,8 +279,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
@@ -460,8 +413,6 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
@@ -479,30 +430,11 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
h.agentSessionContextBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
|
||||
@@ -740,14 +740,21 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"properties": map[string]interface{}{
|
||||
"executions": map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "执行记录列表",
|
||||
"description": "执行记录列表(轻量字段,不含 arguments/result)",
|
||||
"items": map[string]interface{}{
|
||||
"$ref": "#/components/schemas/ToolExecution",
|
||||
},
|
||||
},
|
||||
"stats": map[string]interface{}{
|
||||
"summary": map[string]interface{}{
|
||||
"type": "object",
|
||||
"description": "统计信息",
|
||||
"description": "工具调用汇总",
|
||||
},
|
||||
"topTools": map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "调用量 Top N 工具",
|
||||
"items": map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
"timestamp": map[string]interface{}{
|
||||
"type": "string",
|
||||
@@ -756,20 +763,24 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
"total": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "总数",
|
||||
"description": "执行记录总数",
|
||||
},
|
||||
"page": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "当前页",
|
||||
},
|
||||
"page_size": map[string]interface{}{
|
||||
"pageSize": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "每页数量",
|
||||
},
|
||||
"total_pages": map[string]interface{}{
|
||||
"totalPages": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "总页数",
|
||||
},
|
||||
"retentionDays": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "执行记录保留天数",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ConfigResponse": map[string]interface{}{
|
||||
@@ -1232,6 +1243,34 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "project_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"description": "按项目筛选;传 __none__ 表示仅未绑定项目的对话",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "exclude_grouped",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"description": "为 true 时排除已加入分组的对话(默认在未搜索且未按项目筛选时启用)",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "sort_by",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"description": "排序字段:updated_at(默认)或 created_at",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"updated_at", "created_at"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
|
||||
@@ -7,6 +7,45 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// agentSessionContextBlock 注入会话工作目录、项目黑板与用户原文锚点(用于 system prompt 追加块)。
|
||||
func (h *AgentHandler) agentSessionContextBlock(conversationID string) string {
|
||||
var parts []string
|
||||
if ws := h.buildWorkspaceBlock(conversationID); ws != "" {
|
||||
parts = append(parts, ws)
|
||||
}
|
||||
if bb := h.projectBlackboardBlock(conversationID); bb != "" {
|
||||
parts = append(parts, bb)
|
||||
}
|
||||
if uv := h.userVerbatimAnchorBlock(conversationID); uv != "" {
|
||||
parts = append(parts, uv)
|
||||
}
|
||||
return strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
func (h *AgentHandler) buildWorkspaceBlock(conversationID string) string {
|
||||
if h == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
projectID := h.conversationProjectID(conversationID)
|
||||
rel := project.WorkspaceRootDir(h.config.Agent.WorkspaceRootDir, projectID, conversationID)
|
||||
abs, err := project.EnsureWorkspace(rel)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("创建会话工作目录失败",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("projectId", projectID),
|
||||
zap.String("path", rel),
|
||||
zap.Error(err))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return project.BuildWorkspaceBlock(abs)
|
||||
}
|
||||
|
||||
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
@@ -31,6 +70,29 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
return strings.TrimSpace(block)
|
||||
}
|
||||
|
||||
// userVerbatimAnchorBlock 从 messages 表构建用户各轮原文锚点(压缩后仍由 summarization Finalize 刷新)。
|
||||
func (h *AgentHandler) userVerbatimAnchorBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
maxRunes := h.config.MultiAgent.UserVerbatimAnchorMaxRunesEffective()
|
||||
if maxRunes < 0 {
|
||||
return ""
|
||||
}
|
||||
msgs, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("构建用户原文锚点失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return project.BuildUserVerbatimAnchorBlockFromMessages(msgs, maxRunes)
|
||||
}
|
||||
|
||||
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
||||
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
||||
if h == nil || h.db == nil {
|
||||
|
||||
@@ -447,7 +447,7 @@ func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
||||
}
|
||||
|
||||
func (h *RobotHandler) cmdList() string {
|
||||
convs, err := h.db.ListConversations(50, 0, "", "")
|
||||
convs, err := h.db.ListConversations(50, 0, "", "", "")
|
||||
if err != nil {
|
||||
return "获取对话列表失败: " + err.Error()
|
||||
}
|
||||
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
|
||||
h.mu.Unlock()
|
||||
h.deleteSessionBinding(sk)
|
||||
}
|
||||
if h.agentHandler != nil {
|
||||
h.agentHandler.CancelRunningTaskForConversation(convID)
|
||||
}
|
||||
if err := h.db.DeleteConversation(convID); err != nil {
|
||||
return "删除失败: " + err.Error()
|
||||
}
|
||||
|
||||
@@ -37,6 +37,11 @@ type AgentTask struct {
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
// activeEinoExecuteCancel 当前进行中的 Eino filesystem execute 取消函数(与 MCP 工具并行,供中断并继续)
|
||||
activeEinoExecuteCancel context.CancelFunc
|
||||
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||
activeEinoExecuteAbortNote string
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -70,6 +75,103 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterActiveEinoExecute 登记进行中的 Eino filesystem execute(每会话同时仅一条)。
|
||||
func (m *AgentTaskManager) RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.activeEinoExecuteCancel = cancel
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterActiveEinoExecute execute 正常结束或已取消后清除登记。
|
||||
func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.activeEinoExecuteCancel = nil
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
}
|
||||
}
|
||||
|
||||
// ConversationIDForActiveMCPExecution 根据当前登记的工具 executionId 反查会话 ID(供 MCP 监控页按 executionId 终止)。
|
||||
func (m *AgentTaskManager) ConversationIDForActiveMCPExecution(executionID string) string {
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
if executionID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for convID, t := range m.tasks {
|
||||
if t != nil && t.ActiveMCPExecutionID == executionID {
|
||||
return convID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConversationIDForActiveEinoExecute 返回当前唯一进行 Eino execute 的会话 ID;多会话并行时返回空。
|
||||
func (m *AgentTaskManager) ConversationIDForActiveEinoExecute() (string, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
var found string
|
||||
count := 0
|
||||
for convID, t := range m.tasks {
|
||||
if t != nil && t.activeEinoExecuteCancel != nil {
|
||||
found = convID
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count == 1 {
|
||||
return found, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
||||
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return false
|
||||
}
|
||||
m.mu.Lock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil || t.activeEinoExecuteCancel == nil {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
t.activeEinoExecuteAbortNote = strings.TrimSpace(note)
|
||||
cancel := t.activeEinoExecuteCancel
|
||||
m.mu.Unlock()
|
||||
cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
// TakeEinoExecuteAbortNote 读取并清空 execute 终止说明(execute 收尾时调用一次)。
|
||||
func (m *AgentTaskManager) TakeEinoExecuteAbortNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.activeEinoExecuteAbortNote
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
@@ -145,6 +247,8 @@ type AgentTaskManager struct {
|
||||
maxHistorySize int // 最大历史记录数
|
||||
historyRetention time.Duration // 历史记录保留时间
|
||||
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
|
||||
// toolCanceler 在用户整轮停止任务时终止当前 MCP 工具(非「中断并继续」)。
|
||||
toolCanceler func(conversationID string)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -175,6 +279,13 @@ func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) {
|
||||
m.eventBus = b
|
||||
}
|
||||
|
||||
// SetToolCanceler 设置整轮停止任务时终止当前 MCP 工具的回调(由 AgentHandler 注入)。
|
||||
func (m *AgentTaskManager) SetToolCanceler(fn func(conversationID string)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.toolCanceler = fn
|
||||
}
|
||||
|
||||
// GetTask 返回运行中任务(无则 nil)。
|
||||
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
|
||||
m.mu.RLock()
|
||||
@@ -270,14 +381,21 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
|
||||
task.InterruptContinueNote = ""
|
||||
}
|
||||
cancel := task.cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
if cause == nil {
|
||||
cause = ErrTaskCancelled
|
||||
}
|
||||
var toolCanceler func(string)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
toolCanceler = m.toolCanceler
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel(cause)
|
||||
}
|
||||
if toolCanceler != nil {
|
||||
toolCanceler(conversationID)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAbortActiveEinoExecute(t *testing.T) {
|
||||
m := NewAgentTaskManager()
|
||||
conv := "conv-eino-exec-abort"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_, err := m.StartTask(conv, "test", func(error) {})
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
m.RegisterActiveEinoExecute(conv, cancel)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if !m.AbortActiveEinoExecute(conv, "跳过域名收集") {
|
||||
t.Fatal("expected abort to succeed")
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("execute cancel did not propagate")
|
||||
}
|
||||
if got := m.TakeEinoExecuteAbortNote(conv); got != "跳过域名收集" {
|
||||
t.Fatalf("abort note = %q, want 跳过域名收集", got)
|
||||
}
|
||||
m.UnregisterActiveEinoExecute(conv)
|
||||
if m.AbortActiveEinoExecute(conv, "") {
|
||||
t.Fatal("second abort should fail when no active execute")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConversationIDForActiveMCPExecution(t *testing.T) {
|
||||
m := NewAgentTaskManager()
|
||||
conv := "conv-mcp-exec"
|
||||
_, err := m.StartTask(conv, "test", func(error) {})
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
m.RegisterRunningTool(conv, "exec-123")
|
||||
if got := m.ConversationIDForActiveMCPExecution("exec-123"); got != conv {
|
||||
t.Fatalf("got %q, want %q", got, conv)
|
||||
}
|
||||
if got := m.ConversationIDForActiveMCPExecution("missing"); got != "" {
|
||||
t.Fatalf("missing should be empty, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
func TestCancelTaskInvokesToolCancelerOnFullStop(t *testing.T) {
|
||||
tm := NewAgentTaskManager()
|
||||
called := false
|
||||
tm.SetToolCanceler(func(conversationID string) {
|
||||
if conversationID == "conv-1" {
|
||||
called = true
|
||||
}
|
||||
})
|
||||
|
||||
_, cancel := context.WithCancelCause(context.Background())
|
||||
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
|
||||
ok, err := tm.CancelTask("conv-1", ErrTaskCancelled)
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("CancelTask: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected tool canceler to be invoked on full task cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelTaskSkipsToolCancelerOnInterruptContinue(t *testing.T) {
|
||||
tm := NewAgentTaskManager()
|
||||
called := false
|
||||
tm.SetToolCanceler(func(conversationID string) {
|
||||
called = true
|
||||
})
|
||||
|
||||
_, cancel := context.WithCancelCause(context.Background())
|
||||
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
|
||||
ok, err := tm.CancelTask("conv-1", multiagent.ErrInterruptContinue)
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("CancelTask: ok=%v err=%v", ok, err)
|
||||
}
|
||||
if called {
|
||||
t.Fatal("tool canceler must not run for interrupt-continue")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelTaskDefaultCauseIsTaskCancelled(t *testing.T) {
|
||||
tm := NewAgentTaskManager()
|
||||
var gotCause error
|
||||
tm.SetToolCanceler(func(conversationID string) {
|
||||
if conversationID == "conv-2" {
|
||||
gotCause = ErrTaskCancelled
|
||||
}
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
if _, err := tm.StartTask("conv-2", "hello", cancel); err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
|
||||
if _, err := tm.CancelTask("conv-2", nil); err != nil {
|
||||
t.Fatalf("CancelTask: %v", err)
|
||||
}
|
||||
if !errors.Is(context.Cause(ctx), ErrTaskCancelled) {
|
||||
t.Fatalf("expected ErrTaskCancelled cause, got %v", context.Cause(ctx))
|
||||
}
|
||||
if gotCause != ErrTaskCancelled {
|
||||
t.Fatalf("expected tool canceler path for default cancel cause")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build windows
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RunCommandWS 交互式 PTY 终端依赖 Unix PTY(见 terminal_ws_unix.go);Windows 暂不支持。
|
||||
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": "Interactive WebSocket terminal is not supported on Windows; use POST /terminal/run or /terminal/run/stream instead.",
|
||||
})
|
||||
}
|
||||
@@ -814,6 +814,23 @@ func (m *ExternalMCPManager) CancelToolExecution(id string) bool {
|
||||
return m.CancelToolExecutionWithNote(id, "")
|
||||
}
|
||||
|
||||
// ActiveRunningExecutionIDs 返回当前进程内仍登记 cancel 的外部 MCP executionId 快照。
|
||||
func (m *ExternalMCPManager) ActiveRunningExecutionIDs() map[string]struct{} {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.runningCancels) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(m.runningCancels))
|
||||
for id := range m.runningCancels {
|
||||
out[id] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// updateStats 更新统计信息
|
||||
func (m *ExternalMCPManager) updateStats(toolName string, failed bool) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -11,7 +11,16 @@ type ToolRunRegistry interface {
|
||||
UnregisterRunningTool(conversationID, executionID string)
|
||||
}
|
||||
|
||||
// EinoExecuteRunRegistry 登记进行中的 Eino filesystem execute,供「中断并继续」终止 amass 等长命令。
|
||||
type EinoExecuteRunRegistry interface {
|
||||
RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc)
|
||||
UnregisterActiveEinoExecute(conversationID string)
|
||||
AbortActiveEinoExecute(conversationID, note string) bool
|
||||
TakeEinoExecuteAbortNote(conversationID string) string
|
||||
}
|
||||
|
||||
type toolRunRegistryCtxKey struct{}
|
||||
type einoExecuteRunRegistryCtxKey struct{}
|
||||
type mcpConversationIDCtxKey struct{}
|
||||
|
||||
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
||||
@@ -31,6 +40,23 @@ func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
|
||||
return v
|
||||
}
|
||||
|
||||
// WithEinoExecuteRunRegistry 将 Eino execute 取消登记器注入 ctx。
|
||||
func WithEinoExecuteRunRegistry(ctx context.Context, reg EinoExecuteRunRegistry) context.Context {
|
||||
if ctx == nil || reg == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, einoExecuteRunRegistryCtxKey{}, reg)
|
||||
}
|
||||
|
||||
// EinoExecuteRunRegistryFromContext 取出 Eino execute 登记器(无则 nil)。
|
||||
func EinoExecuteRunRegistryFromContext(ctx context.Context) EinoExecuteRunRegistry {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
v, _ := ctx.Value(einoExecuteRunRegistryCtxKey{}).(EinoExecuteRunRegistry)
|
||||
return v
|
||||
}
|
||||
|
||||
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
||||
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
||||
if ctx == nil {
|
||||
|
||||
+100
-16
@@ -921,9 +921,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
||||
return finalResult, executionID, nil
|
||||
}
|
||||
|
||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
// BeginToolExecution 创建 running 状态的执行记录,供 Eino 等非 CallTool 路径在工具开始时落库。
|
||||
func (s *Server) BeginToolExecution(toolName string, args map[string]interface{}) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -931,21 +930,73 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
executionID := uuid.New().String()
|
||||
now := time.Now()
|
||||
failed := invokeErr != nil
|
||||
exec := &ToolExecution{
|
||||
execution := &ToolExecution{
|
||||
ID: executionID,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
StartTime: now,
|
||||
EndTime: &now,
|
||||
Duration: 0,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.executions[executionID] = execution
|
||||
s.cleanupOldExecutions()
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.storage != nil {
|
||||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
return executionID
|
||||
}
|
||||
|
||||
// FinishToolExecution 完成先前 BeginToolExecution 创建的记录;executionID 为空时等同 RecordCompletedToolInvocation。
|
||||
func (s *Server) FinishToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
id := strings.TrimSpace(executionID)
|
||||
if id == "" {
|
||||
return s.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
failed := invokeErr != nil
|
||||
var finalResult *ToolResult
|
||||
|
||||
s.mu.Lock()
|
||||
exec, inMem := s.executions[id]
|
||||
if !inMem || exec == nil {
|
||||
exec = &ToolExecution{
|
||||
ID: id,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
StartTime: now,
|
||||
}
|
||||
s.executions[id] = exec
|
||||
} else if toolName != "" {
|
||||
exec.ToolName = toolName
|
||||
}
|
||||
if len(args) > 0 {
|
||||
exec.Arguments = args
|
||||
}
|
||||
exec.EndTime = &now
|
||||
if exec.StartTime.IsZero() {
|
||||
exec.StartTime = now
|
||||
}
|
||||
exec.Duration = now.Sub(exec.StartTime)
|
||||
|
||||
if failed {
|
||||
exec.Status = "failed"
|
||||
exec.Error = invokeErr.Error()
|
||||
st, msg := executionStatusAndMessage(invokeErr)
|
||||
exec.Status = st
|
||||
exec.Error = msg
|
||||
if strings.TrimSpace(resultText) != "" {
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||
exec.Result = finalResult
|
||||
}
|
||||
} else {
|
||||
exec.Status = "completed"
|
||||
@@ -953,15 +1004,31 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = "(无输出)"
|
||||
}
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||
exec.Result = finalResult
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.storage != nil {
|
||||
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
||||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
s.updateStats(toolName, failed)
|
||||
return executionID
|
||||
|
||||
s.updateStats(exec.ToolName, failed)
|
||||
|
||||
if s.storage != nil {
|
||||
s.mu.Lock()
|
||||
delete(s.executions, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
return s.FinishToolExecution("", toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
||||
@@ -1103,6 +1170,23 @@ func (s *Server) CancelToolExecution(id string) bool {
|
||||
return s.CancelToolExecutionWithNote(id, "")
|
||||
}
|
||||
|
||||
// ActiveRunningExecutionIDs 返回当前进程内仍登记 cancel 的 executionId 快照。
|
||||
func (s *Server) ActiveRunningExecutionIDs() map[string]struct{} {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.runningCancelsMu.Lock()
|
||||
defer s.runningCancelsMu.Unlock()
|
||||
if len(s.runningCancels) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]struct{}, len(s.runningCancels))
|
||||
for id := range s.runningCancels {
|
||||
out[id] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// initDefaultPrompts 初始化默认提示词模板
|
||||
func (s *Server) initDefaultPrompts() {
|
||||
s.mu.Lock()
|
||||
|
||||
@@ -199,6 +199,8 @@ type ToolExecution struct {
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
// ConversationID 仅 API 展示用(进行中的 Agent 任务),不写入 tool_executions 表。
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
}
|
||||
|
||||
// ToolStats 工具统计信息
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
staleRunningMinAge = 45 * time.Second
|
||||
staleRunningReconcileGap = 2 * time.Minute
|
||||
)
|
||||
|
||||
// ExecutionReconciler 在启动或运行期将无对应协程的 running 执行记录收尾为 cancelled。
|
||||
type ExecutionReconciler struct {
|
||||
db *database.DB
|
||||
mcpServer *mcp.Server
|
||||
externalMgr *mcp.ExternalMCPManager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExecutionReconciler creates a reconciler for orphaned MCP tool executions.
|
||||
func NewExecutionReconciler(db *database.DB, mcpServer *mcp.Server, externalMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ExecutionReconciler {
|
||||
return &ExecutionReconciler{
|
||||
db: db,
|
||||
mcpServer: mcpServer,
|
||||
externalMgr: externalMgr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ReconcileOnStartup marks every persisted running row as cancelled (safe right after process start).
|
||||
func (r *ExecutionReconciler) ReconcileOnStartup() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
n, err := r.db.CancelOrphanedRunningToolExecutions(now, "执行已中断(服务重启)")
|
||||
if err != nil {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("启动时清理孤儿 running 工具执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && r.logger != nil {
|
||||
r.logger.Info("启动时已收尾孤儿 running 工具执行记录", zap.Int64("count", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ExecutionReconciler) activeExecutionIDs() map[string]struct{} {
|
||||
ids := make(map[string]struct{})
|
||||
if r.mcpServer != nil {
|
||||
for id := range r.mcpServer.ActiveRunningExecutionIDs() {
|
||||
ids[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
if r.externalMgr != nil {
|
||||
for id := range r.externalMgr.ActiveRunningExecutionIDs() {
|
||||
ids[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// ReconcileStaleRunning finalizes running rows that are not tracked in-memory and older than staleRunningMinAge.
|
||||
func (r *ExecutionReconciler) ReconcileStaleRunning() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
n, err := r.db.FinalizeStaleRunningToolExecutions(now, staleRunningMinAge, r.activeExecutionIDs(), "执行已中断(会话已结束)")
|
||||
if err != nil {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("定期收尾 stale running 工具执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && r.logger != nil {
|
||||
r.logger.Info("已收尾 stale running 工具执行记录", zap.Int64("count", n))
|
||||
}
|
||||
}
|
||||
|
||||
// StartStaleRunningReconcileLoop periodically reconciles orphaned running tool executions.
|
||||
func StartStaleRunningReconcileLoop(r *ExecutionReconciler, logger *zap.Logger) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(staleRunningReconcileGap)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
r.ReconcileStaleRunning()
|
||||
if logger != nil {
|
||||
logger.Debug("monitor stale running reconcile tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestExecutionReconciler_ReconcileOnStartup(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||
ID: "run-1", ToolName: "hydra", Status: "running", StartTime: time.Now().Add(-time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
r := NewExecutionReconciler(db, mcp.NewServer(zap.NewNop()), nil, zap.NewNop())
|
||||
r.ReconcileOnStartup()
|
||||
|
||||
got, err := db.GetToolExecution("run-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetToolExecution: %v", err)
|
||||
}
|
||||
if got.Status != "cancelled" {
|
||||
t.Fatalf("expected cancelled after startup reconcile, got %s", got.Status)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const retentionPurgeInterval = time.Hour
|
||||
|
||||
// Service manages MCP tool execution monitor retention.
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewService creates a monitor retention service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{db: db, cfg: cfg, logger: logger}
|
||||
}
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
return s.cfg.Monitor.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
// PurgeExpired deletes tool execution rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Monitor.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.PurgeToolExecutionsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期 MCP 执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期 MCP 执行记录", zap.Int64("deleted", n), zap.Int("retention_days", days))
|
||||
}
|
||||
}
|
||||
|
||||
// StartRetentionLoop periodically purges expired tool execution rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(retentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("monitor retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
zero := 0
|
||||
svc := NewService(db, &config.Config{
|
||||
Monitor: config.MonitorConfig{RetentionDays: &zero},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if _, err := db.GetToolExecution("ancient"); err != nil {
|
||||
t.Fatalf("record should remain when retention_days=0: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServicePurgeExpired_deletesOldRows(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
days := 90
|
||||
svc := NewService(db, &config.Config{
|
||||
Monitor: config.MonitorConfig{RetentionDays: &days},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if _, err := db.GetToolExecution("ancient"); err == nil {
|
||||
t.Fatal("record should be purged when older than retention_days")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetentionDaysEffective_defaults(t *testing.T) {
|
||||
got := config.MonitorConfig{}.RetentionDaysEffective()
|
||||
if got != 90 {
|
||||
t.Fatalf("default = %d, want 90", got)
|
||||
}
|
||||
zero := 0
|
||||
cfg := config.MonitorConfig{RetentionDays: &zero}
|
||||
if cfg.RetentionDaysEffective() != 0 {
|
||||
t.Fatalf("zero = %d, want 0", cfg.RetentionDaysEffective())
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseTime(t *testing.T, value string) time.Time {
|
||||
t.Helper()
|
||||
parsed, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
t.Fatalf("parse time: %v", err)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
)
|
||||
|
||||
// InitADK configures global Eino ADK settings. Call once at process startup before
|
||||
// any ADK middleware or agents are created.
|
||||
func InitADK() error {
|
||||
if err := adk.SetLanguage(adk.LanguageChinese); err != nil {
|
||||
return fmt.Errorf("adk set language: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// continuationSessionMarker matches Cursor / IDE session-resume user injections.
|
||||
const continuationSessionMarker = "This session is being continued from a previous conversation"
|
||||
|
||||
// continuationUserDedupMiddleware keeps only the latest session-resume user message when
|
||||
// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes).
|
||||
type continuationUserDedupMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &continuationUserDedupMiddleware{logger: logger, phase: phase}
|
||||
}
|
||||
|
||||
func (m *continuationUserDedupMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
deduped, dropped := dedupContinuationUserMessages(state.Messages)
|
||||
if dropped == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.Info("eino continuation user messages deduplicated",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("dropped", dropped),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(deduped)),
|
||||
)
|
||||
}
|
||||
out := *state
|
||||
out.Messages = deduped
|
||||
return ctx, &out, nil
|
||||
}
|
||||
|
||||
func adkUserMessageText(msg adk.Message) string {
|
||||
if msg == nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
if s := strings.TrimSpace(msg.Content); s != "" {
|
||||
b.WriteString(s)
|
||||
}
|
||||
for _, part := range msg.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText {
|
||||
if s := strings.TrimSpace(part.Text); s != "" {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isContinuationUserMessage(msg adk.Message) bool {
|
||||
if msg == nil || msg.Role != schema.User {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(adkUserMessageText(msg), continuationSessionMarker)
|
||||
}
|
||||
|
||||
func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) {
|
||||
lastIdx := -1
|
||||
contCount := 0
|
||||
for i, msg := range msgs {
|
||||
if !isContinuationUserMessage(msg) {
|
||||
continue
|
||||
}
|
||||
contCount++
|
||||
lastIdx = i
|
||||
}
|
||||
if contCount <= 1 {
|
||||
return msgs, 0
|
||||
}
|
||||
out := make([]adk.Message, 0, len(msgs)-(contCount-1))
|
||||
dropped := 0
|
||||
for i, msg := range msgs {
|
||||
if isContinuationUserMessage(msg) && i != lastIdx {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out, dropped
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func continuationUser(text string) adk.Message {
|
||||
return &schema.Message{
|
||||
Role: schema.User,
|
||||
UserInputMultiContent: []schema.MessageInputPart{
|
||||
{Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text},
|
||||
{Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
continuationUser("summary old"),
|
||||
schema.UserMessage("real task"),
|
||||
continuationUser("summary new"),
|
||||
}
|
||||
out, dropped := dedupContinuationUserMessages(msgs)
|
||||
if dropped != 1 {
|
||||
t.Fatalf("dropped=%d want 1", dropped)
|
||||
}
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("len=%d want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" {
|
||||
t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0]))
|
||||
}
|
||||
if !strings.Contains(adkUserMessageText(out[1]), "summary new") {
|
||||
t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) {
|
||||
msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")}
|
||||
out, dropped := dedupContinuationUserMessages(msgs)
|
||||
if dropped != 0 || len(out) != 2 {
|
||||
t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinuationUserDedupMiddleware(t *testing.T) {
|
||||
mw := newContinuationUserDedupMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
continuationUser("old"),
|
||||
continuationUser("new"),
|
||||
schema.UserMessage("task"),
|
||||
}}
|
||||
_, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out.Messages) != 2 {
|
||||
t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages))
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@@ -90,7 +91,7 @@ type einoADKRunLoopArgs struct {
|
||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||
MCPExecutionBinder *MCPExecutionBinder
|
||||
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,execute/MCP 桥 Fire 时立即推送 tool_result(ADK 晚到经 toolResultSent 去重)。
|
||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
|
||||
DA adk.Agent
|
||||
@@ -196,6 +197,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
pendingByID[tc.ToolCallID] = tc
|
||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||
}
|
||||
markPendingWithMonitor := func(tc toolCallPendingInfo) {
|
||||
markPending(tc)
|
||||
beginEinoADKFilesystemToolMonitor(
|
||||
args.FilesystemMonitorAgent,
|
||||
args.FilesystemMonitorRecord,
|
||||
args.MCPExecutionBinder,
|
||||
tc.ToolCallID,
|
||||
tc.ToolName,
|
||||
)
|
||||
}
|
||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||
pendingMu.Lock()
|
||||
defer pendingMu.Unlock()
|
||||
@@ -288,6 +299,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
|
||||
var toolResultSent sync.Map // toolCallID -> struct{};ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
|
||||
tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
|
||||
// 仅由 ADK schema.Tool 事件调用;MCP/execute 桥在 reduction 前的 ToolInvokeNotify 不得推送 tool_result,
|
||||
// 否则全量输出会先占位并触发 toolResultSent 去重,导致 UI/监控展示与 agent 实际收到的截断正文不一致。
|
||||
if progress == nil {
|
||||
return
|
||||
}
|
||||
@@ -305,6 +318,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"isError": isErr,
|
||||
"result": content,
|
||||
"resultPreview": preview,
|
||||
"agentFacing": true, // 与 reduction 后送入 ChatModel 的正文一致,供前端展示
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": agentName,
|
||||
"einoRole": einoRoleTag(agentName),
|
||||
@@ -331,7 +345,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
toolCallID = tid
|
||||
}
|
||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, args.MCPExecutionBinder, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
||||
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
||||
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
||||
@@ -339,12 +353,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||
}
|
||||
if args.ToolInvokeNotify != nil {
|
||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
removePendingByID(strings.TrimSpace(toolCallID))
|
||||
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
||||
})
|
||||
}
|
||||
|
||||
if args.EinoCallbacks != nil {
|
||||
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
|
||||
@@ -383,6 +391,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
runner := adk.NewRunner(ctx, runnerCfg)
|
||||
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
|
||||
if checkPointID != "" {
|
||||
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
|
||||
}
|
||||
return runner.Run(ctx, runMsgs)
|
||||
}
|
||||
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
||||
if cpStore != nil && checkPointID != "" {
|
||||
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
||||
@@ -422,12 +436,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
if iter == nil {
|
||||
if checkPointID != "" {
|
||||
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
|
||||
} else {
|
||||
iter = runner.Run(ctx, msgs)
|
||||
}
|
||||
iter = startRunnerIter(msgs)
|
||||
}
|
||||
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
|
||||
handleRunErr := func(runErr error) error {
|
||||
if runErr == nil {
|
||||
return nil
|
||||
@@ -480,26 +491,67 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
return runErr
|
||||
}
|
||||
|
||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
|
||||
if runErr == nil {
|
||||
return false, nil
|
||||
}
|
||||
if !isEinoTransientRunError(runErr) {
|
||||
return false, handleRunErr(runErr)
|
||||
}
|
||||
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
|
||||
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
|
||||
)
|
||||
if retErr != nil {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient retry exhausted",
|
||||
zap.Error(retErr),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
|
||||
}
|
||||
return false, retErr
|
||||
}
|
||||
if !restarted {
|
||||
return false, nil
|
||||
}
|
||||
attemptNo := transientRetrier.attempt()
|
||||
maxAttempts := transientRetrier.maxAttempts()
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||
logger.Warn("eino transient error, retrying after backoff",
|
||||
zap.Error(runErr),
|
||||
zap.String("orchestration", orchMode))
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts),
|
||||
zap.Duration("backoff", backoff))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||
progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"error": runErr.Error(),
|
||||
"resumeKind": "trace_segment",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"attempt": attemptNo,
|
||||
"contextSource": string(ctxSource),
|
||||
})
|
||||
}
|
||||
return false, ErrTransientRetryContinue
|
||||
msgs = restartMsgs
|
||||
iter = startRunnerIter(msgs)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 仅在退避重试后真正收到数据/完成一步时清零,避免重启后首个无错 ADK 事件误把计数打回 0。
|
||||
confirmTransientRetryRecovery := func() {
|
||||
if transientRetrier.attempt() > 0 {
|
||||
transientRetrier.reset()
|
||||
}
|
||||
}
|
||||
|
||||
takePartial := func(runErr error) (*RunResult, error) {
|
||||
@@ -514,10 +566,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
|
||||
for {
|
||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
// iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending。
|
||||
ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
|
||||
if iterCtxErr != nil {
|
||||
flushAllPendingAsFailed(iterCtxErr)
|
||||
if progress != nil {
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
@@ -526,17 +578,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
progress("error", iterCtxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctx.Err())
|
||||
default:
|
||||
return takePartial(iterCtxErr)
|
||||
}
|
||||
|
||||
ev, ok := iter.Next()
|
||||
if !ok {
|
||||
// iter 结束并不总是“正常完成”:
|
||||
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
||||
@@ -583,9 +632,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||
restarted, retErr := maybeRetryTransientRun(ev.Err)
|
||||
if retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
if restarted {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ev.AgentName != "" && progress != nil {
|
||||
iterEinoAgent := orchestratorName
|
||||
@@ -648,34 +701,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
|
||||
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||
toolName := strings.TrimSpace(mv.ToolName)
|
||||
var toolBuf strings.Builder
|
||||
streamToolCallID := ""
|
||||
var toolStreamRecvErr error
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
toolStreamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
toolBuf.WriteString(chunk.Content)
|
||||
}
|
||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||
streamToolCallID = tid
|
||||
}
|
||||
}
|
||||
content := toolBuf.String()
|
||||
isErr := false
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||
}
|
||||
content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
|
||||
isErr := einoToolResultIsError(toolName, content)
|
||||
content = einoToolResultBody(content)
|
||||
if streamToolCallID != "" {
|
||||
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
||||
@@ -687,6 +715,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.String("tool", toolName))
|
||||
}
|
||||
if toolStreamRecvErr == nil {
|
||||
confirmTransientRetryRecovery()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -934,7 +965,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
|
||||
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||
@@ -951,9 +982,15 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||
restarted, retErr := maybeRetryTransientRun(streamRecvErr)
|
||||
if retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
if restarted {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
confirmTransientRetryRecovery()
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -963,7 +1000,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
@@ -1038,15 +1075,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
isErr := false
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||
}
|
||||
isErr := einoToolResultIsError(toolName, content)
|
||||
content = einoToolResultBody(content)
|
||||
|
||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
||||
}
|
||||
confirmTransientRetryRecovery()
|
||||
}
|
||||
|
||||
mcpIDsMu.Lock()
|
||||
@@ -1057,32 +1092,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
)
|
||||
if shouldEinoEmptyResponseContinue(out, emptyHint, len(runAccumulatedMsgs), baseAccumulatedCount) {
|
||||
if logger != nil {
|
||||
logger.Info("eino empty response, ending run segment for handler resume",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("traceMessages", len(runAccumulatedMsgs)))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_empty_response_continue", "会话已结束但未产生助手正文,正在基于轨迹自动续跑…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
return out, ErrEmptyResponseContinue
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func shouldEinoEmptyResponseContinue(out *RunResult, emptyHint string, accumulatedLen, baseCount int) bool {
|
||||
if out == nil || accumulatedLen <= baseCount {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(out.Response) == strings.TrimSpace(emptyHint)
|
||||
}
|
||||
|
||||
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||
if args != nil && args.ModelFacingTrace != nil {
|
||||
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||
@@ -1097,17 +1109,119 @@ func einoPartialRunLastOutputHint() string {
|
||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||
}
|
||||
|
||||
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
||||
// friendlyEinoExecuteInvokeTail 将 Eino execute 超时/中断/流异常转为简短提示。
|
||||
// 命令非零退出(ExecuteExitError)已有 exec 对齐的正文,不再追加「执行未正常结束」。
|
||||
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||
if invokeErr == nil {
|
||||
return ""
|
||||
}
|
||||
var exitErr *ExecuteExitError
|
||||
if errors.As(invokeErr, &exitErr) {
|
||||
return ""
|
||||
}
|
||||
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
return einoExecuteTimeoutUserHint()
|
||||
}
|
||||
if errors.Is(invokeErr, context.Canceled) {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(invokeErr.Error(), "shell inactivity timeout") {
|
||||
return ""
|
||||
}
|
||||
return "[执行未正常结束] " + invokeErr.Error()
|
||||
}
|
||||
|
||||
// einoToolResultIsError 统一判断 Eino 工具结果是否应标记为错误(与 MCP exec 的 IsError 对齐)。
|
||||
func einoToolResultIsError(toolName, content string) bool {
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
return true
|
||||
}
|
||||
if strings.TrimSpace(toolName) == "execute" && security.IsCommandFailureResult(content) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// einoToolResultBody 去掉工具错误前缀,返回展示/持久化正文。
|
||||
func einoToolResultBody(content string) string {
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
return strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
|
||||
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
|
||||
if iter == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
type nextRes struct {
|
||||
ev *adk.AgentEvent
|
||||
ok bool
|
||||
}
|
||||
ch := make(chan nextRes, 1)
|
||||
go func() {
|
||||
e, o := iter.Next()
|
||||
ch <- nextRes{e, o}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case res := <-ch:
|
||||
return res.ev, res.ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
|
||||
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
|
||||
if stream == nil {
|
||||
return "", "", nil
|
||||
}
|
||||
type streamMsg struct {
|
||||
chunk *schema.Message
|
||||
err error
|
||||
}
|
||||
recvCh := make(chan streamMsg, 8)
|
||||
go func() {
|
||||
defer close(recvCh)
|
||||
for {
|
||||
ch, rerr := stream.Recv()
|
||||
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
var buf strings.Builder
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return buf.String(), toolCallID, ctx.Err()
|
||||
case sm, open := <-recvCh:
|
||||
if !open {
|
||||
return buf.String(), toolCallID, nil
|
||||
}
|
||||
rerr := sm.err
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return buf.String(), toolCallID, nil
|
||||
}
|
||||
if rerr != nil {
|
||||
return buf.String(), toolCallID, rerr
|
||||
}
|
||||
chunk := sm.chunk
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
buf.WriteString(chunk.Content)
|
||||
}
|
||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||
toolCallID = tid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildEinoRunResultFromAccumulated(
|
||||
orchMode string,
|
||||
runAccumulatedMsgs []adk.Message,
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
|
||||
sw.Close()
|
||||
|
||||
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if content != "hello" {
|
||||
t.Fatalf("content=%q want hello", content)
|
||||
}
|
||||
if tid != "tc-1" {
|
||||
t.Fatalf("toolCallID=%q want tc-1", tid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
t.Cleanup(func() { sw.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
content, _, err := recvSchemaMessageStream(ctx, sr)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
want := errors.New("stream broken")
|
||||
_ = sw.Send(nil, want)
|
||||
sw.Close()
|
||||
|
||||
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("want %v, got %v", want, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
|
||||
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
|
||||
if err != nil || content != "" || tid != "" {
|
||||
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
_ = sw.Send(nil, io.EOF)
|
||||
sw.Close()
|
||||
|
||||
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if err != nil {
|
||||
t.Fatalf("EOF should not surface as error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask
|
||||
// and immediately before each ChatModel invocation pipeline completes.
|
||||
//
|
||||
// Order (best practice):
|
||||
// 1. system merge — accurate token count for summarization
|
||||
// 2. continuation user dedup — drop stale session-resume injections
|
||||
// 3. summarization
|
||||
// 4. orphan tool prune
|
||||
// 5. telemetry
|
||||
// 6. model-facing trace snapshot
|
||||
type einoChatModelTailConfig struct {
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
summarization adk.ChatModelAgentMiddleware
|
||||
modelName string
|
||||
conversationID string
|
||||
trace *modelFacingTraceHolder
|
||||
skipOrphanPruner bool
|
||||
skipTelemetry bool
|
||||
skipTrace bool
|
||||
}
|
||||
|
||||
func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware {
|
||||
handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase))
|
||||
handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase))
|
||||
if cfg.summarization != nil {
|
||||
handlers = append(handlers, cfg.summarization)
|
||||
}
|
||||
if !cfg.skipOrphanPruner {
|
||||
handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase))
|
||||
}
|
||||
if !cfg.skipTelemetry {
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
}
|
||||
if !cfg.skipTrace && cfg.trace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
}
|
||||
return handlers
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldEinoEmptyResponseContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
hint := "(empty hint)"
|
||||
out := &RunResult{Response: hint}
|
||||
if !shouldEinoEmptyResponseContinue(out, hint, 3, 1) {
|
||||
t.Fatal("expected continue when response is empty hint and trace grew")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(out, hint, 1, 1) {
|
||||
t.Fatal("expected no continue when trace did not grow")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(&RunResult{Response: "hello"}, hint, 3, 1) {
|
||||
t.Fatal("expected no continue when response has content")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(nil, hint, 3, 1) {
|
||||
t.Fatal("expected no continue for nil result")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type mockStreamingShellExitFail struct {
|
||||
output string
|
||||
code int
|
||||
}
|
||||
|
||||
func (m *mockStreamingShellExitFail) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||
go func() {
|
||||
defer outW.Close()
|
||||
if m.output != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
||||
}
|
||||
code := m.code
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{ExitCode: &code}, nil)
|
||||
}()
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_CommandFailureFormat(t *testing.T) {
|
||||
inner := &mockStreamingShellExitFail{
|
||||
output: "sudo: a password is required\n",
|
||||
code: 1,
|
||||
}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
var firedBody string
|
||||
var firedSuccess bool
|
||||
var firedErr error
|
||||
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
firedBody = content
|
||||
firedSuccess = success
|
||||
firedErr = invokeErr
|
||||
})
|
||||
wrap := &einoStreamingShellWrap{inner: inner, invokeNotify: notify}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var stream strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp != nil {
|
||||
stream.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
|
||||
if firedSuccess {
|
||||
t.Fatal("expected success=false")
|
||||
}
|
||||
var exitErr *ExecuteExitError
|
||||
if !errors.As(firedErr, &exitErr) || exitErr.Code != 1 {
|
||||
t.Fatalf("expected ExecuteExitError code 1, got %v", firedErr)
|
||||
}
|
||||
if !strings.HasPrefix(firedBody, einomcp.ToolErrorPrefix) {
|
||||
t.Fatalf("missing tool error prefix: %q", firedBody)
|
||||
}
|
||||
body := strings.TrimPrefix(firedBody, einomcp.ToolErrorPrefix)
|
||||
if body != security.FormatCommandFailureResult(1, "sudo: a password is required\n") {
|
||||
t.Fatalf("fire body = %q", body)
|
||||
}
|
||||
if !strings.Contains(stream.String(), "sudo:") {
|
||||
t.Fatalf("stream missing sudo output: %q", stream.String())
|
||||
}
|
||||
if strings.Contains(stream.String(), "command exited with non-zero") {
|
||||
t.Fatalf("stream has legacy noise: %q", stream.String())
|
||||
}
|
||||
if strings.Contains(stream.String(), "执行未正常结束") {
|
||||
t.Fatalf("stream has abnormal tail: %q", stream.String())
|
||||
}
|
||||
if !security.IsCommandFailureResult(stream.String()) {
|
||||
t.Fatalf("stream missing failure status line: %q", stream.String())
|
||||
}
|
||||
if tail := friendlyEinoExecuteInvokeTail(firedErr); tail != "" {
|
||||
t.Fatalf("unexpected invoke tail: %q", tail)
|
||||
}
|
||||
if !einoToolResultIsError("execute", firedBody) {
|
||||
t.Fatal("expected isError for execute failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFriendlyEinoExecuteInvokeTail(t *testing.T) {
|
||||
if friendlyEinoExecuteInvokeTail(&ExecuteExitError{Code: 1}) != "" {
|
||||
t.Fatal("exit error should not get abnormal tail")
|
||||
}
|
||||
if !strings.Contains(friendlyEinoExecuteInvokeTail(context.DeadlineExceeded), "Timed out") {
|
||||
t.Fatal("deadline should get timeout hint")
|
||||
}
|
||||
if friendlyEinoExecuteInvokeTail(errors.New("broken pipe")) == "" {
|
||||
t.Fatal("unexpected error should get tail")
|
||||
}
|
||||
}
|
||||
@@ -7,11 +7,25 @@ import (
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
)
|
||||
|
||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||
return func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||
if ag == nil || recorder == nil {
|
||||
// newEinoExecuteMonitorCallbacks 在 Eino filesystem execute 开始/结束时写入 MCP 监控库并 recorder(executionId),
|
||||
// 与 CallTool 路径一致,使监控页能展示「执行中」状态。
|
||||
func newEinoExecuteMonitorCallbacks(ag *agent.Agent, recorder einomcp.ExecutionRecorder) (
|
||||
begin func(toolCallID, command string) string,
|
||||
finish func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
|
||||
) {
|
||||
begin = func(toolCallID, command string) string {
|
||||
if ag == nil {
|
||||
return ""
|
||||
}
|
||||
args := map[string]interface{}{"command": command}
|
||||
id := ag.BeginLocalToolExecution("execute", args)
|
||||
if id != "" && recorder != nil {
|
||||
recorder(id, toolCallID)
|
||||
}
|
||||
return id
|
||||
}
|
||||
finish = func(executionID, toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||
if ag == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
@@ -23,9 +37,10 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
|
||||
}
|
||||
}
|
||||
args := map[string]interface{}{"command": command}
|
||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
||||
if id != "" {
|
||||
id := ag.FinishLocalToolExecution(executionID, "execute", args, stdout, err)
|
||||
if id != "" && recorder != nil && executionID == "" {
|
||||
recorder(id, toolCallID)
|
||||
}
|
||||
}
|
||||
return begin, finish
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
@@ -49,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||
//
|
||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||||
// run loop 收到 Fire 后立即推送 tool_result(toolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」。
|
||||
//
|
||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||
@@ -61,8 +63,11 @@ type einoStreamingShellWrap struct {
|
||||
outputChunk func(toolName, toolCallID, chunk string)
|
||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||
toolTimeoutMinutes int
|
||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
|
||||
// shellNoOutputTimeoutSec:无任何输出时的空闲秒数;0=关闭。
|
||||
shellNoOutputTimeoutSec int
|
||||
// beginMonitor 在 execute 开始时写入 running 状态;finishMonitor 在流结束后更新为 completed/failed。
|
||||
beginMonitor func(toolCallID, command string) string
|
||||
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error)
|
||||
}
|
||||
|
||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
@@ -74,43 +79,65 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
}
|
||||
req := *input
|
||||
userCmd := strings.TrimSpace(req.Command)
|
||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||
req.RunInBackendGround = true
|
||||
}
|
||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||
convID := mcp.MCPConversationIDFromContext(ctx)
|
||||
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||||
|
||||
execCtx := ctx
|
||||
var execCancel context.CancelFunc
|
||||
var monitorExecID string
|
||||
if w.beginMonitor != nil {
|
||||
monitorExecID = w.beginMonitor(tid, userCmd)
|
||||
}
|
||||
if monitorExecID != "" && convID != "" {
|
||||
if toolReg := mcp.ToolRunRegistryFromContext(ctx); toolReg != nil {
|
||||
toolReg.RegisterRunningTool(convID, monitorExecID)
|
||||
}
|
||||
}
|
||||
toolRunReg := mcp.ToolRunRegistryFromContext(ctx)
|
||||
|
||||
execCtx, execCancel := context.WithCancel(ctx)
|
||||
var timeoutCancel context.CancelFunc
|
||||
if w.toolTimeoutMinutes > 0 {
|
||||
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
}
|
||||
if execReg != nil && convID != "" {
|
||||
execReg.RegisterActiveEinoExecute(convID, execCancel)
|
||||
}
|
||||
|
||||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||
if err != nil {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||
if w.finishMonitor != nil {
|
||||
w.finishMonitor(monitorExecID, tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||
}
|
||||
if w.invokeNotify != nil && tid != "" {
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||||
}
|
||||
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, userCmd, "", false, err)
|
||||
if w.finishMonitor != nil {
|
||||
w.finishMonitor(monitorExecID, tid, userCmd, "", false, err)
|
||||
}
|
||||
if w.invokeNotify != nil && tid != "" {
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if sr == nil || w.invokeNotify == nil {
|
||||
if sr == nil {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
@@ -119,11 +146,35 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||
defer inner.Close()
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry, toolReg mcp.ToolRunRegistry, execID string, toolCallID string, noOutputSec int) {
|
||||
var innerCloseOnce sync.Once
|
||||
closeInner := func() {
|
||||
innerCloseOnce.Do(func() { inner.Close() })
|
||||
}
|
||||
defer closeInner()
|
||||
if timeoutCleanup != nil {
|
||||
defer timeoutCleanup()
|
||||
}
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if reg != nil && conversationID != "" {
|
||||
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||||
}
|
||||
if toolReg != nil && conversationID != "" && execID != "" {
|
||||
defer toolReg.UnregisterRunningTool(conversationID, execID)
|
||||
}
|
||||
|
||||
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-tctx.Done():
|
||||
closeInner()
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
defer close(stopWatch)
|
||||
|
||||
var sb strings.Builder
|
||||
success := true
|
||||
@@ -131,46 +182,103 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
exitCode := 0
|
||||
hasExitCode := false
|
||||
|
||||
idleWatch := security.NewShellInactivityWatch(noOutputSec)
|
||||
if idleWatch != nil {
|
||||
defer idleWatch.Stop()
|
||||
}
|
||||
|
||||
type execRecvMsg struct {
|
||||
resp *filesystem.ExecuteResponse
|
||||
err error
|
||||
}
|
||||
recvCh := make(chan execRecvMsg, 1)
|
||||
go func() {
|
||||
for {
|
||||
resp, rerr := inner.Recv()
|
||||
recvCh <- execRecvMsg{resp: resp, err: rerr}
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fireInactivityTimeout := func() {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||
msg := security.ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: msg}, nil)
|
||||
sb.WriteString(msg)
|
||||
if w.outputChunk != nil && toolCallID != "" {
|
||||
w.outputChunk("execute", toolCallID, msg)
|
||||
}
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
closeInner()
|
||||
}
|
||||
|
||||
recvLoop:
|
||||
for {
|
||||
resp, rerr := inner.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
var idleCh <-chan struct{}
|
||||
if idleWatch != nil {
|
||||
idleCh = idleWatch.Expired
|
||||
}
|
||||
if rerr != nil {
|
||||
success = false
|
||||
invokeErr = rerr
|
||||
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
||||
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||
invokeErr = context.DeadlineExceeded
|
||||
break
|
||||
select {
|
||||
case <-idleCh:
|
||||
fireInactivityTimeout()
|
||||
break recvLoop
|
||||
case msg := <-recvCh:
|
||||
rerr := msg.err
|
||||
resp := msg.resp
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break recvLoop
|
||||
}
|
||||
_ = outW.Send(nil, rerr)
|
||||
break
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.ExitCode != nil {
|
||||
hasExitCode = true
|
||||
exitCode = *resp.ExitCode
|
||||
}
|
||||
var appended string
|
||||
if resp.Output != "" {
|
||||
sb.WriteString(resp.Output)
|
||||
appended = resp.Output
|
||||
}
|
||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||
w.outputChunk("execute", tid, appended)
|
||||
}
|
||||
if outW.Send(resp, nil) {
|
||||
if rerr != nil {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||
break
|
||||
invokeErr = rerr
|
||||
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||
invokeErr = context.DeadlineExceeded
|
||||
break recvLoop
|
||||
}
|
||||
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||||
invokeErr = context.Canceled
|
||||
break recvLoop
|
||||
}
|
||||
_ = outW.Send(nil, rerr)
|
||||
break recvLoop
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.ExitCode != nil {
|
||||
hasExitCode = true
|
||||
exitCode = *resp.ExitCode
|
||||
continue
|
||||
}
|
||||
var appended string
|
||||
if resp.Output != "" {
|
||||
if security.IsLegacyShellExitNoise(resp.Output) {
|
||||
continue
|
||||
}
|
||||
if idleWatch != nil {
|
||||
idleWatch.Bump()
|
||||
}
|
||||
sb.WriteString(resp.Output)
|
||||
appended = resp.Output
|
||||
}
|
||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||
w.outputChunk("execute", toolCallID, appended)
|
||||
}
|
||||
if outW.Send(resp, nil) {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||
break recvLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if success && hasExitCode && exitCode != 0 {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
||||
invokeErr = &ExecuteExitError{Code: exitCode}
|
||||
}
|
||||
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||
@@ -178,6 +286,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
success = false
|
||||
invokeErr = context.DeadlineExceeded
|
||||
}
|
||||
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
|
||||
partialStreamed := sb.String()
|
||||
var abortNote string
|
||||
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
|
||||
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
|
||||
abortNote = note
|
||||
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
|
||||
sb.Reset()
|
||||
sb.WriteString(merged)
|
||||
if invokeErr == nil {
|
||||
success = false
|
||||
invokeErr = context.Canceled
|
||||
}
|
||||
}
|
||||
}
|
||||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
@@ -187,12 +310,32 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
}
|
||||
sb.WriteString(hint)
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
||||
// 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
|
||||
if partialStreamed != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
|
||||
} else if text := strings.TrimSpace(sb.String()); text != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||||
}
|
||||
}
|
||||
rawOutput := sb.String()
|
||||
fireBody := rawOutput
|
||||
if !success && hasExitCode && exitCode != 0 {
|
||||
statusLine := security.ExecuteFailureStatusLine(exitCode)
|
||||
if !strings.Contains(rawOutput, "命令执行失败:") {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: statusLine}, nil)
|
||||
sb.WriteString(statusLine)
|
||||
}
|
||||
fireBody = einomcp.ToolErrorPrefix + security.FormatCommandFailureResult(exitCode, rawOutput)
|
||||
}
|
||||
if w.finishMonitor != nil {
|
||||
w.finishMonitor(execID, toolCallID, command, sb.String(), success, invokeErr)
|
||||
}
|
||||
if w.invokeNotify != nil {
|
||||
w.invokeNotify.Fire(toolCallID, "execute", agentTag, success, fireBody, invokeErr)
|
||||
}
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||
outW.Close()
|
||||
}(sr, userCmd, execCancel, execCtx)
|
||||
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg, toolRunReg, monitorExecID, tid, w.shellNoOutputTimeoutSec)
|
||||
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@@ -18,9 +19,15 @@ type mockStreamingShell struct {
|
||||
immediateErr error
|
||||
recvErr error
|
||||
output string
|
||||
called bool
|
||||
lastCommand string
|
||||
}
|
||||
|
||||
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
m.called = true
|
||||
if input != nil {
|
||||
m.lastCommand = input.Command
|
||||
}
|
||||
if m.immediateErr != nil {
|
||||
return nil, m.immediateErr
|
||||
}
|
||||
@@ -37,6 +44,129 @@ func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesy
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_PreparesNonInteractiveCommand(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "ok\n"}
|
||||
wrap := &einoStreamingShellWrap{inner: inner}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo ok"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
for {
|
||||
_, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(inner.lastCommand, "PYTHONUNBUFFERED=1") {
|
||||
t.Fatalf("missing python unbuffer in inner command: %q", inner.lastCommand)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_NoOutputTimeout(t *testing.T) {
|
||||
inner := &mockStreamingShellHanging{}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
var fired string
|
||||
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
fired = content
|
||||
})
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
shellNoOutputTimeoutSec: 1,
|
||||
}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp != nil {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if !inner.called {
|
||||
t.Fatal("inner shell should run (no command blacklist)")
|
||||
}
|
||||
out := got.String()
|
||||
if !strings.Contains(out, "没有新的输出") && !strings.Contains(out, "no new output") {
|
||||
t.Fatalf("expected inactivity timeout message, got: %q notify=%q", out, fired)
|
||||
}
|
||||
}
|
||||
|
||||
type mockStreamingShellPartialThenHang struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (m *mockStreamingShellPartialThenHang) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
m.called = true
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||
go func() {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: "[sudo] password:\n"}, nil)
|
||||
<-ctx.Done()
|
||||
outW.Close()
|
||||
}()
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_InactivityAfterPartialOutput(t *testing.T) {
|
||||
inner := &mockStreamingShellPartialThenHang{}
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
shellNoOutputTimeoutSec: 1,
|
||||
}
|
||||
start := time.Now()
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp != nil {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if time.Since(start) > 5*time.Second {
|
||||
t.Fatalf("expected inactivity timeout ~1s, took %v", time.Since(start))
|
||||
}
|
||||
if !strings.Contains(got.String(), "没有新的输出") && !strings.Contains(got.String(), "no new output") {
|
||||
t.Fatalf("expected inactivity message, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
type mockStreamingShellHanging struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (m *mockStreamingShellHanging) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
m.called = true
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
outW.Close()
|
||||
}()
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
||||
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
@@ -122,6 +252,94 @@ func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "100\n"}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
var firedContent string
|
||||
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
firedContent = content
|
||||
})
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
toolTimeoutMinutes: 60,
|
||||
}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("unexpected stream error: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(got.String(), "100") {
|
||||
t.Fatalf("stream output = %q, want contains 100", got.String())
|
||||
}
|
||||
if !strings.Contains(firedContent, "100") {
|
||||
t.Fatalf("notify content = %q, want contains 100", firedContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
}
|
||||
reg := &abortNoteTestRegistry{note: "改成20次"}
|
||||
ctx := mcp.WithEinoExecuteRunRegistry(
|
||||
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
|
||||
reg,
|
||||
)
|
||||
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("unexpected stream error: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
out := got.String()
|
||||
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
|
||||
t.Fatalf("stream duplicated stdout: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "改成20次") {
|
||||
t.Fatalf("stream missing abort note: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
type abortNoteTestRegistry struct {
|
||||
note string
|
||||
}
|
||||
|
||||
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
|
||||
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
|
||||
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
|
||||
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
|
||||
|
||||
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
|
||||
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
|
||||
wrap := &einoStreamingShellWrap{inner: inner}
|
||||
|
||||
@@ -63,10 +63,43 @@ func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
// beginEinoADKFilesystemToolMonitor 在 Eino ADK filesystem 工具开始调用时写入 running 状态。
|
||||
func beginEinoADKFilesystemToolMonitor(
|
||||
ag *agent.Agent,
|
||||
rec einomcp.ExecutionRecorder,
|
||||
binder *MCPExecutionBinder,
|
||||
toolCallID, toolName string,
|
||||
) {
|
||||
if ag == nil || rec == nil {
|
||||
return
|
||||
}
|
||||
name := strings.TrimSpace(toolName)
|
||||
if name == "" || strings.EqualFold(name, "execute") {
|
||||
return
|
||||
}
|
||||
if !isBuiltinEinoADKFilesystemToolName(name) {
|
||||
return
|
||||
}
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
if tid == "" {
|
||||
return
|
||||
}
|
||||
storedName := "eino_fs::" + strings.ToLower(name)
|
||||
id := ag.BeginLocalToolExecution(storedName, map[string]interface{}{})
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
rec(id, tid)
|
||||
if binder != nil {
|
||||
binder.Bind(tid, id)
|
||||
}
|
||||
}
|
||||
|
||||
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
||||
func recordEinoADKFilesystemToolMonitor(
|
||||
ag *agent.Agent,
|
||||
rec einomcp.ExecutionRecorder,
|
||||
binder *MCPExecutionBinder,
|
||||
toolName string,
|
||||
toolCallID string,
|
||||
msgs []adk.Message,
|
||||
@@ -94,8 +127,12 @@ func recordEinoADKFilesystemToolMonitor(
|
||||
invErr = errors.New(t)
|
||||
}
|
||||
}
|
||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
||||
if id != "" {
|
||||
execID := ""
|
||||
if binder != nil {
|
||||
execID = binder.ExecutionID(toolCallID)
|
||||
}
|
||||
id := ag.FinishLocalToolExecution(execID, storedName, args, resultText, invErr)
|
||||
if id != "" && execID == "" {
|
||||
rec(id, toolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,17 +243,14 @@ func prependEinoMiddlewares(
|
||||
return outTools, extraHandlers, toolSearchActive, nil
|
||||
}
|
||||
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
if ma == nil {
|
||||
return "", nil, nil
|
||||
return "", nil
|
||||
}
|
||||
mw := ma.EinoMiddleware
|
||||
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
||||
outputKey = k
|
||||
}
|
||||
if mw.DeepModelRetryMaxRetries > 0 {
|
||||
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
|
||||
}
|
||||
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
||||
if prefix != "" {
|
||||
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
||||
@@ -274,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
|
||||
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
||||
}
|
||||
}
|
||||
return outputKey, retry, taskDesc
|
||||
return outputKey, taskDesc
|
||||
}
|
||||
|
||||
@@ -94,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
if a.SkillMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.SkillMiddleware)
|
||||
}
|
||||
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
||||
// 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
|
||||
if a.AppCfg != nil {
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||
if sumErr != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||
}
|
||||
execHandlers = append(execHandlers, sumMw)
|
||||
}
|
||||
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
|
||||
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
|
||||
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
if a.ModelFacingTrace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
||||
execHandlers = append(execHandlers, capMw)
|
||||
}
|
||||
execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
|
||||
logger: a.Logger,
|
||||
phase: "plan_execute_executor",
|
||||
summarization: sumMw,
|
||||
modelName: a.ModelName,
|
||||
conversationID: a.ConversationID,
|
||||
trace: a.ModelFacingTrace,
|
||||
})
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
|
||||
@@ -81,7 +81,7 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
|
||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||
einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
|
||||
mainDefs := ag.ToolsForRole(roleTools)
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
||||
if err != nil {
|
||||
@@ -136,7 +136,7 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
if einoSkillMW != nil {
|
||||
if einoFSTools && einoLoc != nil {
|
||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||
}
|
||||
@@ -144,13 +144,14 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
handlers = append(handlers, einoSkillMW)
|
||||
}
|
||||
handlers = append(handlers, mainSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "eino_single",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
maxIter := agentMaxIterations(appCfg)
|
||||
|
||||
@@ -183,18 +184,16 @@ func RunEinoSingleChatModelAgent(
|
||||
Name: einoSingleAgentName,
|
||||
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
|
||||
Instruction: ins,
|
||||
GenModelInput: literalInstructionGenModelInput,
|
||||
Model: mainModel,
|
||||
ToolsConfig: mainToolsCfg,
|
||||
MaxIterations: maxIter,
|
||||
Handlers: handlers,
|
||||
}
|
||||
outKey, modelRetry, _ := deepExtrasFromConfig(ma)
|
||||
outKey, _ := deepExtrasFromConfig(ma)
|
||||
if outKey != "" {
|
||||
chatCfg.OutputKey = outKey
|
||||
}
|
||||
if modelRetry != nil {
|
||||
chatCfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
|
||||
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -81,8 +82,10 @@ func subAgentFilesystemMiddleware(
|
||||
loc *localbk.Local,
|
||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||
einoAgentName string,
|
||||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
|
||||
beginMonitor func(toolCallID, command string) string,
|
||||
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
|
||||
toolTimeoutMinutes int,
|
||||
shellNoOutputTimeoutSec int,
|
||||
outputChunk func(toolName, toolCallID, chunk string),
|
||||
) (adk.ChatModelAgentMiddleware, error) {
|
||||
if loc == nil {
|
||||
@@ -91,12 +94,14 @@ func subAgentFilesystemMiddleware(
|
||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||
Backend: loc,
|
||||
StreamingShell: &einoStreamingShellWrap{
|
||||
inner: loc,
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
outputChunk: outputChunk,
|
||||
recordMonitor: recordMonitor,
|
||||
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||
inner: security.NewEinoStreamingShell(),
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
outputChunk: outputChunk,
|
||||
beginMonitor: beginMonitor,
|
||||
finishMonitor: finishMonitor,
|
||||
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||
shellNoOutputTimeoutSec: shellNoOutputTimeoutSec,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -108,3 +113,18 @@ func agentToolTimeoutMinutes(cfg *config.Config) int {
|
||||
}
|
||||
return cfg.Agent.ToolTimeoutMinutes
|
||||
}
|
||||
|
||||
// agentShellNoOutputTimeoutSeconds:0=默认 300s(5 分钟);-1=关闭;>0=自定义秒数。
|
||||
func agentShellNoOutputTimeoutSeconds(cfg *config.Config) int {
|
||||
if cfg == nil {
|
||||
return 300
|
||||
}
|
||||
v := cfg.Agent.ShellNoOutputTimeoutSeconds
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
if v == 0 {
|
||||
return 300
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const defaultSummarizationRetryMax = 3
|
||||
|
||||
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||
|
||||
@@ -97,10 +95,8 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
}
|
||||
|
||||
retryMax := defaultSummarizationRetryMax
|
||||
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 {
|
||||
retryMax = mwCfg.SummarizationRetryMaxAttempts
|
||||
}
|
||||
retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
|
||||
retryMax := retryPolicy.maxAttempts
|
||||
|
||||
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
|
||||
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
|
||||
@@ -137,13 +133,14 @@ func newEinoSummarizationMiddleware(
|
||||
Retry: &summarization.RetryConfig{
|
||||
MaxRetries: &retryMax,
|
||||
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
|
||||
if err != nil && logger != nil {
|
||||
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain",
|
||||
retry := isEinoTransientRunError(err)
|
||||
if retry && logger != nil {
|
||||
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
|
||||
zap.Error(err),
|
||||
zap.Int("max_retries", retryMax),
|
||||
)
|
||||
}
|
||||
return err != nil
|
||||
return retry
|
||||
},
|
||||
},
|
||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||
@@ -153,6 +150,7 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
if appCfg != nil {
|
||||
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
||||
out = refreshUserVerbatimAnchorInMessages(out, db, conversationID, appCfg.MultiAgent.UserVerbatimAnchorMaxRunesEffective(), logger)
|
||||
}
|
||||
return out, nil
|
||||
},
|
||||
@@ -260,17 +258,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
nonSystem = append(nonSystem, msg)
|
||||
}
|
||||
|
||||
mergedSystem := mergeCollectedSystemMessages(systemMsgs)
|
||||
|
||||
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rounds := splitMessagesIntoRounds(nonSystem)
|
||||
if len(rounds) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
@@ -322,8 +322,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
}
|
||||
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs))
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
out = append(out, selectedMsgs...)
|
||||
return out, nil
|
||||
@@ -414,6 +414,36 @@ func writeSummarizationTranscript(path string, msgs []adk.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshUserVerbatimAnchorInMessages 压缩后从 messages 表刷新 system 中的用户原文锚点。
|
||||
func refreshUserVerbatimAnchorInMessages(msgs []adk.Message, db *database.DB, conversationID string, maxRunes int, logger *zap.Logger) []adk.Message {
|
||||
if maxRunes < 0 || db == nil {
|
||||
return msgs
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return msgs
|
||||
}
|
||||
rows, err := db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("summarization: 刷新用户原文锚点失败",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
block := project.BuildUserVerbatimAnchorBlockFromMessages(rows, maxRunes)
|
||||
if block == "" {
|
||||
return msgs
|
||||
}
|
||||
out := project.RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||
if logger != nil {
|
||||
logger.Info("summarization: 已刷新用户原文锚点", zap.String("conversationId", conversationID))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
@@ -192,8 +192,8 @@ func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
|
||||
if len(out) < 2 {
|
||||
t.Fatalf("output too short: %d", len(out))
|
||||
}
|
||||
if out[0] != sys {
|
||||
t.Fatalf("first message must be system")
|
||||
if out[0].Role != schema.System || out[0].Content != "sys" {
|
||||
t.Fatalf("first message must be system sys, got %s: %q", out[0].Role, out[0].Content)
|
||||
}
|
||||
if out[1] != summary {
|
||||
t.Fatalf("second message must be summary")
|
||||
@@ -293,12 +293,12 @@ func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out) != 2 || out[0] != sys || out[1] != summary {
|
||||
if len(out) != 2 || out[0].Role != schema.System || out[0].Content != "sys" || out[1] != summary {
|
||||
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
func TestSummarizeFinalize_MergesSystemMessages(t *testing.T) {
|
||||
sys1 := schema.SystemMessage("sys1")
|
||||
sys2 := schema.SystemMessage("sys2")
|
||||
summary := schema.AssistantMessage("s", nil)
|
||||
@@ -321,10 +321,13 @@ func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
for _, m := range out {
|
||||
if m != nil && m.Role == schema.System {
|
||||
systemCount++
|
||||
if got := m.Content; got != "sys1\n\nsys2" {
|
||||
t.Fatalf("unexpected merged system content: %q", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
if systemCount != 2 {
|
||||
t.Fatalf("want 2 system messages retained, got %d", systemCount)
|
||||
if systemCount != 1 {
|
||||
t.Fatalf("want 1 merged system message, got %d", systemCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,6 +381,12 @@ func TestWriteSummarizationTranscript(t *testing.T) {
|
||||
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
|
||||
t.Fatalf("missing tool round: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, `"name":"stub_tool"`) || !strings.Contains(text, `"arguments":"{}"`) {
|
||||
t.Fatalf("missing tool name/arguments: %q", text)
|
||||
}
|
||||
if strings.Contains(text, "tool_call_id") || strings.Contains(text, `"id":"tc1"`) {
|
||||
t.Fatalf("transcript should omit tool_call_id: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
@@ -400,9 +409,9 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
"需要写入请使用 upsert_project_fact。",
|
||||
project.FactIndexSectionEndMarker,
|
||||
"",
|
||||
"# Skills System",
|
||||
"**How to Use Skills**",
|
||||
"Remember: Skills make you more capable",
|
||||
transcriptSkillsSystemMarker,
|
||||
"**如何使用 Skill(技能)(渐进式展示):**",
|
||||
"记住:Skill 让你更加强大和稳定",
|
||||
}, "\n")
|
||||
|
||||
out := sanitizeSystemContentForTranscript(system)
|
||||
@@ -412,7 +421,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
|
||||
t.Fatalf("static persona should be stripped: %q", out)
|
||||
}
|
||||
if strings.Contains(out, "# Skills System") || strings.Contains(out, "How to Use Skills") {
|
||||
if strings.Contains(out, transcriptSkillsSystemMarker) || strings.Contains(out, "如何使用 Skill") {
|
||||
t.Fatalf("skills boilerplate should be stripped: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, transcriptStaticSystemOmitNote) {
|
||||
@@ -426,7 +435,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n# Skills System\nboiler"),
|
||||
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n" + transcriptSkillsSystemMarker + "\nboiler"),
|
||||
schema.UserMessage("hello"),
|
||||
schema.AssistantMessage("reply", nil),
|
||||
}
|
||||
|
||||
@@ -20,9 +20,16 @@ const (
|
||||
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
|
||||
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
|
||||
transcriptPersonaStartMarker = "你是CyberStrikeAI"
|
||||
transcriptSkillsSystemMarker = "# Skills System"
|
||||
// ADK LanguageChinese injects skill middleware prompt with this header (see eino adk/middlewares/skill/prompt.go).
|
||||
transcriptSkillsSystemMarker = "# Skill 系统"
|
||||
transcriptSkillsSystemMarkerEnglish = "# Skills System"
|
||||
)
|
||||
|
||||
type transcriptToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
|
||||
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
|
||||
func formatSummarizationTranscript(msgs []adk.Message) string {
|
||||
@@ -81,13 +88,23 @@ func stripToolNamesIndexFromSystem(s string) string {
|
||||
}
|
||||
|
||||
func stripSkillsSystemBoilerplate(s string) string {
|
||||
idx := strings.Index(s, transcriptSkillsSystemMarker)
|
||||
idx := indexFirstSubstring(s, transcriptSkillsSystemMarker, transcriptSkillsSystemMarkerEnglish)
|
||||
if idx < 0 {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
return strings.TrimSpace(s[:idx])
|
||||
}
|
||||
|
||||
func indexFirstSubstring(s string, markers ...string) int {
|
||||
first := -1
|
||||
for _, m := range markers {
|
||||
if i := strings.Index(s, m); i >= 0 && (first < 0 || i < first) {
|
||||
first = i
|
||||
}
|
||||
}
|
||||
return first
|
||||
}
|
||||
|
||||
func extractProjectBlackboardSection(s string) string {
|
||||
start := strings.Index(s, project.FactIndexSectionStartMarker)
|
||||
if start < 0 {
|
||||
@@ -138,15 +155,21 @@ func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
|
||||
}
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||||
if b, err := sonic.Marshal(formatTranscriptToolCalls(msg.ToolCalls)); err == nil {
|
||||
sb.WriteString("tool_calls: ")
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
if msg.ToolCallID != "" {
|
||||
sb.WriteString("tool_call_id: ")
|
||||
sb.WriteString(msg.ToolCallID)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
func formatTranscriptToolCalls(calls []schema.ToolCall) []transcriptToolCall {
|
||||
out := make([]transcriptToolCall, 0, len(calls))
|
||||
for _, tc := range calls {
|
||||
out = append(out, transcriptToolCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -46,6 +46,10 @@ func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, too
|
||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||
}
|
||||
if s := strings.TrimSpace(injectShellToolGuidance("", names)); s != "" {
|
||||
sb.WriteString(s)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
if s := strings.TrimSpace(instruction); s != "" {
|
||||
sb.WriteString(s)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,8 +18,9 @@ const (
|
||||
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||
)
|
||||
|
||||
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||
// isEinoTransientRunError 是 Eino 运行期「可退避重试 vs 直接失败」的唯一判据。
|
||||
// 429/5xx/网络抖动等返回 true;用户取消、超时、迭代上限、鉴权失败等返回 false。
|
||||
// 其它模块(run loop、summarization 等)只调用本函数,不在别处维护平行规则。
|
||||
func isEinoTransientRunError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -60,6 +62,7 @@ func isEinoTransientRunError(err error) bool {
|
||||
"dial tcp",
|
||||
"tls handshake timeout",
|
||||
"stream error",
|
||||
"goaway", // http2: server sent GOAWAY and closed the connection
|
||||
"unexpected eof",
|
||||
`": eof`, // net/http: Post "url": EOF (often wraps io.EOF)
|
||||
"unexpected end of json",
|
||||
@@ -78,6 +81,71 @@ func isEinoTransientRunError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type einoTransientRunRetryPolicy struct {
|
||||
maxAttempts int
|
||||
maxBackoff time.Duration
|
||||
}
|
||||
|
||||
func einoTransientRunRetryPolicyFromArgs(args *einoADKRunLoopArgs) einoTransientRunRetryPolicy {
|
||||
return einoTransientRunRetryPolicy{
|
||||
maxAttempts: einoRunRetryMaxAttempts(args),
|
||||
maxBackoff: einoRunRetryMaxBackoff(args),
|
||||
}
|
||||
}
|
||||
|
||||
func einoTransientRunRetryPolicyFromMW(mw *config.MultiAgentEinoMiddlewareConfig) einoTransientRunRetryPolicy {
|
||||
maxBackoff := defaultEinoRunRetryMaxBackoff
|
||||
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
|
||||
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRunRetryPolicy{
|
||||
maxAttempts: RunRetryMaxAttemptsFromConfig(mw),
|
||||
maxBackoff: maxBackoff,
|
||||
}
|
||||
}
|
||||
|
||||
// einoTransientRunRetrier 在 run loop 内对临时错误做指数退避并重启 Runner(唯一重试执行层)。
|
||||
type einoTransientRunRetrier struct {
|
||||
policy einoTransientRunRetryPolicy
|
||||
attempts int
|
||||
}
|
||||
|
||||
func newEinoTransientRunRetrier(policy einoTransientRunRetryPolicy) *einoTransientRunRetrier {
|
||||
return &einoTransientRunRetrier{policy: policy}
|
||||
}
|
||||
|
||||
// tryRetry 对临时错误退避后返回重启消息;次数用尽返回 exhausted 错误。
|
||||
func (r *einoTransientRunRetrier) tryRetry(
|
||||
ctx context.Context,
|
||||
runErr error,
|
||||
args *einoADKRunLoopArgs,
|
||||
baseMsgs, accumulated []adk.Message,
|
||||
baseCount int,
|
||||
) (restarted bool, restartMsgs []adk.Message, ctxSource einoRunRestartContextSource, backoff time.Duration, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
return false, nil, "", 0, runErr
|
||||
}
|
||||
r.attempts++
|
||||
if r.attempts > r.policy.maxAttempts {
|
||||
return false, nil, "", 0, fmt.Errorf("transient retry exhausted after %d attempts: %w", r.policy.maxAttempts, runErr)
|
||||
}
|
||||
backoff = einoTransientRetryBackoff(r.attempts-1, r.policy.maxBackoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, nil, "", 0, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
restartMsgs, ctxSource = einoMessagesForRunRestart(args, baseMsgs, accumulated, baseCount)
|
||||
return true, restartMsgs, ctxSource, backoff, nil
|
||||
}
|
||||
|
||||
func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
|
||||
|
||||
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
||||
|
||||
// reset 在退避重试后成功推进(流/消息完整接收)时清零计数,使后续临时错误从第 1 次退避重新开始。
|
||||
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
||||
|
||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||
return args.RunRetryMaxAttempts
|
||||
@@ -85,7 +153,7 @@ func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||
// RunRetryMaxAttemptsFromConfig 与 eino_middleware.run_retry_max_attempts 一致。
|
||||
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||
return mw.RunRetryMaxAttempts
|
||||
@@ -93,15 +161,6 @@ func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) in
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||
max := defaultEinoRunRetryMaxBackoff
|
||||
if maxBackoffSec > 0 {
|
||||
max = time.Duration(maxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRetryBackoff(attempt, max)
|
||||
}
|
||||
|
||||
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||
@@ -122,37 +181,35 @@ const (
|
||||
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||
// modelFacingTrace includes prior Instruction system message(s); genModelInput will prepend again.
|
||||
return stripADKSystemMessages(trace), einoRestartContextModelTrace
|
||||
}
|
||||
if len(accumulated) > baseCount {
|
||||
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||
return stripADKSystemMessages(accumulated), einoRestartContextAccumulated
|
||||
}
|
||||
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||
}
|
||||
|
||||
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
|
||||
// adkMessagesHasUserContent reports whether the conversation tail is already a user turn
|
||||
// with the given content. Only the last message counts: matching text in an earlier round
|
||||
// (e.g. user repeats the same prompt after an assistant reply) must not suppress appending
|
||||
// the new user turn — Claude 4.6+ rejects requests whose final message is assistant.
|
||||
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
||||
want = strings.TrimSpace(want)
|
||||
if want == "" {
|
||||
return true
|
||||
}
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.User {
|
||||
return strings.TrimSpace(m.Content) == want
|
||||
}
|
||||
if m.Role == schema.Assistant || m.Role == schema.Tool {
|
||||
continue
|
||||
}
|
||||
break
|
||||
if len(msgs) == 0 {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
last := msgs[len(msgs)-1]
|
||||
if last == nil || last.Role != schema.User {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(last.Content) == want
|
||||
}
|
||||
|
||||
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
|
||||
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当尾部已是相同 user 句)。
|
||||
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
||||
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
||||
return msgs
|
||||
|
||||
@@ -27,6 +27,7 @@ func TestIsEinoTransientRunError(t *testing.T) {
|
||||
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||
{"http2 goaway", errors.New("failed to receive stream chunk: error, http2: server sent GOAWAY and closed the connection; LastStreamID=791, ErrCode=NO_ERROR"), true},
|
||||
{"unexpected eof", errors.New("unexpected EOF"), true},
|
||||
{"503", errors.New("upstream returned 503"), true},
|
||||
{"iteration limit", errors.New("max iteration reached"), false},
|
||||
@@ -90,6 +91,46 @@ func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoTransientRunRetrierReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
|
||||
r.attempts = 3
|
||||
r.reset()
|
||||
if r.attempt() != 0 {
|
||||
t.Fatalf("after reset: attempt=%d, want 0", r.attempt())
|
||||
}
|
||||
// 重置后下一次退避应从 2s 起算(attempt index 0)。
|
||||
if got := einoTransientRetryBackoff(r.attempt(), r.policy.maxBackoff); got != 2*time.Second {
|
||||
t.Fatalf("backoff after reset: got %v, want 2s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoTransientRunRetrierConsecutiveFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
|
||||
ctx := context.Background()
|
||||
runErr := errors.New("internal server error")
|
||||
args := &einoADKRunLoopArgs{}
|
||||
base := []adk.Message{schema.UserMessage("hi")}
|
||||
|
||||
for want := 1; want <= 3; want++ {
|
||||
restarted, _, _, _, err := r.tryRetry(ctx, runErr, args, base, nil, len(base))
|
||||
if err != nil {
|
||||
t.Fatalf("tryRetry attempt %d: %v", want, err)
|
||||
}
|
||||
if !restarted {
|
||||
t.Fatalf("tryRetry attempt %d: want restarted", want)
|
||||
}
|
||||
if got := r.attempt(); got != want {
|
||||
t.Fatalf("after failure %d: attempt=%d, want %d", want, got, want)
|
||||
}
|
||||
}
|
||||
r.reset()
|
||||
if r.attempt() != 0 {
|
||||
t.Fatalf("after successful recovery reset: attempt=%d, want 0", r.attempt())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||
@@ -103,9 +144,17 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrTransientRetryContinue(t *testing.T) {
|
||||
func TestAppendUserMessageIfNeeded_repeatPromptAfterAssistant(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||
t.Fatal("sentinel should match")
|
||||
msgs := []adk.Message{
|
||||
schema.UserMessage("扫描 example.com"),
|
||||
schema.AssistantMessage("开始扫描...", nil),
|
||||
}
|
||||
out := appendUserMessageIfNeeded(msgs, "扫描 example.com")
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("should append new user turn after assistant reply: len=%d", len(out))
|
||||
}
|
||||
if out[2].Role != schema.User || out[2].Content != "扫描 example.com" {
|
||||
t.Fatalf("tail should be repeated user prompt, got role=%s content=%q", out[2].Role, out[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package multiagent
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ExecuteExitError 表示 execute 命令非零退出(预期失败,非超时/中断/流异常)。
|
||||
type ExecuteExitError struct {
|
||||
Code int
|
||||
}
|
||||
|
||||
func (e *ExecuteExitError) Error() string {
|
||||
if e == nil {
|
||||
return "exit status unknown"
|
||||
}
|
||||
return fmt.Sprintf("exit status %d", e.Code)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// literalInstructionGenModelInput passes Instruction through as a system message without
|
||||
// FString template formatting. Eino defaultGenModelInput formats instruction whenever
|
||||
// SessionValues exist; prompts with literal curly braces (project blackboard "{关系边: ...}",
|
||||
// JSON examples, link syntax) then fail with "could not find key".
|
||||
//
|
||||
// Matches eino/adk/prebuilt/deep genModelInput — the supported fix per Eino docs.
|
||||
func literalInstructionGenModelInput(ctx context.Context, instruction string, input *adk.AgentInput) ([]adk.Message, error) {
|
||||
msgs := make([]adk.Message, 0, len(input.Messages)+1)
|
||||
if instruction != "" {
|
||||
msgs = append(msgs, schema.SystemMessage(instruction))
|
||||
}
|
||||
msgs = append(msgs, input.Messages...)
|
||||
return msgs, nil
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestLiteralInstructionGenModelInput_PreservesLiteralCurlyBraces(t *testing.T) {
|
||||
t.Parallel()
|
||||
instruction := "- [finding/x] summary {关系边: discovered_on←target/dev}\n" +
|
||||
"如 finding 上 {from:target/*, type:discovered_on}"
|
||||
msgs, err := literalInstructionGenModelInput(context.Background(), instruction, &adk.AgentInput{
|
||||
Messages: []adk.Message{schema.UserMessage("继续")},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(msgs))
|
||||
}
|
||||
if msgs[0].Role != schema.System {
|
||||
t.Fatalf("first message must be system, got %s", msgs[0].Role)
|
||||
}
|
||||
for _, want := range []string{"{关系边:", "{from:target/*, type:discovered_on}"} {
|
||||
if !strings.Contains(msgs[0].Content, want) {
|
||||
t.Fatalf("system content missing %q: %q", want, msgs[0].Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,11 +5,3 @@ import "errors"
|
||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||
|
||||
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||
|
||||
// ErrEmptyResponseContinue 表示 Eino ADK 会话正常结束但未捕获到助手正文,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue / ErrTransientRetryContinue 同级)。
|
||||
var ErrEmptyResponseContinue = errors.New("agent empty response: continue after persisting trace")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/projectprompt"
|
||||
)
|
||||
|
||||
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
||||
@@ -122,7 +123,9 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
|
||||
|
||||
## 表达
|
||||
|
||||
在调用工具或给出计划变更前,用 2~5 句中文说明当前决策依据与期望证据形态;最终对用户交付结构化结论(发现摘要、证据、风险、下一步)。`
|
||||
在调用工具或给出计划变更前,用 2~5 句中文说明当前决策依据与期望证据形态;最终对用户交付结构化结论(发现摘要、证据、风险、下一步)。
|
||||
|
||||
` + projectprompt.ShellExecExecuteGuidanceSection()
|
||||
}
|
||||
|
||||
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
// 本中间件与之互补,专职兜底正向孤儿。
|
||||
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
|
||||
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
|
||||
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask /
|
||||
// - 位置建议:挂在 summarization / reduction / skill / plantask / system 合并 / 续聊 dedup 之后,
|
||||
// tool_search)之后,靠近 ChatModel 调用的那一端。
|
||||
type orphanToolPrunerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -120,7 +121,7 @@ func RunDeepAgent(
|
||||
mcpIDs = append(mcpIDs, id)
|
||||
mcpIDsMu.Unlock()
|
||||
}
|
||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||
einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
|
||||
|
||||
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
||||
snapshotMCPIDs := func() []string {
|
||||
@@ -223,7 +224,7 @@ func RunDeepAgent(
|
||||
}
|
||||
if einoSkillMW != nil {
|
||||
if einoFSTools && einoLoc != nil {
|
||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||
}
|
||||
@@ -231,13 +232,13 @@ func RunDeepAgent(
|
||||
}
|
||||
subHandlers = append(subHandlers, einoSkillMW)
|
||||
}
|
||||
subHandlers = append(subHandlers, subSumMw)
|
||||
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前,
|
||||
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。
|
||||
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
subHandlers = appendEinoChatModelTailMiddlewares(subHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "sub_agent:" + id,
|
||||
summarization: subSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
})
|
||||
|
||||
subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready())
|
||||
subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive)
|
||||
@@ -253,10 +254,11 @@ func RunDeepAgent(
|
||||
)
|
||||
}
|
||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||
Name: id,
|
||||
Description: desc,
|
||||
Instruction: subInstrFinal,
|
||||
Model: subModel,
|
||||
Name: id,
|
||||
Description: desc,
|
||||
Instruction: subInstrFinal,
|
||||
GenModelInput: literalInstructionGenModelInput,
|
||||
Model: subModel,
|
||||
ToolsConfig: adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
Tools: subToolsForCfg,
|
||||
@@ -358,19 +360,28 @@ func RunDeepAgent(
|
||||
if einoLoc != nil && einoFSTools {
|
||||
deepBackend = einoLoc
|
||||
deepShell = &einoStreamingShellWrap{
|
||||
inner: einoLoc,
|
||||
invokeNotify: toolInvokeNotify,
|
||||
einoAgentName: orchestratorName,
|
||||
outputChunk: nil,
|
||||
recordMonitor: einoExecMonitor,
|
||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||
inner: security.NewEinoStreamingShell(),
|
||||
invokeNotify: toolInvokeNotify,
|
||||
einoAgentName: orchestratorName,
|
||||
outputChunk: nil,
|
||||
beginMonitor: einoExecBegin,
|
||||
finishMonitor: einoExecFinish,
|
||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||
shellNoOutputTimeoutSec: agentShellNoOutputTimeoutSeconds(appCfg),
|
||||
}
|
||||
}
|
||||
|
||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
||||
taskEnrichExtra := systemPromptExtra
|
||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil {
|
||||
var taskBlackboardSupplement string
|
||||
if appCfg.Project.Enabled && db != nil {
|
||||
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||
if block, err := project.BuildFactIndexBlock(db, pid, appCfg.Project); err == nil {
|
||||
taskBlackboardSupplement = strings.TrimSpace(block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunesEffective(), taskBlackboardSupplement); mw != nil {
|
||||
deepHandlers = append(deepHandlers, mw)
|
||||
}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -379,14 +390,14 @@ func RunDeepAgent(
|
||||
if einoSkillMW != nil {
|
||||
deepHandlers = append(deepHandlers, einoSkillMW)
|
||||
}
|
||||
deepHandlers = append(deepHandlers, mainSumMw)
|
||||
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
deepHandlers = append(deepHandlers, capMw)
|
||||
}
|
||||
deepHandlers = appendEinoChatModelTailMiddlewares(deepHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "deep_orchestrator",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -395,14 +406,14 @@ func RunDeepAgent(
|
||||
if einoSkillMW != nil {
|
||||
supHandlers = append(supHandlers, einoSkillMW)
|
||||
}
|
||||
supHandlers = append(supHandlers, mainSumMw)
|
||||
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
supHandlers = append(supHandlers, capMw)
|
||||
}
|
||||
supHandlers = appendEinoChatModelTailMiddlewares(supHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "supervisor_orchestrator",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
mainToolsCfg := adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
@@ -416,7 +427,7 @@ func RunDeepAgent(
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
|
||||
deepOutKey, modelRetry, taskGen := deepExtrasFromConfig(ma)
|
||||
deepOutKey, taskGen := deepExtrasFromConfig(ma)
|
||||
|
||||
var da adk.Agent
|
||||
switch orchMode {
|
||||
@@ -428,7 +439,7 @@ func RunDeepAgent(
|
||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||
var peFsMw adk.ChatModelAgentMiddleware
|
||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||
}
|
||||
@@ -451,12 +462,14 @@ func RunDeepAgent(
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"),
|
||||
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
|
||||
},
|
||||
PlannerReplannerRewriteHandlers: appendEinoChatModelTailMiddlewares(nil, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "plan_execute_planner_replanner",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
skipTrace: true,
|
||||
}),
|
||||
})
|
||||
if perr != nil {
|
||||
return nil, perr
|
||||
@@ -467,15 +480,13 @@ func RunDeepAgent(
|
||||
Name: orchestratorName,
|
||||
Description: orchDescription,
|
||||
Instruction: supInstr,
|
||||
GenModelInput: literalInstructionGenModelInput,
|
||||
Model: mainModel,
|
||||
ToolsConfig: mainToolsCfg,
|
||||
MaxIterations: deepMaxIter,
|
||||
Handlers: supHandlers,
|
||||
Exit: &adk.ExitTool{},
|
||||
}
|
||||
if modelRetry != nil {
|
||||
supCfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
if deepOutKey != "" {
|
||||
supCfg.OutputKey = deepOutKey
|
||||
}
|
||||
@@ -509,9 +520,6 @@ func RunDeepAgent(
|
||||
if deepOutKey != "" {
|
||||
dcfg.OutputKey = deepOutKey
|
||||
}
|
||||
if modelRetry != nil {
|
||||
dcfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
if taskGen != nil {
|
||||
dcfg.TaskToolDescriptionGenerator = taskGen
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/projectprompt"
|
||||
)
|
||||
|
||||
func shellToolsPresent(toolNames []string) bool {
|
||||
for _, n := range toolNames {
|
||||
switch strings.ToLower(strings.TrimSpace(n)) {
|
||||
case "exec", "execute":
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// injectShellToolGuidance 在系统提示末尾追加 exec/execute 分工(仅当工具列表含 exec 或 execute)。
|
||||
func injectShellToolGuidance(instruction string, toolNames []string) string {
|
||||
if !shellToolsPresent(toolNames) {
|
||||
return instruction
|
||||
}
|
||||
block := strings.TrimSpace(projectprompt.ShellExecExecuteGuidanceSection())
|
||||
if block == "" {
|
||||
return instruction
|
||||
}
|
||||
s := strings.TrimSpace(instruction)
|
||||
if s == "" {
|
||||
return block
|
||||
}
|
||||
return s + "\n\n" + block
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInjectShellToolGuidance(t *testing.T) {
|
||||
got := injectShellToolGuidance("base", []string{"nmap"})
|
||||
if got != "base" {
|
||||
t.Fatalf("expected unchanged, got %q", got)
|
||||
}
|
||||
got = injectShellToolGuidance("base", []string{"exec", "nmap"})
|
||||
if !strings.Contains(got, "exec/execute") || !strings.Contains(got, "base") {
|
||||
t.Fatalf("expected shell guidance appended, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
@@ -11,7 +12,7 @@ import (
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
const defaultSubAgentUserContextMaxRunes = 2000
|
||||
const userContextSupplementHeader = "\n\n## 用户历史输入(原文,子代理必读)\n"
|
||||
|
||||
// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator
|
||||
// and appends the user's original conversation messages to the task description.
|
||||
@@ -30,13 +31,14 @@ type taskContextEnrichMiddleware struct {
|
||||
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
||||
// descriptions with user conversation context. Returns nil if disabled
|
||||
// (maxRunes < 0) or no user messages exist.
|
||||
// projectBlackboard 仅传项目黑板索引块(BuildFactIndexBlock);勿传完整 systemPromptExtra。
|
||||
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
|
||||
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
||||
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
|
||||
if supplement != "" {
|
||||
supplement += "\n\n## 项目黑板索引\n" + bb
|
||||
supplement += "\n\n" + bb
|
||||
} else {
|
||||
supplement = "\n\n## 项目黑板索引\n" + bb
|
||||
supplement = "\n\n" + bb
|
||||
}
|
||||
}
|
||||
if supplement == "" {
|
||||
@@ -86,9 +88,6 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
|
||||
if maxRunes < 0 {
|
||||
return ""
|
||||
}
|
||||
if maxRunes == 0 {
|
||||
maxRunes = defaultSubAgentUserContextMaxRunes
|
||||
}
|
||||
|
||||
var userMsgs []string
|
||||
for _, h := range history {
|
||||
@@ -107,12 +106,16 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
|
||||
return ""
|
||||
}
|
||||
|
||||
joined := strings.Join(userMsgs, "\n---\n")
|
||||
if len([]rune(joined)) > maxRunes {
|
||||
lines := make([]string, 0, len(userMsgs))
|
||||
for i, msg := range userMsgs {
|
||||
lines = append(lines, fmt.Sprintf("[第%d轮] %s", i+1, msg))
|
||||
}
|
||||
joined := strings.Join(lines, "\n")
|
||||
if maxRunes > 0 && len([]rune(joined)) > maxRunes {
|
||||
joined = truncateKeepFirstLast(userMsgs, maxRunes)
|
||||
}
|
||||
|
||||
return "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + joined
|
||||
return userContextSupplementHeader + joined
|
||||
}
|
||||
|
||||
// truncateKeepFirstLast keeps the first and last user messages, giving each
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestBuildUserContextSupplement_DisabledByNegative(t *testing.T) {
|
||||
func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) {
|
||||
msg := strings.Repeat("A", 200)
|
||||
result := buildUserContextSupplement(msg, nil, 50)
|
||||
header := "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n"
|
||||
header := userContextSupplementHeader
|
||||
body := strings.TrimPrefix(result, header)
|
||||
if len([]rune(body)) > 50 {
|
||||
t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body)))
|
||||
@@ -89,7 +89,7 @@ func TestBuildUserContextSupplement_TruncateKeepsFirstAndLast(t *testing.T) {
|
||||
history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)})
|
||||
}
|
||||
last := "最后一条指令"
|
||||
result := buildUserContextSupplement(last, history, 0)
|
||||
result := buildUserContextSupplement(last, history, 800)
|
||||
if !strings.Contains(result, "http://target.com") {
|
||||
t.Error("first message (target URL) should survive truncation")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// systemMessageNormalizerMiddleware merges duplicate role=system messages into a single
|
||||
// leading system message before summarization and each ChatModel call.
|
||||
type systemMessageNormalizerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
func newSystemMessageNormalizerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &systemMessageNormalizerMiddleware{logger: logger, phase: phase}
|
||||
}
|
||||
|
||||
func (m *systemMessageNormalizerMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
before := countADKSystemMessages(state.Messages)
|
||||
if before <= 1 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
normalized := normalizeSingleLeadingSystemMessage(state.Messages, "")
|
||||
if len(normalized) == len(state.Messages) && countADKSystemMessages(normalized) >= before {
|
||||
return ctx, state, nil
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.Info("eino system messages merged",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("system_before", before),
|
||||
zap.Int("system_after", countADKSystemMessages(normalized)),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(normalized)),
|
||||
)
|
||||
}
|
||||
out := *state
|
||||
out.Messages = normalized
|
||||
return ctx, &out, nil
|
||||
}
|
||||
|
||||
func countADKSystemMessages(msgs []adk.Message) int {
|
||||
n := 0
|
||||
for _, msg := range msgs {
|
||||
if msg != nil && msg.Role == schema.System {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// stripADKSystemMessages removes all system messages. Use before runner.Run restart when
|
||||
// genModelInput will prepend a fresh Instruction.
|
||||
func stripADKSystemMessages(msgs []adk.Message) []adk.Message {
|
||||
if len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
out := make([]adk.Message, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if msg == nil || msg.Role == schema.System {
|
||||
continue
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// mergeCollectedSystemMessages collapses multiple system messages into one (or none).
|
||||
func mergeCollectedSystemMessages(systemMsgs []adk.Message) []adk.Message {
|
||||
if len(systemMsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalizeSingleLeadingSystemMessage(systemMsgs, "")
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestStripADKSystemMessages(t *testing.T) {
|
||||
in := []adk.Message{
|
||||
schema.SystemMessage("a"),
|
||||
schema.UserMessage("u"),
|
||||
schema.SystemMessage("b"),
|
||||
schema.AssistantMessage("x", nil),
|
||||
}
|
||||
out := stripADKSystemMessages(in)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("got %d messages, want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
|
||||
t.Fatalf("unexpected roles: %s, %s", out[0].Role, out[1].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoMessagesForRunRestart_StripsSystemFromTrace(t *testing.T) {
|
||||
holder := newModelFacingTraceHolder()
|
||||
holder.storeFromState(&adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("sys-1"),
|
||||
schema.SystemMessage("sys-2"),
|
||||
schema.UserMessage("task"),
|
||||
}})
|
||||
msgs, src := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, nil, nil, 0)
|
||||
if src != einoRestartContextModelTrace {
|
||||
t.Fatalf("source: got %q want model_trace", src)
|
||||
}
|
||||
if len(msgs) != 1 || msgs[0].Role != schema.User {
|
||||
t.Fatalf("expected user-only restart msgs, got %+v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessageNormalizerMiddleware_MergesDuplicates(t *testing.T) {
|
||||
mw := newSystemMessageNormalizerMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("a"),
|
||||
schema.SystemMessage("b"),
|
||||
schema.UserMessage("u"),
|
||||
}}
|
||||
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if countADKSystemMessages(out.Messages) != 1 {
|
||||
t.Fatalf("want 1 system, got %d", countADKSystemMessages(out.Messages))
|
||||
}
|
||||
if out.Messages[0].Content != "a\n\nb" {
|
||||
t.Fatalf("merged content: %q", out.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessageNormalizerMiddleware_NoOpSingleSystem(t *testing.T) {
|
||||
mw := newSystemMessageNormalizerMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("only"),
|
||||
schema.UserMessage("u"),
|
||||
}}
|
||||
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != state {
|
||||
t.Fatalf("expected same state pointer for no-op")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
// UserVerbatimSectionHeading 用户原文锚点可读标题(块内保留,供 Agent 阅读)。
|
||||
UserVerbatimSectionHeading = "## 用户历史输入(原文保留,勿省略或改写)"
|
||||
|
||||
// UserVerbatimSectionStartMarker / EndMarker:HTML 注释边界,供程序化替换;对模型无指令语义。
|
||||
UserVerbatimSectionStartMarker = "<!-- user-verbatim-start -->"
|
||||
UserVerbatimSectionEndMarker = "<!-- user-verbatim-end -->"
|
||||
)
|
||||
|
||||
// ExtractUserContentsFromMessages 按时间顺序提取 user 角色消息的原文(跳过空白)。
|
||||
func ExtractUserContentsFromMessages(msgs []database.Message) []string {
|
||||
out := make([]string, 0, len(msgs))
|
||||
for i := range msgs {
|
||||
if !strings.EqualFold(strings.TrimSpace(msgs[i].Role), "user") {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(msgs[i].Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, content)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BuildUserVerbatimAnchorBlockFromMessages 从 messages 表行构建用户原文锚点块。
|
||||
// maxRunes: 0 = 不截断;>0 = 总 rune 上限(仍保留每一轮,仅对超长单条做尾部截断提示)。
|
||||
func BuildUserVerbatimAnchorBlockFromMessages(msgs []database.Message, maxRunes int) string {
|
||||
return BuildUserVerbatimAnchorBlock(ExtractUserContentsFromMessages(msgs), maxRunes)
|
||||
}
|
||||
|
||||
// BuildUserVerbatimAnchorBlock 将各轮用户原文格式化为 system prompt 锚点块。
|
||||
func BuildUserVerbatimAnchorBlock(userContents []string, maxRunes int) string {
|
||||
if len(userContents) == 0 {
|
||||
return ""
|
||||
}
|
||||
lines := make([]string, 0, len(userContents))
|
||||
for _, content := range userContents {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("[第%d轮] %s", len(lines)+1, content))
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return ""
|
||||
}
|
||||
body := strings.Join(lines, "\n")
|
||||
if maxRunes > 0 {
|
||||
body = capUserVerbatimBody(body, maxRunes)
|
||||
}
|
||||
return wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n" + body)
|
||||
}
|
||||
|
||||
func capUserVerbatimBody(body string, maxRunes int) string {
|
||||
rs := []rune(body)
|
||||
if len(rs) <= maxRunes {
|
||||
return body
|
||||
}
|
||||
suffix := "\n\n...(用户原文锚点已达配置上限,更早轮次可能被截断;完整原文见 messages 表)..."
|
||||
suffixRunes := []rune(suffix)
|
||||
keep := maxRunes - len(suffixRunes)
|
||||
if keep <= 0 {
|
||||
return string(rs[:maxRunes])
|
||||
}
|
||||
return string(rs[:keep]) + suffix
|
||||
}
|
||||
|
||||
func wrapUserVerbatimBlock(content string) string {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
return UserVerbatimSectionStartMarker + "\n" + content + "\n" + UserVerbatimSectionEndMarker + "\n"
|
||||
}
|
||||
|
||||
// ReplaceUserVerbatimAnchorSection 用 freshBlock 替换 content 中已有的用户原文锚点段。
|
||||
func ReplaceUserVerbatimAnchorSection(content, freshBlock string) (string, bool) {
|
||||
content = strings.TrimSpace(content)
|
||||
freshBlock = strings.TrimSpace(freshBlock)
|
||||
if freshBlock == "" {
|
||||
return content, false
|
||||
}
|
||||
start, ok := userVerbatimSectionStart(content)
|
||||
if !ok {
|
||||
return content, false
|
||||
}
|
||||
end, ok := userVerbatimSectionEnd(content, start)
|
||||
if !ok {
|
||||
return content, false
|
||||
}
|
||||
return strings.TrimSpace(content[:start] + freshBlock + content[end:]), true
|
||||
}
|
||||
|
||||
func userVerbatimSectionStart(content string) (int, bool) {
|
||||
idx := strings.Index(content, UserVerbatimSectionStartMarker)
|
||||
if idx < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return idx, true
|
||||
}
|
||||
|
||||
func userVerbatimSectionEnd(content string, start int) (int, bool) {
|
||||
if start < 0 || start >= len(content) {
|
||||
return 0, false
|
||||
}
|
||||
tail := content[start:]
|
||||
idx := strings.LastIndex(tail, UserVerbatimSectionEndMarker)
|
||||
if idx < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return start + idx + len(UserVerbatimSectionEndMarker), true
|
||||
}
|
||||
|
||||
// RefreshUserVerbatimAnchorInMessages 在 summarization 等压缩后,用 freshBlock 刷新 system 中的用户原文锚点。
|
||||
// 若尚无锚点段,则追加到首条 system 消息;若无 system 消息则在开头插入一条。
|
||||
func RefreshUserVerbatimAnchorInMessages(msgs []adk.Message, freshBlock string) []adk.Message {
|
||||
freshBlock = strings.TrimSpace(freshBlock)
|
||||
if freshBlock == "" || len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
|
||||
out := make([]adk.Message, len(msgs))
|
||||
changed := false
|
||||
for i, msg := range msgs {
|
||||
if msg == nil || msg.Role != schema.System {
|
||||
out[i] = msg
|
||||
continue
|
||||
}
|
||||
newContent, ok := ReplaceUserVerbatimAnchorSection(msg.Content, freshBlock)
|
||||
if !ok {
|
||||
out[i] = msg
|
||||
continue
|
||||
}
|
||||
cloned := *msg
|
||||
cloned.Content = newContent
|
||||
out[i] = &cloned
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
return out
|
||||
}
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg == nil || msg.Role != schema.System {
|
||||
continue
|
||||
}
|
||||
cloned := *msg
|
||||
cloned.Content = AppendSystemPromptBlock(cloned.Content, freshBlock)
|
||||
out[i] = &cloned
|
||||
return out
|
||||
}
|
||||
|
||||
prefix := make([]adk.Message, 0, len(msgs)+1)
|
||||
prefix = append(prefix, schema.SystemMessage(freshBlock))
|
||||
return append(prefix, msgs...)
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestBuildUserVerbatimAnchorBlock_MultiTurn(t *testing.T) {
|
||||
msgs := []database.Message{
|
||||
{Role: "user", Content: "目标 https://a.com 仅测 /api"},
|
||||
{Role: "assistant", Content: "好的"},
|
||||
{Role: "user", Content: "用 admin:test 登录"},
|
||||
}
|
||||
block := BuildUserVerbatimAnchorBlockFromMessages(msgs, 0)
|
||||
if block == "" {
|
||||
t.Fatal("expected non-empty block")
|
||||
}
|
||||
if !strings.Contains(block, UserVerbatimSectionStartMarker) {
|
||||
t.Error("missing start marker")
|
||||
}
|
||||
if !strings.Contains(block, "[第1轮]") || !strings.Contains(block, "https://a.com") {
|
||||
t.Error("missing first user turn")
|
||||
}
|
||||
if !strings.Contains(block, "[第2轮]") || !strings.Contains(block, "admin:test") {
|
||||
t.Error("missing second user turn")
|
||||
}
|
||||
if strings.Contains(block, "好的") {
|
||||
t.Error("assistant content should not appear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUserVerbatimAnchorSection(t *testing.T) {
|
||||
old := "prefix\n\n" + wrapUserVerbatimBlock("## old\n\n[第1轮] a") + "\nsuffix"
|
||||
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] b\n[第2轮] c")
|
||||
out, ok := ReplaceUserVerbatimAnchorSection(old, newBlock)
|
||||
if !ok {
|
||||
t.Fatal("expected replace ok")
|
||||
}
|
||||
if !strings.Contains(out, "[第2轮] c") {
|
||||
t.Errorf("expected new block, got %q", out)
|
||||
}
|
||||
if !strings.HasPrefix(strings.TrimSpace(out), "prefix") {
|
||||
t.Error("prefix should remain")
|
||||
}
|
||||
if !strings.Contains(out, "suffix") {
|
||||
t.Error("suffix should remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshUserVerbatimAnchorInMessages_ReplaceExisting(t *testing.T) {
|
||||
oldBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] old")
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("instr\n\n" + oldBlock),
|
||||
schema.UserMessage("hi"),
|
||||
}
|
||||
newBlock := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] new")
|
||||
out := RefreshUserVerbatimAnchorInMessages(msgs, newBlock)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("message count: got %d", len(out))
|
||||
}
|
||||
if !strings.Contains(out[0].Content, "[第1轮] new") {
|
||||
t.Errorf("system content: %q", out[0].Content)
|
||||
}
|
||||
if strings.Contains(out[0].Content, "[第1轮] old") {
|
||||
t.Error("old anchor should be replaced")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshUserVerbatimAnchorInMessages_InsertWhenMissing(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("base instruction"),
|
||||
schema.UserMessage("hi"),
|
||||
}
|
||||
block := wrapUserVerbatimBlock(UserVerbatimSectionHeading + "\n\n[第1轮] anchor")
|
||||
out := RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||
if !strings.Contains(out[0].Content, "[第1轮] anchor") {
|
||||
t.Errorf("expected appended anchor, got %q", out[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserVerbatimAnchorBlock_MaxRunes(t *testing.T) {
|
||||
long := strings.Repeat("字", 200)
|
||||
block := BuildUserVerbatimAnchorBlock([]string{long}, 50)
|
||||
body := block
|
||||
if idx := strings.Index(body, UserVerbatimSectionStartMarker); idx >= 0 {
|
||||
body = strings.TrimPrefix(body[idx+len(UserVerbatimSectionStartMarker):], "\n")
|
||||
}
|
||||
if len([]rune(body)) > 120 {
|
||||
t.Errorf("expected capped body, got %d runes", len([]rune(body)))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func sanitizeWorkspacePathSegment(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "default"
|
||||
}
|
||||
s = strings.ReplaceAll(s, string(filepath.Separator), "-")
|
||||
s = strings.ReplaceAll(s, "/", "-")
|
||||
s = strings.ReplaceAll(s, "\\", "-")
|
||||
s = strings.ReplaceAll(s, "..", "__")
|
||||
if len(s) > 180 {
|
||||
s = s[:180]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// WorkspaceRootDir returns the relative workspace root for downloads and local analysis.
|
||||
// Project-bound sessions share projects/<id>/; otherwise conversations/<id>/.
|
||||
func WorkspaceRootDir(configuredBase, projectID, conversationID string) string {
|
||||
base := strings.TrimSpace(configuredBase)
|
||||
if base == "" {
|
||||
base = filepath.Join("tmp", "workspace")
|
||||
}
|
||||
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||
return filepath.Join(base, "projects", sanitizeWorkspacePathSegment(pid))
|
||||
}
|
||||
conv := strings.TrimSpace(conversationID)
|
||||
if conv == "" {
|
||||
conv = "default"
|
||||
}
|
||||
return filepath.Join(base, "conversations", sanitizeWorkspacePathSegment(conv))
|
||||
}
|
||||
|
||||
// EnsureWorkspace creates the workspace directory and returns its absolute path.
|
||||
func EnsureWorkspace(root string) (string, error) {
|
||||
abs, err := filepath.Abs(strings.TrimSpace(root))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("workspace abs: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(abs, 0o755); err != nil {
|
||||
return "", fmt.Errorf("workspace mkdir: %w", err)
|
||||
}
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
// BuildWorkspaceBlock instructs the agent to use the session workspace instead of /tmp.
|
||||
func BuildWorkspaceBlock(absPath string) string {
|
||||
absPath = strings.TrimSpace(absPath)
|
||||
if absPath == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`## 会话工作目录(下载与本地分析)
|
||||
|
||||
**必须使用以下目录**保存 curl/wget 下载的文件、临时 HTML/JS,以及 read_file/glob/grep 的检索范围:
|
||||
`+"`%s`"+`
|
||||
|
||||
- **禁止**使用系统 `+"`/tmp`"+` 或其它全局临时目录(多项目/多会话会互窜遗留文件)。
|
||||
- 下载示例:`+"`curl -o '%s/page.html' 'https://target/'`"+`;exec 时可将 `+"`workdir`"+` 设为该目录。
|
||||
- 读取前用 glob/grep/read_file **限定在该目录**下搜索,勿在 `+"`/tmp`"+` 盲目检索。`, absPath, absPath)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWorkspaceRootDirProjectScoped(t *testing.T) {
|
||||
got := WorkspaceRootDir("", "proj-1", "conv-1")
|
||||
want := filepath.Join("tmp", "workspace", "projects", "proj-1")
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkspaceRootDirConversationScoped(t *testing.T) {
|
||||
got := WorkspaceRootDir("/data/ws", "", "conv-abc")
|
||||
want := filepath.Join("/data/ws", "conversations", "conv-abc")
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureWorkspaceCreatesDir(t *testing.T) {
|
||||
root := filepath.Join(t.TempDir(), "nested", "workspace")
|
||||
abs, err := EnsureWorkspace(root)
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureWorkspace: %v", err)
|
||||
}
|
||||
st, err := os.Stat(abs)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat: %v", err)
|
||||
}
|
||||
if !st.IsDir() {
|
||||
t.Fatal("expected directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWorkspaceBlockMentionsPath(t *testing.T) {
|
||||
block := BuildWorkspaceBlock("/opt/csai/tmp/workspace/projects/p1")
|
||||
if block == "" {
|
||||
t.Fatal("expected non-empty block")
|
||||
}
|
||||
if !strings.Contains(block, "/opt/csai/tmp/workspace/projects/p1") {
|
||||
t.Fatalf("block missing path: %s", block)
|
||||
}
|
||||
if !strings.Contains(block, "/tmp") {
|
||||
t.Fatalf("block should warn about /tmp: %s", block)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package projectprompt
|
||||
|
||||
// ShellExecExecuteGuidanceSection 供单代理/多代理系统提示追加:exec 与 execute 分工(尽量短)。
|
||||
func ShellExecExecuteGuidanceSection() string {
|
||||
return `Shell(exec/execute):有专用 MCP 工具时优先专用工具;系统命令(管道、workdir、后台 &)用 exec;skills/ 内脚本(配合 read_file、skill)用 execute;多步扫描分拆调用,禁止一条 shell 串多个扫描器。下载/临时文件须写入系统提示中的「会话工作目录」,禁止用 /tmp。`
|
||||
}
|
||||
|
||||
// ShellExecExecuteGuidanceReconSuffix 侦察子代理可选追加(一行)。
|
||||
func ShellExecExecuteGuidanceReconSuffix() string {
|
||||
return `枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。`
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FormatCommandFailureResult 与 exec 工具 ToolResult 文案一致(不含 ToolErrorPrefix)。
|
||||
func FormatCommandFailureResult(exitCode int, output string) string {
|
||||
output = strings.TrimSpace(output)
|
||||
errMsg := fmt.Sprintf("exit status %d", exitCode)
|
||||
if output == "" {
|
||||
return fmt.Sprintf("命令执行失败: %s", errMsg)
|
||||
}
|
||||
if strings.HasPrefix(output, "命令执行失败:") {
|
||||
return output
|
||||
}
|
||||
return fmt.Sprintf("命令执行失败: %s\n输出: %s", errMsg, output)
|
||||
}
|
||||
|
||||
// FormatCommandFailureFromErr 根据 exec/execute 返回的 error 生成统一失败文案(IsError 正文)。
|
||||
func FormatCommandFailureFromErr(err error, output string) string {
|
||||
if err == nil {
|
||||
return strings.TrimSpace(output)
|
||||
}
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
return FormatCommandFailureResult(exitError.ExitCode(), output)
|
||||
}
|
||||
output = strings.TrimSpace(output)
|
||||
if output == "" {
|
||||
return fmt.Sprintf("命令执行失败: %v", err)
|
||||
}
|
||||
if strings.HasPrefix(output, "命令执行失败:") {
|
||||
return output
|
||||
}
|
||||
return fmt.Sprintf("命令执行失败: %v\n输出: %s", err, output)
|
||||
}
|
||||
|
||||
// ExecuteFailureStatusLine 流式 execute 结束时追加的单行状态(输出正文已在流中推送过)。
|
||||
func ExecuteFailureStatusLine(exitCode int) string {
|
||||
return fmt.Sprintf("\n命令执行失败: exit status %d", exitCode)
|
||||
}
|
||||
|
||||
// IsCommandFailureResult 判断工具结果正文是否表示命令非零退出(用于 execute / exec 对齐 isError)。
|
||||
func IsCommandFailureResult(content string) bool {
|
||||
return strings.Contains(content, "命令执行失败:")
|
||||
}
|
||||
|
||||
// IsLegacyShellExitNoise 过滤旧版 shell 流中冗余的 exit code 行。
|
||||
func IsLegacyShellExitNoise(s string) bool {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
return strings.HasPrefix(trimmed, "command exited with non-zero code ")
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormatCommandFailureResult(t *testing.T) {
|
||||
got := FormatCommandFailureResult(1, "sudo: password required")
|
||||
want := "命令执行失败: exit status 1\n输出: sudo: password required"
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
if FormatCommandFailureResult(2, "") != "命令执行失败: exit status 2" {
|
||||
t.Fatal("empty output format")
|
||||
}
|
||||
if FormatCommandFailureResult(1, "命令执行失败: exit status 1") != "命令执行失败: exit status 1" {
|
||||
t.Fatal("should not double-wrap")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCommandFailureResult(t *testing.T) {
|
||||
if !IsCommandFailureResult("sudo: err\n命令执行失败: exit status 1") {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if IsCommandFailureResult("sudo: err only") {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCommandFailureFromErr(t *testing.T) {
|
||||
cmd := exec.Command("sh", "-c", "exit 42")
|
||||
err := cmd.Run()
|
||||
got := FormatCommandFailureFromErr(err, "oops")
|
||||
if got != "命令执行失败: exit status 42\n输出: oops" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
timeoutErr := errors.New("shell inactivity timeout (300s)")
|
||||
got2 := FormatCommandFailureFromErr(timeoutErr, "already timed out")
|
||||
if !strings.Contains(got2, "shell inactivity timeout") || !strings.Contains(got2, "already timed out") {
|
||||
t.Fatalf("got %q", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLegacyShellExitNoise(t *testing.T) {
|
||||
if !IsLegacyShellExitNoise("command exited with non-zero code 1\n") {
|
||||
t.Fatal("expected legacy noise")
|
||||
}
|
||||
if IsLegacyShellExitNoise("sudo: failed") {
|
||||
t.Fatal("unexpected noise")
|
||||
}
|
||||
}
|
||||
+139
-110
@@ -32,10 +32,11 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
|
||||
|
||||
// Executor 安全工具执行器
|
||||
type Executor struct {
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300;-1=关闭(见 SetShellNoOutputTimeoutSeconds)
|
||||
}
|
||||
|
||||
// NewExecutor 创建新的执行器
|
||||
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
|
||||
return executor
|
||||
}
|
||||
|
||||
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
|
||||
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
|
||||
e.shellNoOutputTimeoutSec = sec
|
||||
}
|
||||
|
||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||
func (e *Executor) buildToolIndex() {
|
||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
// 执行命令
|
||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
e.logger.Info("执行安全工具",
|
||||
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
var err error
|
||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||
zap.String("tool", toolName),
|
||||
@@ -155,9 +162,8 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||
}
|
||||
} else {
|
||||
outputBytes, err2 := cmd.CombinedOutput()
|
||||
output = string(outputBytes)
|
||||
err = err2
|
||||
// 非流式:内存缓冲 + ctx 取消杀进程组;行为对齐原 CombinedOutput,避免双流管道 fan-in 死锁。
|
||||
output, err = combinedOutputCancellable(ctx, cmd)
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||
zap.String("tool", toolName),
|
||||
@@ -685,83 +691,21 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac
|
||||
// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。
|
||||
// command1 & command2 不算完全后台(command2 仍在前台执行)。
|
||||
func IsBackgroundShellCommand(command string) bool {
|
||||
// 移除首尾空格
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查命令中所有不在引号内的 & 符号
|
||||
// 找到最后一个 & 符号,检查它是否在命令末尾
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
lastAmpersandPos := -1
|
||||
|
||||
for i, r := range command {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if r == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if r == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
if r == '&' && !inSingleQuote && !inDoubleQuote {
|
||||
// 检查 & 前后是否有空格或换行(确保是独立的 &,而不是变量名的一部分)
|
||||
isStandalone := false
|
||||
|
||||
// 检查前面:空格、制表符、换行符,或者是命令开头
|
||||
if i == 0 {
|
||||
isStandalone = true
|
||||
} else {
|
||||
prev := command[i-1]
|
||||
if prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r' {
|
||||
isStandalone = true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查后面:空格、制表符、换行符,或者是命令末尾
|
||||
if isStandalone {
|
||||
if i == len(command)-1 {
|
||||
// 在末尾,肯定是独立的 &
|
||||
lastAmpersandPos = i
|
||||
} else {
|
||||
next := command[i+1]
|
||||
if next == ' ' || next == '\t' || next == '\n' || next == '\r' {
|
||||
// 后面有空格,是独立的 &
|
||||
lastAmpersandPos = i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到 & 符号,不是后台命令
|
||||
if lastAmpersandPos == -1 {
|
||||
positions := findStandaloneAmpersandPositions(command)
|
||||
if len(positions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查最后一个 & 后面是否还有非空内容
|
||||
afterAmpersand := strings.TrimSpace(command[lastAmpersandPos+1:])
|
||||
if afterAmpersand == "" {
|
||||
// & 在末尾或后面只有空白字符,这是完全后台命令
|
||||
// 检查 & 前面是否有内容
|
||||
beforeAmpersand := strings.TrimSpace(command[:lastAmpersandPos])
|
||||
return beforeAmpersand != ""
|
||||
last := positions[len(positions)-1]
|
||||
afterAmpersand := strings.TrimSpace(command[last+1:])
|
||||
if afterAmpersand != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果 & 后面还有非空内容,说明是 command1 & command2 的情况
|
||||
// 这种情况下,command2会在前台执行,所以不算完全后台命令
|
||||
return false
|
||||
beforeAmpersand := strings.TrimSpace(command[:last])
|
||||
return beforeAmpersand != ""
|
||||
}
|
||||
|
||||
// executeSystemCommand 执行系统命令
|
||||
@@ -797,6 +741,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
zap.String("command", command),
|
||||
)
|
||||
|
||||
command = PrepareShellCommandForExecute(command)
|
||||
|
||||
// 获取shell类型(可选,默认为sh)
|
||||
shell := "sh"
|
||||
if s, ok := args["shell"].(string); ok && s != "" {
|
||||
@@ -820,8 +766,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
|
||||
// 执行命令
|
||||
e.logger.Info("执行系统命令",
|
||||
@@ -837,10 +782,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&")
|
||||
commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand)
|
||||
|
||||
// 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid。
|
||||
// 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行,
|
||||
// bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。
|
||||
pidCommand := fmt.Sprintf("%s </dev/null >/dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand)
|
||||
// 构建新命令:后台作业重定向标准流后 echo $pid(与 RedirectBackgroundJobStdio 一致)。
|
||||
pidCommand := RedirectBackgroundJobStdio(commandWithoutAmpersand+" &") + " pid=$!; echo $pid"
|
||||
|
||||
// 创建新命令来获取PID
|
||||
var pidCmd *exec.Cmd
|
||||
@@ -850,8 +793,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
} else {
|
||||
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
||||
}
|
||||
applyDefaultTerminalEnv(pidCmd)
|
||||
_ = prepareShellCmdSession(pidCmd)
|
||||
ConfigureShellCmdForAgentExecute(pidCmd)
|
||||
|
||||
// 获取stdout管道
|
||||
stdout, err := pidCmd.StdoutPipe()
|
||||
@@ -963,29 +905,25 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
var err error
|
||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||
if workDir != "" {
|
||||
cmd2.Dir = workDir
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd2)
|
||||
_ = prepareShellCmdSession(cmd2)
|
||||
ConfigureShellCmdForAgentExecute(cmd2)
|
||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||
}
|
||||
} else {
|
||||
outputBytes, err2 := cmd.CombinedOutput()
|
||||
output = string(outputBytes)
|
||||
err = err2
|
||||
output, err = combinedOutputCancellable(ctx, cmd)
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||
if workDir != "" {
|
||||
cmd2.Dir = workDir
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd2)
|
||||
_ = prepareShellCmdSession(cmd2)
|
||||
ConfigureShellCmdForAgentExecute(cmd2)
|
||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||
}
|
||||
}
|
||||
@@ -999,7 +937,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
|
||||
Text: FormatCommandFailureFromErr(err, output),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
@@ -1022,12 +960,58 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
}, nil
|
||||
}
|
||||
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
// combinedOutputCancellable 行为对齐 cmd.CombinedOutput(stdout/stderr 写入内存缓冲),
|
||||
// 但在 ctx 取消时 terminateCmdTree 终止整棵进程树。
|
||||
// 非流式路径不使用双流管道 fan-in,避免 stderr 撑满管道缓冲区时与 stdout 互相阻塞导致死锁。
|
||||
// 无输出空闲检测由上层 agent.tool_timeout_minutes 兜底,不改变原 CombinedOutput 语义。
|
||||
func combinedOutputCancellable(ctx context.Context, cmd *exec.Cmd) (string, error) {
|
||||
var stdoutBuf, stderrBuf strings.Builder
|
||||
cmd.Stdout = &stdoutBuf
|
||||
cmd.Stderr = &stderrBuf
|
||||
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
TerminateShellCmdSession(session)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
defer close(stopWatch)
|
||||
|
||||
var waitErr error
|
||||
select {
|
||||
case waitErr = <-done:
|
||||
case <-ctx.Done():
|
||||
waitErr = <-done
|
||||
return joinCommandOutput(stdoutBuf.String(), stderrBuf.String()), ctx.Err()
|
||||
}
|
||||
return joinCommandOutput(stdoutBuf.String(), stderrBuf.String()), waitErr
|
||||
}
|
||||
|
||||
func joinCommandOutput(stdout, stderr string) string {
|
||||
if stderr == "" {
|
||||
return stdout
|
||||
}
|
||||
if stdout == "" {
|
||||
return stderr
|
||||
}
|
||||
return stdout + stderr
|
||||
}
|
||||
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1037,7 +1021,8 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
_ = stdoutPipe.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
_ = stderrPipe.Close()
|
||||
return "", err
|
||||
@@ -1047,7 +1032,7 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
terminateCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
@@ -1086,23 +1071,61 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
if deltaBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
cb(deltaBuilder.String())
|
||||
if cb != nil {
|
||||
cb(deltaBuilder.String())
|
||||
}
|
||||
deltaBuilder.Reset()
|
||||
lastFlush = time.Now()
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
idleWatch := NewShellInactivityWatch(noOutputSec)
|
||||
if idleWatch != nil {
|
||||
defer idleWatch.Stop()
|
||||
}
|
||||
|
||||
fireInactivity := func() {
|
||||
TerminateShellCmdSession(session)
|
||||
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||
outBuilder.WriteString(msg)
|
||||
if cb != nil {
|
||||
cb(msg)
|
||||
}
|
||||
_ = session.Wait()
|
||||
}
|
||||
|
||||
chunksLoop:
|
||||
for {
|
||||
var idleCh <-chan struct{}
|
||||
if idleWatch != nil {
|
||||
idleCh = idleWatch.Expired
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
TerminateShellCmdSession(session)
|
||||
flush()
|
||||
_ = session.Wait()
|
||||
return outBuilder.String(), ctx.Err()
|
||||
case <-idleCh:
|
||||
fireInactivity()
|
||||
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||
case chunk, ok := <-chunks:
|
||||
if !ok {
|
||||
break chunksLoop
|
||||
}
|
||||
if chunk != "" && idleWatch != nil {
|
||||
idleWatch.Bump()
|
||||
}
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
flush()
|
||||
|
||||
// 等待命令结束,返回最终退出状态
|
||||
waitErr := cmd.Wait()
|
||||
waitErr := session.Wait()
|
||||
return outBuilder.String(), waitErr
|
||||
}
|
||||
|
||||
@@ -1116,6 +1139,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
|
||||
if cmd.Env == nil {
|
||||
cmd.Env = os.Environ()
|
||||
}
|
||||
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
|
||||
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
||||
has := func(k string) bool {
|
||||
prefix := k + "="
|
||||
@@ -1159,7 +1183,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
if runtime.GOOS == "windows" {
|
||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||
if cb != nil {
|
||||
return streamCommandOutput(ctx, cmd, cb)
|
||||
return streamCommandOutput(ctx, cmd, cb, 0)
|
||||
}
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
@@ -1173,13 +1197,18 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
}
|
||||
defer func() { _ = ptmx.Close() }()
|
||||
|
||||
rootPID := 0
|
||||
if cmd.Process != nil {
|
||||
rootPID = cmd.Process.Pid
|
||||
}
|
||||
|
||||
// ctx 取消时尽快终止子进程
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = ptmx.Close() // 触发读退出
|
||||
terminateCmdTree(cmd)
|
||||
terminateProcessGroup(rootPID, cmd)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2,6 +2,8 @@ package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -71,6 +73,27 @@ func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSystemCommand_FailureFormat(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
res, err := executor.executeSystemCommand(context.Background(), map[string]interface{}{
|
||||
"command": "echo fail-msg >&2; exit 7",
|
||||
"shell": "sh",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("executeSystemCommand: %v", err)
|
||||
}
|
||||
if res == nil || !res.IsError {
|
||||
t.Fatalf("expected IsError, got %+v", res)
|
||||
}
|
||||
text := res.Content[0].Text
|
||||
if text != FormatCommandFailureResult(7, "fail-msg\n") && text != FormatCommandFailureResult(7, "fail-msg") {
|
||||
t.Fatalf("unexpected failure text: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "exit status 7") || !strings.Contains(text, "fail-msg") {
|
||||
t.Fatalf("unexpected failure text: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
|
||||
pos1 := 1
|
||||
executor, _ := setupTestExecutor(t)
|
||||
@@ -126,3 +149,33 @@ func indexOf(slice []string, s string) int {
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// TestCombinedOutputCancellable_ContextCancelKillsTree 验证 ctx 取消时能在数秒内结束(杀进程组,非挂死)。
|
||||
func TestCombinedOutputCancellable_ContextCancelKillsTree(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix process group kill")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", "sleep 300")
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := combinedOutputCancellable(ctx, cmd)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
t.Fatal("expected context cancel error")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("combinedOutputCancellable did not return within 5s after context cancel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,13 +19,23 @@ func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
// terminateProcessGroup 对 rootPID 对应进程组发 SIGKILL;rootPID 为 0 时回退到 cmd.Process.Pid。
|
||||
func terminateProcessGroup(rootPID int, cmd *exec.Cmd) {
|
||||
pid := rootPID
|
||||
if pid <= 0 && cmd != nil && cmd.Process != nil {
|
||||
pid = cmd.Process.Pid
|
||||
}
|
||||
if pid <= 0 {
|
||||
return
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
terminateProcessGroup(0, cmd)
|
||||
}
|
||||
|
||||
@@ -2,16 +2,42 @@
|
||||
|
||||
package security
|
||||
|
||||
import "os/exec"
|
||||
import (
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
_ = cmd
|
||||
if cmd == nil {
|
||||
return nil
|
||||
}
|
||||
// 独立进程组,便于 taskkill /T 终止整棵子进程树。
|
||||
if cmd.SysProcAttr == nil {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
cmd.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP
|
||||
return nil
|
||||
}
|
||||
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
// terminateProcessGroup 使用 taskkill /F /T 终止进程及其子进程;rootPID 为 0 时回退到 cmd.Process.Pid。
|
||||
func terminateProcessGroup(rootPID int, cmd *exec.Cmd) {
|
||||
pid := rootPID
|
||||
if pid <= 0 && cmd != nil && cmd.Process != nil {
|
||||
pid = cmd.Process.Pid
|
||||
}
|
||||
if pid <= 0 {
|
||||
return
|
||||
}
|
||||
_ = cmd.Process.Kill()
|
||||
tk := exec.Command("taskkill", "/F", "/T", "/PID", strconv.Itoa(pid))
|
||||
if err := tk.Run(); err != nil {
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminateCmdTree 使用 taskkill /F /T 终止进程及其子进程(Windows 上 Process.Kill 无法保证杀掉 python 等孙进程)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
terminateProcessGroup(0, cmd)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
package security
|
||||
|
||||
import "strings"
|
||||
|
||||
const backgroundJobStdioRedirect = " </dev/null >/dev/null 2>&1"
|
||||
|
||||
// findStandaloneAmpersandPositions 返回不在引号内的独立 & 下标(排除 &&)。
|
||||
func findStandaloneAmpersandPositions(command string) []int {
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var positions []int
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(command); i++ {
|
||||
r := command[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if r == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if r == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
if r != '&' || inSingleQuote || inDoubleQuote {
|
||||
continue
|
||||
}
|
||||
if i+1 < len(command) && command[i+1] == '&' {
|
||||
continue
|
||||
}
|
||||
if i > 0 && command[i-1] == '&' {
|
||||
continue
|
||||
}
|
||||
|
||||
isStandalone := i == 0
|
||||
if !isStandalone {
|
||||
prev := command[i-1]
|
||||
isStandalone = prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r'
|
||||
}
|
||||
if !isStandalone {
|
||||
continue
|
||||
}
|
||||
if i == len(command)-1 {
|
||||
positions = append(positions, i)
|
||||
continue
|
||||
}
|
||||
next := command[i+1]
|
||||
if next == ' ' || next == '\t' || next == '\n' || next == '\r' {
|
||||
positions = append(positions, i)
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func segmentHasStdioRedirect(segment string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(segment))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(lower, ">/dev/null") || strings.Contains(lower, "2>/dev/null") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(lower, "&>") || strings.Contains(lower, "&>>") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(lower, "2>&1") && strings.Contains(lower, "/dev/null") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RedirectBackgroundJobStdio 为每个独立 & 前的后台段注入 </dev/null >/dev/null 2>&1,
|
||||
// 避免后台子进程占用 execute/exec 管道导致挂死。
|
||||
func RedirectBackgroundJobStdio(command string) string {
|
||||
positions := findStandaloneAmpersandPositions(command)
|
||||
if len(positions) == 0 {
|
||||
return command
|
||||
}
|
||||
|
||||
out := command
|
||||
for j := len(positions) - 1; j >= 0; j-- {
|
||||
i := positions[j]
|
||||
before := out[:i]
|
||||
after := out[i:]
|
||||
trimmed := strings.TrimRight(before, " \t\r\n")
|
||||
if segmentHasStdioRedirect(trimmed) {
|
||||
continue
|
||||
}
|
||||
trailing := before[len(trimmed):]
|
||||
out = trimmed + backgroundJobStdioRedirect + trailing + after
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PrepareShellCommandForExecute 组合 execute/exec 用的非交互包装与后台 IO 重定向。
|
||||
// 须先注入 exec </dev/null,再改写 & 后台段,否则段内 </dev/null 会使 stdin 重定向被误判为已存在。
|
||||
func PrepareShellCommandForExecute(shellCommand string) string {
|
||||
return RedirectBackgroundJobStdio(PrepareNonInteractiveShellCommand(shellCommand))
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedirectBackgroundJobStdio_mixedCommand(t *testing.T) {
|
||||
in := "java -jar app.jar & JRMP_PID=$!; echo started"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
if !strings.Contains(out, "java -jar app.jar </dev/null >/dev/null 2>&1 &") {
|
||||
t.Fatalf("expected redirect before &: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "echo started") {
|
||||
t.Fatalf("foreground tail preserved: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectBackgroundJobStdio_trailingOnly(t *testing.T) {
|
||||
in := "sleep 120 &"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
want := "sleep 120 </dev/null >/dev/null 2>&1 &"
|
||||
if strings.TrimSpace(out) != want {
|
||||
t.Fatalf("got %q want %q", out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectBackgroundJobStdio_skipsAlreadyRedirected(t *testing.T) {
|
||||
in := "sleep 1 >/dev/null 2>&1 & echo ok"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
if out != in {
|
||||
t.Fatalf("should not double-redirect: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectBackgroundJobStdio_skipsAndAnd(t *testing.T) {
|
||||
in := "test -f /etc/passwd && echo ok"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
if out != in {
|
||||
t.Fatalf("&& must not be treated as background &: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareShellCommandForExecute(t *testing.T) {
|
||||
out := PrepareShellCommandForExecute("java -jar x & echo hi")
|
||||
if !strings.Contains(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") {
|
||||
t.Fatalf("missing pager export: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "java -jar x </dev/null >/dev/null 2>&1 &") {
|
||||
t.Fatalf("missing background redirect: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBackgroundShellCommand_usesSharedParser(t *testing.T) {
|
||||
if !IsBackgroundShellCommand("sleep 1 &") {
|
||||
t.Fatal("trailing & should be background")
|
||||
}
|
||||
if IsBackgroundShellCommand("sleep 1 & echo hi") {
|
||||
t.Fatal("mixed should not be fully background")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// ConfigureShellCmdForAgentExecute 与 exec 工具一致:非交互 stdin、pager/TERM 环境、独立进程组。
|
||||
func ConfigureShellCmdForAgentExecute(cmd *exec.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
}
|
||||
|
||||
// TerminateShellCmdTree 尽力终止 shell 及其子进程组(与 exec/execute 超时取消一致)。
|
||||
func TerminateShellCmdTree(cmd *exec.Cmd) {
|
||||
terminateCmdTree(cmd)
|
||||
}
|
||||
|
||||
// TerminateShellCmdSession 使用 Start 时缓存的进程组 ID 终止(shell 已退出时仍有效)。
|
||||
func TerminateShellCmdSession(session *ShellSession) {
|
||||
TerminateShellSession(session)
|
||||
}
|
||||
|
||||
// EinoStreamingShell 为 Eino ADK execute 工具提供流式 shell,行为与 exec 对齐:
|
||||
// 并发读取 stdout/stderr(定长块,非按行),避免官方 local.ExecuteStreaming 先排空 stdout
|
||||
// 导致 stderr 错误(如 sudo 密码提示)长时间不可见、UI 一直显示「执行中」。
|
||||
type EinoStreamingShell struct{}
|
||||
|
||||
// NewEinoStreamingShell 创建 execute 流式 shell 实现。
|
||||
func NewEinoStreamingShell() *EinoStreamingShell {
|
||||
return &EinoStreamingShell{}
|
||||
}
|
||||
|
||||
// ExecuteStreaming 实现 filesystem.StreamingShell。
|
||||
func (s *EinoStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
if input == nil || input.Command == "" {
|
||||
return nil, fmt.Errorf("command is required")
|
||||
}
|
||||
|
||||
sr, w := schema.Pipe[*filesystem.ExecuteResponse](100)
|
||||
if input.RunInBackendGround {
|
||||
go runShellInBackground(ctx, input.Command, w)
|
||||
return sr, nil
|
||||
}
|
||||
go streamShellForeground(ctx, input.Command, w)
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func runShellInBackground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
|
||||
defer w.Close()
|
||||
|
||||
command = PrepareShellCommandForExecute(command)
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
|
||||
return
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
_ = stdout.Close()
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
|
||||
return
|
||||
}
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
_ = stdout.Close()
|
||||
_ = stderr.Close()
|
||||
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
drainShellPipes(stdout, stderr)
|
||||
_ = session.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
TerminateShellCmdSession(session)
|
||||
}
|
||||
|
||||
exitCode := 0
|
||||
_ = w.Send(&filesystem.ExecuteResponse{
|
||||
Output: "command started in background\n",
|
||||
ExitCode: &exitCode,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func drainShellPipes(stdout, stderr io.Reader) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(io.Discard, stdout)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(io.Discard, stderr)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func streamShellForeground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
|
||||
defer w.Close()
|
||||
|
||||
command = PrepareShellCommandForExecute(command)
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
|
||||
return
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
|
||||
return
|
||||
}
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
_ = stderrPipe.Close()
|
||||
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
TerminateShellCmdSession(session)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
defer close(stopWatch)
|
||||
|
||||
chunks := make(chan string, 64)
|
||||
var wg sync.WaitGroup
|
||||
readFn := func(r io.Reader) {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 8192)
|
||||
for {
|
||||
n, readErr := r.Read(buf)
|
||||
if n > 0 {
|
||||
chunks <- string(buf[:n])
|
||||
}
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go readFn(stdoutPipe)
|
||||
go readFn(stderrPipe)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(chunks)
|
||||
}()
|
||||
|
||||
hadOutput := false
|
||||
for chunk := range chunks {
|
||||
if chunk == "" {
|
||||
continue
|
||||
}
|
||||
hadOutput = true
|
||||
if w.Send(&filesystem.ExecuteResponse{Output: chunk}, nil) {
|
||||
TerminateShellCmdSession(session)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
waitErr := session.Wait()
|
||||
if waitErr == nil {
|
||||
exitCode := 0
|
||||
_ = w.Send(&filesystem.ExecuteResponse{ExitCode: &exitCode}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(waitErr, &exitError) {
|
||||
exitCode := exitError.ExitCode()
|
||||
resp := &filesystem.ExecuteResponse{ExitCode: &exitCode}
|
||||
if !hadOutput {
|
||||
resp.Output = FormatCommandFailureResult(exitCode, "")
|
||||
}
|
||||
_ = w.Send(resp, nil)
|
||||
return
|
||||
}
|
||||
_ = w.Send(nil, fmt.Errorf("command failed: %w", waitErr))
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
)
|
||||
|
||||
func TestEinoStreamingShell_StreamsStderrBeforeStdoutEOF(t *testing.T) {
|
||||
shell := NewEinoStreamingShell()
|
||||
cmd := PrepareNonInteractiveShellCommand("echo err-only >&2; exit 1")
|
||||
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
start := time.Now()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if time.Since(start) > 3*time.Second {
|
||||
t.Fatalf("expected fast completion, took %v", time.Since(start))
|
||||
}
|
||||
if !strings.Contains(got.String(), "err-only") {
|
||||
t.Fatalf("expected stderr in output, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShell_SudoFailsFast(t *testing.T) {
|
||||
shell := NewEinoStreamingShell()
|
||||
cmd := PrepareNonInteractiveShellCommand("sudo whoami && sudo cat /etc/os-release")
|
||||
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
start := time.Now()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
if time.Since(start) > 5*time.Second {
|
||||
t.Fatalf("sudo should fail quickly, took %v output=%q", time.Since(start), got.String())
|
||||
}
|
||||
out := got.String()
|
||||
if strings.Contains(out, "command exited with non-zero code") {
|
||||
t.Fatalf("legacy exit line present: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "sudo") && !strings.Contains(out, "password") && !strings.Contains(out, "terminal") {
|
||||
t.Fatalf("expected sudo error text, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShell_StderrWhileStdoutBlocks(t *testing.T) {
|
||||
shell := NewEinoStreamingShell()
|
||||
// 模拟 sudo:stderr 先有输出,stdout 侧进程仍挂起;旧 eino local 在首包 stderr 前不会向流写任何内容。
|
||||
cmd := PrepareNonInteractiveShellCommand(`echo "password prompt" >&2; sleep 30`)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sr, err := shell.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: cmd})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
start := time.Now()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
break
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
if strings.Contains(got.String(), "password prompt") {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if time.Since(start) > 1500*time.Millisecond {
|
||||
t.Fatalf("expected stderr promptly, took %v output=%q", time.Since(start), got.String())
|
||||
}
|
||||
if !strings.Contains(got.String(), "password prompt") {
|
||||
t.Fatalf("expected early stderr, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestEinoStreamingShell_BackgroundJobDoesNotHoldPipe 模拟 cmd & 后继续前台逻辑:重定向后应快速结束。
|
||||
func TestEinoStreamingShell_BackgroundJobDoesNotHoldPipe(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
shell := NewEinoStreamingShell()
|
||||
cmd := `(sh -c 'printf x; sleep 120') & echo started; sleep 0`
|
||||
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
start := time.Now()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if time.Since(start) > 3*time.Second {
|
||||
t.Fatalf("expected fast completion, took %v output=%q", time.Since(start), got.String())
|
||||
}
|
||||
if !strings.Contains(got.String(), "started") {
|
||||
t.Fatalf("expected foreground echo, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShellNoOutputTimeoutMessage 长时间无新 stdout/stderr 时的提示(软失败,模型可见)。
|
||||
func ShellNoOutputTimeoutMessage(idleSec int) string {
|
||||
return fmt.Sprintf(`命令已终止:超过 %d 秒没有新的输出,疑似在等待交互输入或已挂起。
|
||||
|
||||
长时静默任务请使用末尾 & 后台运行,或增大 agent.shell_no_output_timeout_seconds(-1=关闭此检测)。
|
||||
|
||||
Command terminated: no new output for %d seconds (possible interactive wait or hung process).`, idleSec, idleSec)
|
||||
}
|
||||
|
||||
// ShellInactivityWatch 在 noOutputSec 内无任何新输出时向 expired 发送信号;每次 Bump 重置计时。
|
||||
// 与「仅有首包输出就永久取消计时」不同,可兜住 sudo 打印 Password 提示后继续挂起等情况。
|
||||
type ShellInactivityWatch struct {
|
||||
Sec int
|
||||
mu sync.Mutex
|
||||
timer *time.Timer
|
||||
Expired chan struct{}
|
||||
}
|
||||
|
||||
func NewShellInactivityWatch(noOutputSec int) *ShellInactivityWatch {
|
||||
sec := ResolveShellNoOutputTimeoutSeconds(noOutputSec)
|
||||
if sec <= 0 {
|
||||
return nil
|
||||
}
|
||||
w := &ShellInactivityWatch{
|
||||
Sec: sec,
|
||||
Expired: make(chan struct{}, 1),
|
||||
}
|
||||
w.Bump()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *ShellInactivityWatch) Bump() {
|
||||
if w == nil || w.Sec <= 0 {
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
}
|
||||
w.timer = time.AfterFunc(time.Duration(w.Sec)*time.Second, func() {
|
||||
select {
|
||||
case w.Expired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (w *ShellInactivityWatch) Stop() {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
w.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveShellNoOutputTimeoutSeconds:0=默认 300(5 分钟);-1=关闭;>0=自定义。
|
||||
func ResolveShellNoOutputTimeoutSeconds(sec int) int {
|
||||
if sec < 0 {
|
||||
return 0
|
||||
}
|
||||
if sec == 0 {
|
||||
return 300
|
||||
}
|
||||
return sec
|
||||
}
|
||||
|
||||
// PrependNonInteractiveShellExports 为 sh -c 注入通用非交互环境(pager 等),不维护命令黑名单。
|
||||
func PrependNonInteractiveShellExports(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
upper := strings.ToUpper(shellCommand)
|
||||
var pairs []string
|
||||
add := func(key, val string) {
|
||||
if strings.Contains(upper, strings.ToUpper(key)) {
|
||||
return
|
||||
}
|
||||
pairs = append(pairs, key+"="+val)
|
||||
}
|
||||
add("GIT_PAGER", "cat")
|
||||
add("PAGER", "cat")
|
||||
add("SYSTEMD_PAGER", "cat")
|
||||
add("DEBIAN_FRONTEND", "noninteractive")
|
||||
if len(pairs) == 0 {
|
||||
return shellCommand
|
||||
}
|
||||
return "export " + strings.Join(pairs, " ") + "\n" + shellCommand
|
||||
}
|
||||
|
||||
// PrependNonInteractiveStdinRedirect 为 sh -c 关闭 stdin(与 attachNonInteractiveStdin 等价),
|
||||
// 使 read/input()/sudo -S 等从 stdin 读取的程序快速失败而非挂起。已含 </dev/null 时不重复注入。
|
||||
func PrependNonInteractiveStdinRedirect(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
lower := strings.ToLower(shellCommand)
|
||||
if strings.Contains(lower, "</dev/null") || strings.Contains(lower, "0</dev/null") {
|
||||
return shellCommand
|
||||
}
|
||||
return "exec </dev/null\n" + shellCommand
|
||||
}
|
||||
|
||||
// PrepareNonInteractiveShellCommand 组合非交互包装:stdin 关闭 + pager 等环境变量(零名单)。
|
||||
func PrepareNonInteractiveShellCommand(shellCommand string) string {
|
||||
return PrependNonInteractiveStdinRedirect(PrependNonInteractiveShellExports(shellCommand))
|
||||
}
|
||||
|
||||
// ApplyNonInteractivePagerEnv 为 exec.Cmd 补齐与 PrependNonInteractiveShellExports 一致的环境变量。
|
||||
func ApplyNonInteractivePagerEnv(cmdEnv []string) []string {
|
||||
if cmdEnv == nil {
|
||||
cmdEnv = []string{}
|
||||
}
|
||||
has := func(k string) bool {
|
||||
prefix := k + "="
|
||||
for _, e := range cmdEnv {
|
||||
if strings.HasPrefix(e, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !has("GIT_PAGER") {
|
||||
cmdEnv = append(cmdEnv, "GIT_PAGER=cat")
|
||||
}
|
||||
if !has("PAGER") {
|
||||
cmdEnv = append(cmdEnv, "PAGER=cat")
|
||||
}
|
||||
if !has("SYSTEMD_PAGER") {
|
||||
cmdEnv = append(cmdEnv, "SYSTEMD_PAGER=cat")
|
||||
}
|
||||
if !has("DEBIAN_FRONTEND") {
|
||||
cmdEnv = append(cmdEnv, "DEBIAN_FRONTEND=noninteractive")
|
||||
}
|
||||
return cmdEnv
|
||||
}
|
||||
|
||||
// attachNonInteractiveStdin 关闭交互式 stdin,使部分命令快速失败而非等待输入。
|
||||
func attachNonInteractiveStdin(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Stdin != nil {
|
||||
return
|
||||
}
|
||||
f, err := os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cmd.Stdin = f
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPrependNonInteractiveShellExports(t *testing.T) {
|
||||
out := PrependNonInteractiveShellExports("echo hi")
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") || !strings.Contains(out, "PAGER=cat") {
|
||||
t.Fatalf("missing pager exports: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||
t.Fatalf("command suffix lost: %q", out)
|
||||
}
|
||||
skip := PrependNonInteractiveShellExports("GIT_PAGER=less echo hi")
|
||||
if strings.Contains(skip, "export GIT_PAGER=cat") {
|
||||
t.Fatalf("should not override existing GIT_PAGER: %q", skip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrependNonInteractiveStdinRedirect(t *testing.T) {
|
||||
out := PrependNonInteractiveStdinRedirect("echo hi")
|
||||
if !strings.HasPrefix(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||
t.Fatalf("command suffix lost: %q", out)
|
||||
}
|
||||
skip := PrependNonInteractiveStdinRedirect("cmd </dev/null")
|
||||
if strings.HasPrefix(skip, "exec </dev/null") {
|
||||
t.Fatalf("should not double redirect: %q", skip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareNonInteractiveShellCommand(t *testing.T) {
|
||||
out := PrepareNonInteractiveShellCommand("echo hi")
|
||||
if !strings.Contains(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") {
|
||||
t.Fatalf("missing pager export: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewShellInactivityWatch(t *testing.T) {
|
||||
w := NewShellInactivityWatch(1)
|
||||
if w == nil {
|
||||
t.Fatal("expected watch")
|
||||
}
|
||||
w.Bump()
|
||||
select {
|
||||
case <-w.Expired:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("expected inactivity fire within 3s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveShellNoOutputTimeoutSeconds(t *testing.T) {
|
||||
if ResolveShellNoOutputTimeoutSeconds(0) != 300 {
|
||||
t.Fatal("zero should default to 300")
|
||||
}
|
||||
if ResolveShellNoOutputTimeoutSeconds(-1) != 0 {
|
||||
t.Fatal("-1 should disable")
|
||||
}
|
||||
if ResolveShellNoOutputTimeoutSeconds(30) != 30 {
|
||||
t.Fatal("explicit value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonInteractiveStdinReadExitsQuickly 验证 exec </dev/null + attachNonInteractiveStdin 时 read 立即 EOF,不挂起。
|
||||
func TestNonInteractiveStdinReadExitsQuickly(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", PrepareNonInteractiveShellCommand(`read x; echo "x=<$x>"`))
|
||||
attachNonInteractiveStdin(cmd)
|
||||
|
||||
start := time.Now()
|
||||
out, err := cmd.CombinedOutput()
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > 2*time.Second {
|
||||
t.Fatalf("read with closed stdin took %v, want <2s", elapsed)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v output=%q", err, out)
|
||||
}
|
||||
if !strings.Contains(string(out), "x=<>") {
|
||||
t.Fatalf("unexpected output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonInteractiveStdinReadBlocksWithoutRedirect 对照:stdin 为永不写入的管道时 read 会挂起。
|
||||
func TestNonInteractiveStdinReadBlocksWithoutRedirect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
// 保持 w 打开且不写数据,模拟「等待用户输入」
|
||||
|
||||
cmd := exec.Command("sh", "-c", `read x; echo done`)
|
||||
cmd.Stdin = r
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- cmd.Run() }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
t.Fatalf("expected hang, but command finished: %v", err)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = w.Close()
|
||||
<-done // 等待 goroutine 退出
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package security
|
||||
|
||||
import "os/exec"
|
||||
|
||||
// ShellSession 在 Start 时记录根 shell 的进程组 ID,取消/超时时可杀整组(即使 cmd.Process 已失效)。
|
||||
type ShellSession struct {
|
||||
Cmd *exec.Cmd
|
||||
rootPID int
|
||||
}
|
||||
|
||||
// StartShellSession 配置独立进程组并启动 shell,缓存 rootPID(Unix 下即 PGID)。
|
||||
func StartShellSession(cmd *exec.Cmd) (*ShellSession, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pid := 0
|
||||
if cmd.Process != nil {
|
||||
pid = cmd.Process.Pid
|
||||
}
|
||||
return &ShellSession{Cmd: cmd, rootPID: pid}, nil
|
||||
}
|
||||
|
||||
// Wait 等待 shell 退出。
|
||||
func (s *ShellSession) Wait() error {
|
||||
if s == nil || s.Cmd == nil {
|
||||
return nil
|
||||
}
|
||||
return s.Cmd.Wait()
|
||||
}
|
||||
|
||||
// Terminate 终止 shell 及其进程组。
|
||||
func (s *ShellSession) Terminate() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
terminateProcessGroup(s.rootPID, s.Cmd)
|
||||
}
|
||||
|
||||
// TerminateShellSession 终止由 StartShellSession 启动的会话。
|
||||
func TerminateShellSession(session *ShellSession) {
|
||||
if session != nil {
|
||||
session.Terminate()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestShellSession_TerminateUsesCachedRootPID(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix process group kill")
|
||||
}
|
||||
|
||||
cmd := exec.Command("sh", "-c", "sleep 300")
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("StartShellSession: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
session.Terminate()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- session.Wait() }()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("session did not finish within 5s after Terminate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellSession_TerminateAfterContextCancel(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix process group kill")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", "sleep 300")
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("StartShellSession: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
TerminateShellCmdSession(session)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- session.Wait() }()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("session did not finish within 5s after cancel+terminate")
|
||||
}
|
||||
}
|
||||
+63
-28
@@ -1,60 +1,95 @@
|
||||
name: "hydra"
|
||||
command: "hydra"
|
||||
args: ["-I"]
|
||||
enabled: true
|
||||
short_description: "密码暴力破解工具,支持多种协议和服务"
|
||||
description: |
|
||||
Hydra是一个快速的网络登录破解工具,支持多种协议和服务的密码暴力破解。
|
||||
Hydra 是网络登录口令爆破工具,支持 SSH、FTP、HTTP、SMB 等多种协议。
|
||||
|
||||
**主要功能:**
|
||||
- 支持多种协议(SSH, FTP, HTTP, SMB等)
|
||||
- 快速并行破解
|
||||
- 支持用户名和密码字典
|
||||
- 可恢复的会话
|
||||
**调用约定(必读):**
|
||||
- 必须提供 **用户名**:`username`(-l)或 `username_file`(-L)至少其一
|
||||
- 必须提供 **口令**:`password`(-p)、`password_file`(-P)或 `-C`(经 `additional_args`)至少其一
|
||||
- **先用小字典试跑**(几十~几百条),确认目标可达再扩大;禁止默认使用 rockyou 等超大字典
|
||||
- 默认已启用:找到即停(-f)、并行 4(-t)、忽略 restore(-I);长任务请设 `output_file`
|
||||
|
||||
**使用场景:**
|
||||
- 密码强度测试
|
||||
- 渗透测试
|
||||
- 安全评估
|
||||
- 弱密码检测
|
||||
**CLI 顺序:** `hydra [选项] <target> <service>`(本工具已按此顺序组参,勿把 target 写在选项前)
|
||||
|
||||
**使用场景:** 授权环境下的弱口令检测、密码强度评估
|
||||
|
||||
**注意:** 仅用于已授权目标;对无响应目标请减小 `wait_time` 或缩小字典,避免长时间挂起。
|
||||
parameters:
|
||||
- name: "target"
|
||||
type: "string"
|
||||
description: "目标IP或主机名"
|
||||
required: true
|
||||
position: 0
|
||||
format: "positional"
|
||||
- name: "service"
|
||||
type: "string"
|
||||
description: "服务类型(ssh, ftp, http等)"
|
||||
required: true
|
||||
position: 1
|
||||
format: "positional"
|
||||
- name: "username"
|
||||
type: "string"
|
||||
description: "单个用户名"
|
||||
description: "单个用户名(-l);与 username_file 二选一至少填一个"
|
||||
required: false
|
||||
flag: "-l"
|
||||
format: "flag"
|
||||
- name: "username_file"
|
||||
type: "string"
|
||||
description: "用户名字典文件"
|
||||
description: "用户名字典文件(-L)"
|
||||
required: false
|
||||
flag: "-L"
|
||||
format: "flag"
|
||||
- name: "password"
|
||||
type: "string"
|
||||
description: "单个密码"
|
||||
description: "单个密码(-p)"
|
||||
required: false
|
||||
flag: "-p"
|
||||
format: "flag"
|
||||
- name: "password_file"
|
||||
type: "string"
|
||||
description: "密码字典文件"
|
||||
description: "密码字典文件(-P);优先使用小字典试跑"
|
||||
required: false
|
||||
flag: "-P"
|
||||
format: "flag"
|
||||
- name: "stop_on_first"
|
||||
type: "bool"
|
||||
description: "找到一对有效账密后立即退出(-f,默认 true)"
|
||||
required: false
|
||||
flag: "-f"
|
||||
format: "flag"
|
||||
default: true
|
||||
- name: "tasks"
|
||||
type: "int"
|
||||
description: "每目标并行连接数(-t);SSH 等建议 4,默认 4"
|
||||
required: false
|
||||
flag: "-t"
|
||||
format: "flag"
|
||||
default: 4
|
||||
- name: "wait_time"
|
||||
type: "int"
|
||||
description: "单次连接等待响应秒数(-w),默认 16(低于 Hydra 默认 32,减少挂起感)"
|
||||
required: false
|
||||
flag: "-w"
|
||||
format: "flag"
|
||||
default: 16
|
||||
- name: "wait_between"
|
||||
type: "int"
|
||||
description: "每线程连接间隔秒数(-W),默认 1"
|
||||
required: false
|
||||
flag: "-W"
|
||||
format: "flag"
|
||||
default: 1
|
||||
- name: "output_file"
|
||||
type: "string"
|
||||
description: "将结果写入文件(-o),长任务建议指定"
|
||||
required: false
|
||||
flag: "-o"
|
||||
format: "flag"
|
||||
- name: "target"
|
||||
type: "string"
|
||||
description: "目标 IP、主机名或 CIDR(须在选项之后)"
|
||||
required: true
|
||||
position: 1
|
||||
format: "positional"
|
||||
- name: "service"
|
||||
type: "string"
|
||||
description: "服务类型(ssh、ftp、http-get、http-post-form、smb 等,见 hydra -h)"
|
||||
required: true
|
||||
position: 2
|
||||
format: "positional"
|
||||
- name: "additional_args"
|
||||
type: "string"
|
||||
description: "额外的Hydra参数"
|
||||
description: "额外参数(如 -s 端口、-S SSL、-m 模块选项、-C login:pass 文件),追加在命令末尾"
|
||||
required: false
|
||||
format: "positional"
|
||||
|
||||
+627
-15
@@ -1615,9 +1615,34 @@ header {
|
||||
|
||||
.conversation-search-box {
|
||||
position: relative;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.conversation-sidebar .sidebar-content {
|
||||
padding: 10px 16px 16px;
|
||||
}
|
||||
|
||||
.conversation-sidebar .conversation-search-box {
|
||||
margin-top: 8px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.conversation-sidebar .conversation-project-filter {
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.conversation-sidebar .conversation-groups-section {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.conversation-sidebar .recent-conversations-section {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.conversation-sidebar .section-header {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.conversation-search-box input {
|
||||
width: 100%;
|
||||
padding: 8px 32px 8px 12px;
|
||||
@@ -1668,6 +1693,170 @@ header {
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
.conversation-project-filter {
|
||||
margin-bottom: 12px;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.conversation-project-filter-label {
|
||||
display: block;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 500;
|
||||
color: var(--text-muted);
|
||||
margin-bottom: 4px;
|
||||
padding: 0 2px;
|
||||
}
|
||||
|
||||
.conversation-project-filter-native {
|
||||
position: absolute;
|
||||
width: 1px;
|
||||
height: 1px;
|
||||
padding: 0;
|
||||
margin: -1px;
|
||||
overflow: hidden;
|
||||
clip: rect(0, 0, 0, 0);
|
||||
white-space: nowrap;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
.conversation-project-filter-ui {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.conversation-project-filter-trigger {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
padding: 8px 10px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 6px;
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.25;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.conversation-project-filter-trigger:hover:not(:disabled) {
|
||||
border-color: var(--accent-color);
|
||||
}
|
||||
|
||||
.conversation-project-filter-ui.open .conversation-project-filter-trigger {
|
||||
outline: none;
|
||||
border-color: var(--accent-color);
|
||||
box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1);
|
||||
}
|
||||
|
||||
.conversation-project-filter-ui.open {
|
||||
z-index: 120;
|
||||
}
|
||||
|
||||
.conversation-project-filter-value {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.conversation-project-filter-caret {
|
||||
flex-shrink: 0;
|
||||
color: var(--text-secondary);
|
||||
transition: transform 0.15s ease;
|
||||
}
|
||||
|
||||
.conversation-project-filter-ui.open .conversation-project-filter-caret {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
.conversation-project-filter-dropdown {
|
||||
display: none;
|
||||
position: absolute;
|
||||
top: calc(100% + 4px);
|
||||
left: 0;
|
||||
right: 0;
|
||||
z-index: 200;
|
||||
max-height: 280px;
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
padding: 4px;
|
||||
background: var(--bg-primary);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
}
|
||||
|
||||
.conversation-project-filter-ui.open .conversation-project-filter-dropdown {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.conversation-project-filter-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
padding: 8px 10px;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
background: transparent;
|
||||
color: var(--text-primary);
|
||||
font-size: 0.8125rem;
|
||||
font-family: inherit;
|
||||
cursor: pointer;
|
||||
text-align: left;
|
||||
transition: background 0.12s ease, color 0.12s ease;
|
||||
}
|
||||
|
||||
.conversation-project-filter-option:hover {
|
||||
background: var(--bg-secondary);
|
||||
}
|
||||
|
||||
.conversation-project-filter-option.is-selected {
|
||||
background: rgba(0, 102, 255, 0.08);
|
||||
color: var(--accent-color);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.conversation-project-filter-check {
|
||||
width: 14px;
|
||||
flex-shrink: 0;
|
||||
opacity: 0;
|
||||
font-size: 0.75rem;
|
||||
line-height: 1;
|
||||
color: var(--accent-color);
|
||||
}
|
||||
|
||||
.conversation-project-filter-option.is-selected .conversation-project-filter-check {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.conversation-project-filter-option-label {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.conversation-item-project-badge {
|
||||
font-size: 0.6875rem;
|
||||
color: var(--text-muted);
|
||||
margin-top: 2px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
line-height: 1.3;
|
||||
}
|
||||
|
||||
.conversations-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@@ -2456,16 +2645,68 @@ header {
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.mcp-call-buttons {
|
||||
.mcp-call-buttons,
|
||||
.mcp-call-toolbar {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.mcp-tool-list {
|
||||
display: none;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-top: 8px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.mcp-tool-list.expanded {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.mcp-tools-toggle-btn {
|
||||
background: rgba(25, 118, 210, 0.1) !important;
|
||||
border-color: rgba(25, 118, 210, 0.35) !important;
|
||||
color: #1976d2 !important;
|
||||
}
|
||||
|
||||
.mcp-call-toolbar .process-detail-btn,
|
||||
.mcp-call-toolbar .mcp-tools-toggle-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 6px;
|
||||
min-height: 32px;
|
||||
padding: 6px 12px;
|
||||
font-size: 0.8125rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.25;
|
||||
box-sizing: border-box;
|
||||
white-space: nowrap;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.mcp-call-toolbar .process-detail-btn span,
|
||||
.mcp-call-toolbar .mcp-tools-toggle-btn span {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
line-height: 1.25;
|
||||
}
|
||||
|
||||
.mcp-tools-toggle-btn:hover {
|
||||
background: rgba(25, 118, 210, 0.18) !important;
|
||||
border-color: #1976d2 !important;
|
||||
color: #1565c0 !important;
|
||||
}
|
||||
|
||||
.process-detail-btn {
|
||||
background: rgba(156, 39, 176, 0.1) !important;
|
||||
border-color: rgba(156, 39, 176, 0.3) !important;
|
||||
color: #9c27b0 !important;
|
||||
}
|
||||
|
||||
.mcp-call-toolbar .process-detail-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
@@ -11144,6 +11385,7 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
.conversation-groups-section,
|
||||
.recent-conversations-section {
|
||||
margin-bottom: 24px;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.conversation-groups-section:last-child,
|
||||
@@ -11157,6 +11399,8 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
justify-content: space-between;
|
||||
margin-bottom: 12px;
|
||||
padding: 0 8px;
|
||||
min-width: 0;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.section-header-actions {
|
||||
@@ -11285,6 +11529,21 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.recent-conversations-section .section-title {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.recent-conversations-section .section-title.section-title--filtered {
|
||||
text-transform: none;
|
||||
letter-spacing: normal;
|
||||
font-size: 0.875rem;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.add-group-btn,
|
||||
.batch-manage-btn {
|
||||
width: 24px;
|
||||
@@ -11677,7 +11936,7 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
|
||||
/* 批量管理模态框 */
|
||||
.batch-manage-modal-content {
|
||||
max-width: 800px;
|
||||
max-width: 920px;
|
||||
width: 90vw;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@@ -11687,7 +11946,23 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
.batch-manage-header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
gap: 10px;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.batch-manage-header-actions .conversation-project-filter-ui {
|
||||
width: 148px;
|
||||
min-width: 108px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.batch-manage-header-actions .conversation-project-filter-trigger {
|
||||
font-size: 0.8125rem;
|
||||
padding: 8px 10px;
|
||||
}
|
||||
|
||||
.batch-manage-modal-content .conversation-project-filter-ui.open {
|
||||
z-index: 400;
|
||||
}
|
||||
|
||||
.batch-search-box {
|
||||
@@ -11731,8 +12006,8 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
|
||||
.batch-table-header {
|
||||
display: grid;
|
||||
grid-template-columns: 40px 1fr 180px 80px;
|
||||
gap: 16px;
|
||||
grid-template-columns: 40px minmax(0, 1.2fr) minmax(0, 0.9fr) 160px 72px;
|
||||
gap: 12px;
|
||||
padding: 12px 16px;
|
||||
background: var(--bg-secondary);
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
@@ -11750,8 +12025,8 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
|
||||
.batch-conversation-row {
|
||||
display: grid;
|
||||
grid-template-columns: 40px 1fr 180px 80px;
|
||||
gap: 16px;
|
||||
grid-template-columns: 40px minmax(0, 1.2fr) minmax(0, 0.9fr) 160px 72px;
|
||||
gap: 12px;
|
||||
padding: 12px 16px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
align-items: center;
|
||||
@@ -11778,6 +12053,20 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
/* 完全依赖JavaScript截断,禁用CSS的ellipsis以避免在UTF-8多字节字符中间截断 */
|
||||
}
|
||||
|
||||
.batch-table-col-project {
|
||||
font-size: 0.8125rem;
|
||||
color: var(--text-muted);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.batch-table-col-project.is-unbound {
|
||||
color: var(--text-muted);
|
||||
font-style: italic;
|
||||
opacity: 0.85;
|
||||
}
|
||||
|
||||
.batch-table-col-time {
|
||||
font-size: 0.875rem;
|
||||
color: var(--text-muted);
|
||||
@@ -11797,34 +12086,44 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
background: transparent;
|
||||
color: var(--text-muted);
|
||||
cursor: pointer;
|
||||
border-radius: 4px;
|
||||
border-radius: 6px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.batch-delete-btn svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.batch-delete-btn:hover {
|
||||
background: rgba(220, 53, 69, 0.1);
|
||||
color: var(--error-color);
|
||||
}
|
||||
|
||||
.batch-delete-btn:hover svg {
|
||||
transform: scale(1.08);
|
||||
}
|
||||
|
||||
.batch-delete-btn:active {
|
||||
background: rgba(220, 53, 69, 0.2);
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.batch-manage-footer {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
justify-content: flex-end;
|
||||
padding: 16px 24px;
|
||||
border-top: 1px solid var(--border-color);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.select-all-checkbox {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
.batch-table-col-checkbox input[type="checkbox"] {
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.batch-footer-actions {
|
||||
@@ -19682,6 +19981,158 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.12);
|
||||
}
|
||||
|
||||
.vuln-filter-native-select {
|
||||
position: absolute;
|
||||
width: 1px;
|
||||
height: 1px;
|
||||
padding: 0;
|
||||
margin: -1px;
|
||||
overflow: hidden;
|
||||
clip: rect(0, 0, 0, 0);
|
||||
white-space: nowrap;
|
||||
border: 0;
|
||||
}
|
||||
|
||||
.vulnerability-filter-field--project .vuln-filter-select,
|
||||
.vulnerability-filter-field--status .vuln-filter-select {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.vulnerability-filter-field--project .vuln-filter-select {
|
||||
min-width: 132px;
|
||||
max-width: 180px;
|
||||
}
|
||||
|
||||
.vulnerability-filter-field--status .vuln-filter-select {
|
||||
min-width: 112px;
|
||||
}
|
||||
|
||||
.vuln-filter-select-trigger {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
padding: 8px 10px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.25;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
transition: border-color 0.15s ease, box-shadow 0.15s ease;
|
||||
}
|
||||
|
||||
.vuln-filter-select-trigger:hover:not(:disabled) {
|
||||
border-color: rgba(59, 130, 246, 0.45);
|
||||
}
|
||||
|
||||
.vuln-filter-select.open .vuln-filter-select-trigger {
|
||||
border-color: #3b82f6;
|
||||
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.12);
|
||||
}
|
||||
|
||||
.vuln-filter-select.open {
|
||||
z-index: 120;
|
||||
}
|
||||
|
||||
.vuln-filter-select-value {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.vuln-filter-select-caret {
|
||||
flex-shrink: 0;
|
||||
color: var(--text-secondary);
|
||||
transition: transform 0.15s ease;
|
||||
}
|
||||
|
||||
.vuln-filter-select.open .vuln-filter-select-caret {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
.vuln-filter-select-dropdown {
|
||||
display: none;
|
||||
position: absolute;
|
||||
top: calc(100% + 4px);
|
||||
left: 0;
|
||||
right: 0;
|
||||
z-index: 200;
|
||||
max-height: 280px;
|
||||
overflow-y: auto;
|
||||
padding: 4px;
|
||||
background: var(--bg-primary);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
}
|
||||
|
||||
.vuln-filter-select.open .vuln-filter-select-dropdown {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.vuln-filter-select-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
padding: 8px 10px;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
background: transparent;
|
||||
color: var(--text-primary);
|
||||
font-size: 0.8125rem;
|
||||
font-family: inherit;
|
||||
cursor: pointer;
|
||||
text-align: left;
|
||||
transition: background 0.12s ease, color 0.12s ease;
|
||||
}
|
||||
|
||||
.vuln-filter-select-option:hover {
|
||||
background: var(--bg-secondary);
|
||||
}
|
||||
|
||||
.vuln-filter-select-option.is-selected {
|
||||
background: rgba(59, 130, 246, 0.08);
|
||||
color: #2563eb;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.vuln-filter-select-check {
|
||||
width: 14px;
|
||||
flex-shrink: 0;
|
||||
opacity: 0;
|
||||
font-size: 0.75rem;
|
||||
line-height: 1;
|
||||
color: #2563eb;
|
||||
}
|
||||
|
||||
.vuln-filter-select-option.is-selected .vuln-filter-select-check {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.vuln-filter-select-label {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.vuln-filter-select.is-disabled .vuln-filter-select-trigger {
|
||||
opacity: 0.55;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.vulnerability-filter-clear-btn[hidden] {
|
||||
display: none !important;
|
||||
}
|
||||
@@ -19982,6 +20433,167 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
color: #868e96;
|
||||
}
|
||||
|
||||
.vuln-status-picker {
|
||||
position: relative;
|
||||
display: inline-flex;
|
||||
vertical-align: middle;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.vuln-status-picker.open {
|
||||
z-index: 120;
|
||||
}
|
||||
|
||||
.vuln-status-picker-trigger {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
padding: 4px 8px 4px 10px;
|
||||
border-radius: 12px;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
line-height: 1.3;
|
||||
border: 1px solid transparent;
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
max-width: 148px;
|
||||
transition: opacity 0.15s ease, border-color 0.15s ease, box-shadow 0.15s ease, background 0.15s ease;
|
||||
background: transparent;
|
||||
color: inherit;
|
||||
}
|
||||
|
||||
.vuln-status-picker-value {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.vuln-status-picker-caret {
|
||||
flex-shrink: 0;
|
||||
opacity: 0.8;
|
||||
transition: transform 0.15s ease;
|
||||
}
|
||||
|
||||
.vuln-status-picker.open .vuln-status-picker-caret {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
|
||||
.vuln-status-picker.open .vuln-status-picker-trigger {
|
||||
box-shadow: 0 0 0 2px rgba(0, 102, 255, 0.12);
|
||||
}
|
||||
|
||||
.vuln-status-picker.status-open .vuln-status-picker-trigger {
|
||||
background: rgba(0, 102, 255, 0.1);
|
||||
color: #0066ff;
|
||||
border-color: rgba(0, 102, 255, 0.22);
|
||||
}
|
||||
|
||||
.vuln-status-picker.status-confirmed .vuln-status-picker-trigger {
|
||||
background: rgba(40, 167, 69, 0.1);
|
||||
color: #28a745;
|
||||
border-color: rgba(40, 167, 69, 0.22);
|
||||
}
|
||||
|
||||
.vuln-status-picker.status-fixed .vuln-status-picker-trigger {
|
||||
background: rgba(108, 117, 125, 0.1);
|
||||
color: #6c757d;
|
||||
border-color: rgba(108, 117, 125, 0.22);
|
||||
}
|
||||
|
||||
.vuln-status-picker.status-false_positive .vuln-status-picker-trigger {
|
||||
background: rgba(220, 53, 69, 0.1);
|
||||
color: #dc3545;
|
||||
border-color: rgba(220, 53, 69, 0.22);
|
||||
}
|
||||
|
||||
.vuln-status-picker.status-ignored .vuln-status-picker-trigger {
|
||||
background: rgba(108, 117, 125, 0.12);
|
||||
color: #868e96;
|
||||
border-color: rgba(108, 117, 125, 0.22);
|
||||
}
|
||||
|
||||
.vuln-status-picker-trigger:hover:not(:disabled) {
|
||||
filter: brightness(0.97);
|
||||
}
|
||||
|
||||
.vuln-status-picker.is-disabled .vuln-status-picker-trigger {
|
||||
opacity: 0.65;
|
||||
cursor: wait;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.vuln-status-picker-menu {
|
||||
position: absolute;
|
||||
top: calc(100% + 6px);
|
||||
left: 0;
|
||||
min-width: 136px;
|
||||
z-index: 200;
|
||||
background: var(--bg-primary);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
padding: 4px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
}
|
||||
|
||||
.vuln-status-picker-menu[hidden] {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.vuln-status-picker-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
width: 100%;
|
||||
padding: 8px 10px;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
background: transparent;
|
||||
color: var(--text-primary);
|
||||
font-size: 0.8125rem;
|
||||
font-family: inherit;
|
||||
cursor: pointer;
|
||||
text-align: left;
|
||||
transition: background 0.12s ease, color 0.12s ease;
|
||||
}
|
||||
|
||||
.vuln-status-picker-option:hover {
|
||||
background: var(--bg-secondary);
|
||||
}
|
||||
|
||||
.vuln-status-picker-option.is-selected {
|
||||
background: rgba(0, 102, 255, 0.08);
|
||||
color: var(--accent-color);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.vuln-status-picker-check {
|
||||
width: 14px;
|
||||
flex-shrink: 0;
|
||||
opacity: 0;
|
||||
font-size: 0.75rem;
|
||||
line-height: 1;
|
||||
color: var(--accent-color);
|
||||
}
|
||||
|
||||
.vuln-status-picker-option.is-selected .vuln-status-picker-check {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.vuln-status-picker-label {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.vulnerability-card--removing {
|
||||
opacity: 0;
|
||||
transform: scale(0.98);
|
||||
pointer-events: none;
|
||||
transition: opacity 0.18s ease, transform 0.18s ease;
|
||||
}
|
||||
|
||||
.vulnerability-date {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-muted);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user