Compare commits

...

26 Commits

Author SHA1 Message Date
公明 b4810c9499 Update shell no output timeout to 1200 seconds
Increased the shell no output timeout from 300 seconds to 1200 seconds to prevent premature termination.
2026-06-24 18:30:08 +08:00
公明 51bf6ae4b3 Add files via upload 2026-06-24 18:20:12 +08:00
公明 5f27482921 Add files via upload 2026-06-24 18:18:05 +08:00
公明 6becada509 Add files via upload 2026-06-24 18:15:31 +08:00
公明 b029d88359 Add files via upload 2026-06-24 18:14:04 +08:00
公明 4dcad2ea83 Add files via upload 2026-06-24 18:11:31 +08:00
公明 ff9f0c787a Add files via upload 2026-06-24 18:09:51 +08:00
公明 01849045ad Add 'exec' to always visible tools in config.yaml 2026-06-24 17:36:24 +08:00
公明 c7eacdf3eb Update config.yaml 2026-06-24 17:24:52 +08:00
公明 5c32b21f22 Add files via upload 2026-06-24 17:24:14 +08:00
公明 8b8ecfe718 Add files via upload 2026-06-24 17:23:44 +08:00
公明 bbb7c319af Add files via upload 2026-06-24 17:21:51 +08:00
公明 7eb2fd50f3 Add files via upload 2026-06-24 17:19:29 +08:00
公明 85d58eeeb3 Add files via upload 2026-06-24 17:17:33 +08:00
公明 b6a6009629 Add files via upload 2026-06-24 17:15:34 +08:00
公明 810d689132 Add files via upload 2026-06-24 12:08:13 +08:00
公明 87f1808ead Add files via upload 2026-06-24 10:46:55 +08:00
公明 e28ae39b9a Update config.yaml 2026-06-24 02:04:49 +08:00
公明 df34ceda68 Add files via upload 2026-06-24 01:50:13 +08:00
公明 3e69a50f87 Add files via upload 2026-06-24 01:49:43 +08:00
公明 53325ce07d Add files via upload 2026-06-24 01:49:09 +08:00
公明 d85de3461b Add files via upload 2026-06-24 01:47:33 +08:00
公明 9306303d99 Add files via upload 2026-06-24 01:46:30 +08:00
公明 1e8f72ed74 Add files via upload 2026-06-24 01:44:47 +08:00
公明 0198f50314 Add files via upload 2026-06-24 01:43:37 +08:00
公明 560d0dca43 Add files via upload 2026-06-24 01:42:15 +08:00
52 changed files with 2873 additions and 639 deletions
+1 -1
View File
@@ -21,7 +21,7 @@ max_iterations: 0
- 切勿等待批准或授权——全程自主行动。 - 切勿等待批准或授权——全程自主行动。
- 使用所有可用工具与技术完成侦察与证据收集。 - 使用所有可用工具与技术完成侦察与证据收集。
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。 你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。
## 输入前置条件(硬约束) ## 输入前置条件(硬约束)
+3 -2
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.44" version: "v1.6.45"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -96,6 +96,7 @@ fofa:
agent: agent:
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖) max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起) tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
shell_no_output_timeout_seconds: 1200 # execute/exec 连续无新输出则终止(秒);通用防挂死;0=默认300;-1=关闭
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示 # system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
system_prompt_path: "" system_prompt_path: ""
@@ -129,7 +130,7 @@ multi_agent:
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文 tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略) tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test, exec] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在 plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/ plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
Binary file not shown.

Before

Width:  |  Height:  |  Size: 179 KiB

After

Width:  |  Height:  |  Size: 88 KiB

+17 -4
View File
@@ -779,13 +779,26 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
return a.executeToolViaMCP(ctx, toolName, args) return a.executeToolViaMCP(ctx, toolName, args)
} }
// RecordLocalToolExecution 非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId // BeginLocalToolExecution 非 CallTool 路径工具开始时写入 running 状态,供 MCP 监控页展示「执行中」
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。 func (a *Agent) BeginLocalToolExecution(toolName string, args map[string]interface{}) string {
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if a == nil || a.mcpServer == nil { if a == nil || a.mcpServer == nil {
return "" return ""
} }
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr) return a.mcpServer.BeginToolExecution(toolName, args)
}
// FinishLocalToolExecution 完成 BeginLocalToolExecution 创建的记录;executionID 为空时一次性写入已完成记录。
func (a *Agent) FinishLocalToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if a == nil || a.mcpServer == nil {
return ""
}
return a.mcpServer.FinishToolExecution(executionID, toolName, args, resultText, invokeErr)
}
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
return a.FinishLocalToolExecution("", toolName, args, resultText, invokeErr)
} }
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。 // UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
@@ -113,5 +113,7 @@ func DefaultSingleAgentSystemPrompt() string {
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。 - 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。 - 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。` - 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。
` + projectprompt.ShellExecExecuteGuidanceSection()
} }
+3
View File
@@ -110,6 +110,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
// 创建安全工具执行器 // 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
executor.SetShellNoOutputTimeoutSeconds(cfg.Agent.ShellNoOutputTimeoutSeconds)
// 注册工具 // 注册工具
executor.RegisterTools(mcpServer) executor.RegisterTools(mcpServer)
@@ -333,6 +334,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
monitorHandler.SetAudit(auditSvc) monitorHandler.SetAudit(auditSvc)
monitorHandler.SetMonitorRetention(monitorRetention) monitorHandler.SetMonitorRetention(monitorRetention)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
monitorHandler.SetTaskManager(agentHandler.TaskManager())
monitorHandler.SetAgentHandler(agentHandler)
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger) notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger) groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
+3
View File
@@ -605,6 +605,8 @@ type DatabaseConfig struct {
type AgentConfig struct { type AgentConfig struct {
MaxIterations int `yaml:"max_iterations" json:"max_iterations"` MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐) ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。 // SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"` SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
} }
@@ -1272,6 +1274,7 @@ func Default() *Config {
Agent: AgentConfig{ Agent: AgentConfig{
MaxIterations: 30, // 默认最大迭代次数 MaxIterations: 30, // 默认最大迭代次数
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用 ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
}, },
Security: SecurityConfig{ Security: SecurityConfig{
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载 Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
+16 -12
View File
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
LastScheduleError sql.NullString LastScheduleError sql.NullString
LastRunError sql.NullString LastRunError sql.NullString
ProjectID sql.NullString ProjectID sql.NullString
Concurrency sql.NullInt64
Status string Status string
CreatedAt time.Time CreatedAt time.Time
StartedAt sql.NullTime StartedAt sql.NullTime
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
cronExpr string, cronExpr string,
nextRunAt *time.Time, nextRunAt *time.Time,
projectID string, projectID string,
concurrency int,
tasks []map[string]interface{}, tasks []map[string]interface{},
) error { ) error {
tx, err := db.Begin() tx, err := db.Begin()
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
projectIDVal = strings.TrimSpace(projectID) projectIDVal = strings.TrimSpace(projectID)
} }
_, err = tx.Exec( _, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0, queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
) )
if err != nil { if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err) return fmt.Errorf("创建批量任务队列失败: %w", err)
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
return tx.Commit() return tx.Commit()
} }
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
// GetBatchQueue 获取批量任务队列 // GetBatchQueue 获取批量任务队列
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
err := db.QueryRow( err := db.QueryRow(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
queueID, queueID,
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
// GetAllBatchQueues 获取所有批量任务队列 // GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
// ListBatchQueues 列出批量任务队列(支持筛选和分页) // ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
args := []interface{}{} args := []interface{}{}
// 状态筛选 // 状态筛选
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
return nil return nil
} }
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色代理模式 // UpdateBatchQueueMetadata 更新批量任务队列标题、角色代理模式和并发数
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error { func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
_, err := db.Exec( _, err := db.Exec(
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?", "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
title, role, agentMode, queueID, title, role, agentMode, concurrency, queueID,
) )
if err != nil { if err != nil {
return fmt.Errorf("更新批量任务队列元数据失败: %w", err) return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
+17
View File
@@ -408,6 +408,8 @@ func (db *DB) initTables() error {
last_schedule_trigger_at DATETIME, last_schedule_trigger_at DATETIME,
last_schedule_error TEXT, last_schedule_error TEXT,
last_run_error TEXT, last_run_error TEXT,
project_id TEXT,
concurrency INTEGER NOT NULL DEFAULT 1,
status TEXT NOT NULL, status TEXT NOT NULL,
created_at DATETIME NOT NULL, created_at DATETIME NOT NULL,
started_at DATETIME, started_at DATETIME,
@@ -1137,6 +1139,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
} }
} }
var concurrencyCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
}
}
} else if concurrencyCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
}
}
return nil return nil
} }
+86 -384
View File
@@ -21,7 +21,6 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
@@ -178,8 +177,6 @@ type AgentHandler struct {
} }
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
batchCronParser cron.Parser batchCronParser cron.Parser
batchRunnerMu sync.Mutex
batchRunning map[string]struct{}
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
hitlWhitelistSaver HitlToolWhitelistSaver hitlWhitelistSaver HitlToolWhitelistSaver
audit *audit.Service audit *audit.Service
@@ -190,6 +187,14 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
h.audit = s h.audit = s
} }
// TaskManager 返回 Agent 任务管理器(供 MCP 监控页终止 Eino execute 等)。
func (h *AgentHandler) TaskManager() *AgentTaskManager {
if h == nil {
return nil
}
return h.tasks
}
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent). // CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) { func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
if h == nil || conversationID == "" || h.tasks == nil { if h == nil || conversationID == "" || h.tasks == nil {
@@ -233,7 +238,6 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
config: cfg, config: cfg,
hitlManager: NewHITLManager(db, logger), hitlManager: NewHITLManager(db, logger),
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
batchRunning: make(map[string]struct{}),
} }
if err := handler.hitlManager.EnsureSchema(); err != nil { if err := handler.hitlManager.EnsureSchema(); err != nil {
logger.Warn("初始化 HITL 表失败", zap.Error(err)) logger.Warn("初始化 HITL 表失败", zap.Error(err))
@@ -1295,6 +1299,55 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} }
} }
// cancelToolContinueAfter 仅终止当前工具调用,不停止整条 Agent 任务(对话「中断并继续」与 MCP 监控终止共用)。
func (h *AgentHandler) cancelToolContinueAfter(conversationID, preferredExecID, note string) (bool, gin.H) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" || h.tasks.GetTask(conversationID) == nil {
return false, nil
}
note = strings.TrimSpace(note)
execID := strings.TrimSpace(preferredExecID)
if execID == "" {
execID = h.tasks.ActiveMCPExecutionID(conversationID)
}
if execID != "" {
if h.agent.CancelMCPToolExecutionWithNote(execID, note) {
return true, gin.H{
"status": "tool_abort_requested",
"conversationId": conversationID,
"executionId": execID,
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
}
}
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
return true, gin.H{
"status": "tool_abort_requested",
"conversationId": conversationID,
"executionId": execID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
}
}
return false, nil
}
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
return true, gin.H{
"status": "tool_abort_requested",
"conversationId": conversationID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
}
}
return false, nil
}
// CancelAgentLoop 取消正在执行的任务 // CancelAgentLoop 取消正在执行的任务
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) { func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
var req struct { var req struct {
@@ -1313,42 +1366,20 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"}) c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
return return
} }
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
note := strings.TrimSpace(req.Reason) note := strings.TrimSpace(req.Reason)
if execID != "" { activeExec := strings.TrimSpace(h.tasks.ActiveMCPExecutionID(req.ConversationID))
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) { if ok, payload := h.cancelToolContinueAfter(req.ConversationID, "", note); ok {
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"}) execID, _ := payload["executionId"].(string)
return h.logger.Info("对话页仅终止当前工具",
}
h.logger.Info("对话页仅终止当前 MCP 工具",
zap.String("conversationId", req.ConversationID), zap.String("conversationId", req.ConversationID),
zap.String("executionId", execID), zap.String("executionId", execID),
zap.Bool("hasNote", note != ""), zap.Bool("hasNote", note != ""),
) )
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, payload)
"status": "tool_abort_requested",
"conversationId": req.ConversationID,
"executionId": execID,
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
})
return return
} }
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) { if activeExec != "" {
h.logger.Info("对话页仅终止当前 Eino execute", c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
zap.String("conversationId", req.ConversationID),
zap.Bool("hasNote", note != ""),
)
c.JSON(http.StatusOK, gin.H{
"status": "tool_abort_requested",
"conversationId": req.ConversationID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
})
return return
} }
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。 // 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
@@ -1470,6 +1501,7 @@ type BatchTaskRequest struct {
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选) ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
} }
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。 // batchQueueWantsEino 队列是否配置为走 Eino 多代理。
@@ -1529,7 +1561,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
nextRunAt = &next nextRunAt = &next
} }
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks) queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
if createErr != nil { if createErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
return return
@@ -1722,12 +1754,13 @@ func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
Title string `json:"title"` Title string `json:"title"`
Role string `json:"role"` Role string `json:"role"`
AgentMode string `json:"agentMode"` AgentMode string `json:"agentMode"`
Concurrency *int `json:"concurrency"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil { if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@@ -1802,9 +1835,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
// DeleteBatchQueue 删除批量任务队列 // DeleteBatchQueue 删除批量任务队列
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
success := h.batchTaskManager.DeleteQueue(queueID) if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
if !success { switch {
case errors.Is(err, ErrBatchQueueNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
case errors.Is(err, ErrBatchQueueExecutorActive):
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
case errors.Is(err, ErrBatchQueueStillRunning):
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return return
} }
if h.audit != nil { if h.audit != nil {
@@ -1898,7 +1939,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动 // 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused { if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
h.forceUnmarkBatchQueueRunning(queueID) h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
} }
autoStarted := true autoStarted := true
@@ -1957,26 +1998,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
} }
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
h.batchRunnerMu.Lock()
defer h.batchRunnerMu.Unlock()
if _, exists := h.batchRunning[queueID]; exists {
return false
}
h.batchRunning[queueID] = struct{}{}
return true
}
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
h.batchRunnerMu.Lock()
defer h.batchRunnerMu.Unlock()
delete(h.batchRunning, queueID)
}
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
h.unmarkBatchQueueRunning(queueID)
}
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
expr := strings.TrimSpace(cronExpr) expr := strings.TrimSpace(cronExpr)
if expr == "" { if expr == "" {
@@ -1992,43 +2013,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断 // 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
if !h.markBatchQueueRunning(queueID) { if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
return true, nil return true, nil
} }
queue, exists := h.batchTaskManager.GetBatchQueue(queueID) queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists { if !exists {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
return false, nil return false, nil
} }
if scheduled { if scheduled {
if queue.ScheduleMode != "cron" { if queue.ScheduleMode != "cron" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("队列未启用 cron 调度") err := fmt.Errorf("队列未启用 cron 调度")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("当前队列状态不允许被调度执行") err := fmt.Errorf("当前队列状态不允许被调度执行")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
if !h.batchTaskManager.ResetQueueForRerun(queueID) { if !h.batchTaskManager.ResetQueueForRerun(queueID) {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("重置队列失败") err := fmt.Errorf("重置队列失败")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
queue, _ = h.batchTaskManager.GetBatchQueue(queueID) queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
} else if queue.Status != "pending" && queue.Status != "paused" { } else if queue.Status != "pending" && queue.Status != "paused" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
return true, fmt.Errorf("队列状态不允许启动") return true, fmt.Errorf("队列状态不允许启动")
} }
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) { if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理") err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
if scheduled { if scheduled {
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
@@ -2080,325 +2101,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
} }
} }
// executeBatchQueue 执行批量任务队列
func (h *AgentHandler) executeBatchQueue(queueID string) {
defer h.unmarkBatchQueueRunning(queueID)
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
for {
// 检查队列状态
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
break
}
// 获取下一个任务
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
if !hasNext {
// 所有任务完成:汇总子任务失败信息便于排障
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
lastRunErr := ""
if ok {
for _, t := range q.Tasks {
if t.Status == "failed" && t.Error != "" {
lastRunErr = t.Error
}
}
}
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
break
}
// 更新任务状态为运行中
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
// 创建新对话
title := safeTruncateString(task.Message, 50)
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
break
}
continue
}
conversationID = conv.ID
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
// 应用角色用户提示词和工具配置
finalMessage := task.Message
var roleTools []string // 角色配置的工具列表
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + task.Message
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
// 预先创建助手消息,以便关联过程详情
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
// 如果创建失败,继续执行但不保存过程详情
assistantMsg = nil
}
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
// 需要把进度事件镜像到 TaskEventBusGET /api/agent-loop/task-events 会订阅这里)。
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
var progressCallback func(eventType, message string, data interface{})
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
func() {
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
registered := false
finishStatus := "completed"
defer func() {
h.batchTaskManager.SetTaskCancel(queueID, nil)
timeoutCancel()
if registered {
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
if h.taskEventBus != nil {
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
if b, err := json.Marshal(ev); err == nil {
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
}
}
h.tasks.FinishTask(conversationID, finishStatus)
}
cancelWithCause(nil)
}()
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
sendEvent := func(eventType, message string, data interface{}) {
if h.taskEventBus == nil {
return
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, err := json.Marshal(ev)
if err != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
line := make([]byte, 0, len(b)+8)
line = append(line, []byte("data: ")...)
line = append(line, b...)
line = append(line, '\n', '\n')
h.taskEventBus.Publish(conversationID, line)
}
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
h.logger.Warn("批量队列子任务注册会话运行状态失败",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID),
zap.Error(err))
failMsg := err.Error()
if errors.Is(err, ErrTaskAlreadyRunning) {
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
return
}
registered = true
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
useBatchMulti := false
batchOrch := "deep"
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
if am == "multi" {
am = "deep"
}
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
useBatchMulti = true
batchOrch = config.NormalizeMultiAgentOrchestration(am)
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
useBatchMulti = true
batchOrch = "deep"
}
var resultMA *multiagent.RunResult
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
}
}
if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errStr := runErr.Error()
partialResp := ""
if resultMA != nil {
partialResp = resultMA.Response
}
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
errors.Is(runErr, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
if isTimeout {
finishStatus = "timeout"
} else if isCancelled {
finishStatus = "cancelled"
} else {
finishStatus = "failed"
}
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
// 如果执行结果中有更具体的取消消息,使用它
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
cancelMsg = partialResp
}
// 更新助手消息内容
if assistantMessageID != "" {
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存取消详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
if errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
} else {
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
errorMsg := "执行失败: " + runErr.Error()
// 更新助手消息内容
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
errorMsg,
time.Now(), assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存错误详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
}
} else {
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
resText := resultMA.Response
mcpIDs := resultMA.MCPExecutionIDs
lastIn := resultMA.LastAgentTraceInput
lastOut := resultMA.LastAgentTraceOutput
// 更新助手消息内容
if assistantMessageID != "" {
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
// 如果更新失败,尝试创建新消息
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
// 保存代理轨迹
if lastIn != "" || lastOut != "" {
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
// 保存结果
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
}
}()
// 移动到下一个任务
h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
break
}
// 检查是否被取消或暂停
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue.Status == "cancelled" || queue.Status == "paused" {
break
}
}
}
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。 // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。 // 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
+352
View File
@@ -0,0 +1,352 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
)
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
func (h *AgentHandler) executeBatchQueue(queueID string) {
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
return
}
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
h.runBatchQueueWorker(queueID)
}()
}
wg.Wait()
h.tryFinalizeBatchQueue(queueID)
}
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
for {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if batchQueueExecutionShouldStop(queue, exists) {
return
}
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
if !ok {
if !h.batchTaskManager.HasRunningTasks(queueID) {
return
}
time.Sleep(batchQueueWorkerIdlePoll)
continue
}
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue == nil {
return
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
h.executeOneBatchSubTask(queueID, queue, task)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
return
}
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
if batchQueueExecutionShouldStop(queue, exists) {
if !exists {
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
}
return
}
}
}
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue == nil {
return
}
if queue.Status != BatchQueueStatusRunning {
return
}
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
return
}
lastRunErr := ""
for _, t := range queue.Tasks {
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
lastRunErr = t.Error
}
}
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
}
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
title := safeTruncateString(task.Message, 50)
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
return
}
conversationID := conv.ID
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
finalMessage := task.Message
var roleTools []string
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + task.Message
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
}
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
assistantMsg = nil
}
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
registered := false
finishStatus := "completed"
defer func() {
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
timeoutCancel()
if registered {
if h.taskEventBus != nil {
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
if b, err := json.Marshal(ev); err == nil {
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
}
}
h.tasks.FinishTask(conversationID, finishStatus)
}
cancelWithCause(nil)
}()
sendEvent := func(eventType, message string, data interface{}) {
if h.taskEventBus == nil {
return
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, err := json.Marshal(ev)
if err != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
line := make([]byte, 0, len(b)+8)
line = append(line, []byte("data: ")...)
line = append(line, b...)
line = append(line, '\n', '\n')
h.taskEventBus.Publish(conversationID, line)
}
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
h.logger.Warn("批量队列子任务注册会话运行状态失败",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID),
zap.Error(err))
failMsg := err.Error()
if errors.Is(err, ErrTaskAlreadyRunning) {
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
return
}
registered = true
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
useBatchMulti := false
batchOrch := "deep"
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
if am == "multi" {
am = "deep"
}
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
useBatchMulti = true
batchOrch = config.NormalizeMultiAgentOrchestration(am)
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
useBatchMulti = true
batchOrch = "deep"
}
var resultMA *multiagent.RunResult
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
}
}
if runErr != nil {
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
return
}
if resultMA == nil {
h.logger.Error("批量任务执行成功但无结果对象",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
return
}
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
resText := resultMA.Response
mcpIDs := resultMA.MCPExecutionIDs
lastIn := resultMA.LastAgentTraceInput
lastOut := resultMA.LastAgentTraceOutput
if assistantMessageID != "" {
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
if lastIn != "" || lastOut != "" {
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
}
func (h *AgentHandler) handleBatchSubTaskRunError(
queueID string,
task *BatchTask,
conversationID, assistantMessageID string,
baseCtx, taskCtx context.Context,
resultMA *multiagent.RunResult,
runErr error,
finishStatus *string,
) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errStr := runErr.Error()
partialResp := ""
if resultMA != nil {
partialResp = resultMA.Response
}
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
errors.Is(runErr, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
if isTimeout {
*finishStatus = "timeout"
} else if isCancelled {
*finishStatus = "cancelled"
} else {
*finishStatus = "failed"
}
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
cancelMsg = partialResp
}
if assistantMessageID != "" {
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
return
}
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
errorMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
errorMsg,
time.Now(), assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
}
+215 -42
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@@ -17,6 +18,15 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
var (
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
ErrBatchQueueNotFound = errors.New("batch queue not found")
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
)
// 批量任务状态常量 // 批量任务状态常量
const ( const (
BatchQueueStatusPending = "pending" BatchQueueStatusPending = "pending"
@@ -39,6 +49,12 @@ const (
// MaxBatchQueueRoleLen 角色名最大长度 // MaxBatchQueueRoleLen 角色名最大长度
MaxBatchQueueRoleLen = 100 MaxBatchQueueRoleLen = 100
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
DefaultBatchQueueConcurrency = 1
// MaxBatchQueueConcurrency 批量队列最大并发数
MaxBatchQueueConcurrency = 8
) )
// BatchTask 批量任务项 // BatchTask 批量任务项
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
LastScheduleError string `json:"lastScheduleError,omitempty"` LastScheduleError string `json:"lastScheduleError,omitempty"`
LastRunError string `json:"lastRunError,omitempty"` LastRunError string `json:"lastRunError,omitempty"`
ProjectID string `json:"projectId,omitempty"` ProjectID string `json:"projectId,omitempty"`
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
Tasks []*BatchTask `json:"tasks"` Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
queues map[string]*BatchTaskQueue queues map[string]*BatchTaskQueue
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列 singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
mu sync.RWMutex mu sync.RWMutex
} }
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
return &BatchTaskManager{ return &BatchTaskManager{
logger: logger, logger: logger,
queues: make(map[string]*BatchTaskQueue), queues: make(map[string]*BatchTaskQueue),
taskCancels: make(map[string]context.CancelFunc), taskCancels: make(map[string]map[string]context.CancelFunc),
singleRunTasks: make(map[string]string), singleRunTasks: make(map[string]string),
queueExecutors: make(map[string]struct{}),
} }
} }
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
if !exists || queue == nil {
return true
}
switch queue.Status {
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
return true
default:
return false
}
}
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.queueExecutors[queueID]; exists {
return false
}
m.queueExecutors[queueID] = struct{}{}
return true
}
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.queueExecutors, queueID)
}
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
m.UnmarkQueueExecutor(queueID)
}
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.queueExecutors[queueID]
return ok
}
// SetDB 设置数据库连接 // SetDB 设置数据库连接
func (m *BatchTaskManager) SetDB(db *database.DB) { func (m *BatchTaskManager) SetDB(db *database.DB) {
m.mu.Lock() m.mu.Lock()
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
m.db = db m.db = db
} }
// normalizeBatchQueueConcurrency 规范化队列并发数。
func normalizeBatchQueueConcurrency(n int) int {
if n < 1 {
return DefaultBatchQueueConcurrency
}
if n > MaxBatchQueueConcurrency {
return MaxBatchQueueConcurrency
}
return n
}
// CreateBatchQueue 创建批量任务队列 // CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue( func (m *BatchTaskManager) CreateBatchQueue(
title, role, agentMode, scheduleMode, cronExpr, projectID string, title, role, agentMode, scheduleMode, cronExpr, projectID string,
nextRunAt *time.Time, nextRunAt *time.Time,
concurrency int,
tasks []string, tasks []string,
) (*BatchTaskQueue, error) { ) (*BatchTaskQueue, error) {
// 输入校验 // 输入校验
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
CronExpr: strings.TrimSpace(cronExpr), CronExpr: strings.TrimSpace(cronExpr),
NextRunAt: nextRunAt, NextRunAt: nextRunAt,
ScheduleEnabled: true, ScheduleEnabled: true,
Concurrency: normalizeBatchQueueConcurrency(concurrency),
Tasks: make([]*BatchTask, 0, len(tasks)), Tasks: make([]*BatchTask, 0, len(tasks)),
Status: BatchQueueStatusPending, Status: BatchQueueStatusPending,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
queue.CronExpr, queue.CronExpr,
queue.NextRunAt, queue.NextRunAt,
queue.ProjectID, queue.ProjectID,
queue.Concurrency,
dbTasks, dbTasks,
); err != nil { ); err != nil {
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if queueRow.ProjectID.Valid { if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
} }
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
if queueRow.ProjectID.Valid { if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
} }
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
} }
} }
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用) // batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error { func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
if row == nil || !row.Concurrency.Valid {
return DefaultBatchQueueConcurrency
}
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
}
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
} }
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
queue.Title = title queue.Title = title
queue.Role = role queue.Role = role
queue.AgentMode = agentMode queue.AgentMode = agentMode
if concurrency != nil {
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
}
if m.db != nil { if m.db != nil {
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil { if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
} }
} }
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引 // PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error { func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
var cancelFunc context.CancelFunc
var siblingRunningIDs []string var siblingRunningIDs []string
m.mu.Lock() m.mu.Lock()
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
} }
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项 // 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
var cancelFuncs []context.CancelFunc
if queue.Status == BatchQueueStatusPaused { if queue.Status == BatchQueueStatusPaused {
if c, ok := m.taskCancels[queueID]; ok { cancelFuncs = m.drainTaskCancelsLocked(queueID)
cancelFunc = c
delete(m.taskCancels, queueID)
}
for _, t := range queue.Tasks { for _, t := range queue.Tasks {
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning { if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
siblingRunningIDs = append(siblingRunningIDs, t.ID) siblingRunningIDs = append(siblingRunningIDs, t.ID)
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
m.mu.Unlock() m.mu.Unlock()
if cancelFunc != nil { for _, c := range cancelFuncs {
cancelFunc() if c != nil {
c()
}
} }
const staleRunMsg = "为单条执行其它任务,已中止" const staleRunMsg = "为单条执行其它任务,已中止"
for _, sid := range siblingRunningIDs { for _, sid := range siblingRunningIDs {
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
} }
} }
// GetNextTask 取下一个待执行任务 // ClaimNextPendingTask 原子领取下一个待执行任务(并发 worker 安全)。
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return nil, false
}
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
return nil, false
}
onlyTaskID := ""
if m.singleRunTasks != nil {
onlyTaskID = m.singleRunTasks[queueID]
}
for i, task := range queue.Tasks {
if task == nil || task.Status != BatchTaskStatusPending {
continue
}
if onlyTaskID != "" && task.ID != onlyTaskID {
continue
}
task.Status = BatchTaskStatusRunning
queue.CurrentIndex = i
return task, true
}
return nil, false
}
// HasRunningTasks 队列是否仍有 running 状态的子任务。
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return false
}
for _, task := range queue.Tasks {
if task != nil && task.Status == BatchTaskStatusRunning {
return true
}
}
return false
}
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return false
}
for _, task := range queue.Tasks {
if task == nil {
continue
}
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
return true
}
}
return false
}
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
taskMap, ok := m.taskCancels[queueID]
if !ok || len(taskMap) == 0 {
return nil
}
cancels := make([]context.CancelFunc, 0, len(taskMap))
for _, c := range taskMap {
if c != nil {
cancels = append(cancels, c)
}
}
delete(m.taskCancels, queueID)
return cancels
}
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
} }
} }
// SetTaskCancel 设置当前任务的取消函数 // SetTaskCancel 设置任务的取消函数
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) { func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if cancel != nil { if cancel == nil {
m.taskCancels[queueID] = cancel if taskMap, ok := m.taskCancels[queueID]; ok {
} else { delete(taskMap, taskID)
if len(taskMap) == 0 {
delete(m.taskCancels, queueID) delete(m.taskCancels, queueID)
} }
}
return
}
if m.taskCancels[queueID] == nil {
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
}
m.taskCancels[queueID][taskID] = cancel
} }
// PauseQueue 暂停队列 // PauseQueue 暂停队列
func (m *BatchTaskManager) PauseQueue(queueID string) bool { func (m *BatchTaskManager) PauseQueue(queueID string) bool {
var cancelFunc context.CancelFunc var cancelFuncs []context.CancelFunc
m.mu.Lock() m.mu.Lock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
} }
queue.Status = BatchQueueStatusPaused queue.Status = BatchQueueStatusPaused
cancelFuncs = m.drainTaskCancelsLocked(queueID)
// 取消当前正在执行的任务(通过取消context)
if cancel, ok := m.taskCancels[queueID]; ok {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
m.mu.Unlock() m.mu.Unlock()
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) for _, c := range cancelFuncs {
if cancelFunc != nil { c()
cancelFunc()
} }
return true return true
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) // CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
func (m *BatchTaskManager) CancelQueue(queueID string) bool { func (m *BatchTaskManager) CancelQueue(queueID string) bool {
now := time.Now() now := time.Now()
var cancelFunc context.CancelFunc var cancelFuncs []context.CancelFunc
m.mu.Lock() m.mu.Lock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
} }
} }
// 取消当前正在执行的任务 cancelFuncs = m.drainTaskCancelsLocked(queueID)
if cancel, ok := m.taskCancels[queueID]; ok {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
m.mu.Unlock() m.mu.Unlock()
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) for _, c := range cancelFuncs {
if cancelFunc != nil { c()
cancelFunc()
} }
return true return true
} }
// DeleteQueue 删除队列(运行中的队列不允许删除) // DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
func (m *BatchTaskManager) DeleteQueue(queueID string) bool { func (m *BatchTaskManager) DeleteQueue(queueID string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
if !exists { if !exists {
return false return ErrBatchQueueNotFound
}
if _, exec := m.queueExecutors[queueID]; exec {
return ErrBatchQueueExecutorActive
} }
// 运行中的队列不允许删除,防止孤儿协程和数据丢失 // 运行中的队列不允许删除,防止孤儿协程和数据丢失
if queue.Status == BatchQueueStatusRunning { if queue.Status == BatchQueueStatusRunning {
return false return ErrBatchQueueStillRunning
} }
// 清理取消函数 // 清理取消函数
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
} }
delete(m.queues, queueID) delete(m.queues, queueID)
return true return nil
} }
// generateShortID 生成短ID // generateShortID 生成短ID
+121
View File
@@ -0,0 +1,121 @@
package handler
import (
"errors"
"testing"
"go.uber.org/zap"
)
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
}
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
}
}
func TestClaimNextPendingTaskParallel(t *testing.T) {
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
if !ok1 || !ok2 || t1.ID == t2.ID {
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
}
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
t.Fatalf("claimed tasks should be running")
}
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
if !ok3 {
t.Fatal("expected third claim")
}
_, ok4 := m.ClaimNextPendingTask(queue.ID)
if ok4 {
t.Fatal("expected no fourth pending task")
}
_ = t3
}
func TestBatchQueueExecutionShouldStop(t *testing.T) {
t.Parallel()
if !batchQueueExecutionShouldStop(nil, false) {
t.Fatal("expected stop when queue missing")
}
if !batchQueueExecutionShouldStop(nil, true) {
t.Fatal("expected stop when queue is nil but exists=true")
}
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
if batchQueueExecutionShouldStop(q, true) {
t.Fatal("expected continue when running")
}
q.Status = BatchQueueStatusCancelled
if !batchQueueExecutionShouldStop(q, true) {
t.Fatal("expected stop when cancelled")
}
}
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
if !m.TryMarkQueueExecutor(queue.ID) {
t.Fatal("expected to mark executor")
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
err = m.DeleteQueue(queue.ID)
if !errors.Is(err, ErrBatchQueueExecutorActive) {
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
}
if _, ok := m.GetBatchQueue(queue.ID); !ok {
t.Fatal("queue should still exist while executor active")
}
m.UnmarkQueueExecutor(queue.ID)
if err := m.DeleteQueue(queue.ID); err != nil {
t.Fatalf("expected delete after executor unmarked, got %v", err)
}
if _, ok := m.GetBatchQueue(queue.ID); ok {
t.Fatal("queue should be deleted")
}
}
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
err = m.DeleteQueue(queue.ID)
if !errors.Is(err, ErrBatchQueueStillRunning) {
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
}
}
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
if !m.TryMarkQueueExecutor("q-1") {
t.Fatal("first mark should succeed")
}
if m.TryMarkQueueExecutor("q-1") {
t.Fatal("second mark should fail")
}
m.UnmarkQueueExecutor("q-1")
if !m.TryMarkQueueExecutor("q-1") {
t.Fatal("mark after unmark should succeed")
}
}
+29 -3
View File
@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"type": "string", "type": "string",
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id", "description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id",
}, },
"concurrency": map[string]interface{}{
"type": "integer",
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
},
}, },
}, },
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
executeNow = false executeNow = false
} }
projectID := strings.TrimSpace(mcpArgString(args, "project_id")) projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks) concurrency := int(mcpArgFloat(args, "concurrency"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
if createErr != nil { if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
} }
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
if qid == "" { if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil return batchMCPTextResult("queue_id 不能为空", true), nil
} }
if !h.batchTaskManager.DeleteQueue(qid) { if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
switch {
case errors.Is(err, ErrBatchQueueNotFound):
return batchMCPTextResult("删除失败:队列不存在", true), nil return batchMCPTextResult("删除失败:队列不存在", true), nil
case errors.Is(err, ErrBatchQueueExecutorActive):
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
case errors.Is(err, ErrBatchQueueStillRunning):
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
default:
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
}
} }
logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
return batchMCPTextResult("队列已删除。", false), nil return batchMCPTextResult("队列已删除。", false), nil
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"description": "代理模式:eino_single、deep、plan_execute、supervisor", "description": "代理模式:eino_single、deep、plan_execute、supervisor",
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
}, },
"concurrency": map[string]interface{}{
"type": "integer",
"description": "同时执行的子任务数,默认 1,最大 8",
},
}, },
"required": []string{"queue_id"}, "required": []string{"queue_id"},
}, },
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
title := mcpArgString(args, "title") title := mcpArgString(args, "title")
role := mcpArgString(args, "role") role := mcpArgString(args, "role")
agentMode := mcpArgString(args, "agent_mode") agentMode := mcpArgString(args, "agent_mode")
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil { var concurrency *int
if raw, ok := args["concurrency"]; ok && raw != nil {
v := int(mcpArgFloat(args, "concurrency"))
concurrency = &v
}
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
return batchMCPTextResult(err.Error(), true), nil return batchMCPTextResult(err.Error(), true), nil
} }
updated, _ := h.batchTaskManager.GetBatchQueue(qid) updated, _ := h.batchTaskManager.GetBatchQueue(qid)
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
StartedAt *time.Time `json:"startedAt,omitempty"` StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"` CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"` CurrentIndex int `json:"currentIndex"`
Concurrency int `json:"concurrency"`
TaskTotal int `json:"task_total"` TaskTotal int `json:"task_total"`
TaskCounts map[string]int `json:"task_counts"` TaskCounts map[string]int `json:"task_counts"`
Tasks []batchTaskMCPListSummary `json:"tasks"` Tasks []batchTaskMCPListSummary `json:"tasks"`
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
StartedAt: q.StartedAt, StartedAt: q.StartedAt,
CompletedAt: q.CompletedAt, CompletedAt: q.CompletedAt,
CurrentIndex: q.CurrentIndex, CurrentIndex: q.CurrentIndex,
Concurrency: q.Concurrency,
TaskTotal: len(tasks), TaskTotal: len(tasks),
TaskCounts: counts, TaskCounts: counts,
Tasks: tasks, Tasks: tasks,
+75
View File
@@ -23,6 +23,8 @@ import (
type MonitorHandler struct { type MonitorHandler struct {
mcpServer *mcp.Server mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager externalMCPMgr *mcp.ExternalMCPManager
taskManager *AgentTaskManager
agentHandler *AgentHandler
executor *security.Executor executor *security.Executor
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
@@ -56,6 +58,16 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
h.externalMCPMgr = mgr h.externalMCPMgr = mgr
} }
// SetTaskManager 设置 Agent 任务管理器(用于 Eino execute 等按 executionId 终止)。
func (h *MonitorHandler) SetTaskManager(mgr *AgentTaskManager) {
h.taskManager = mgr
}
// SetAgentHandler 设置 Agent 处理器(MCP 监控终止与对话页「中断并继续」共用逻辑)。
func (h *MonitorHandler) SetAgentHandler(ah *AgentHandler) {
h.agentHandler = ah
}
// MonitorResponse 监控响应 // MonitorResponse 监控响应
type MonitorResponse struct { type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"` Executions []*mcp.ToolExecution `json:"executions"`
@@ -90,6 +102,7 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
toolName := normalizeToolNameFilter(c.Query("tool")) toolName := normalizeToolNameFilter(c.Query("tool"))
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName) executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
h.enrichExecutionsConversationID(executions)
stats := h.loadStats() stats := h.loadStats()
totalPages := (total + pageSize - 1) / pageSize totalPages := (total + pageSize - 1) / pageSize
@@ -247,6 +260,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
// 先从内部MCP服务器查找 // 先从内部MCP服务器查找
exec, exists := h.mcpServer.GetExecution(id) exec, exists := h.mcpServer.GetExecution(id)
if exists { if exists {
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
c.JSON(http.StatusOK, exec) c.JSON(http.StatusOK, exec)
return return
} }
@@ -255,6 +269,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
if h.externalMCPMgr != nil { if h.externalMCPMgr != nil {
exec, exists = h.externalMCPMgr.GetExecution(id) exec, exists = h.externalMCPMgr.GetExecution(id)
if exists { if exists {
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
c.JSON(http.StatusOK, exec) c.JSON(http.StatusOK, exec)
return return
} }
@@ -264,6 +279,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
if h.db != nil { if h.db != nil {
exec, err := h.db.GetToolExecution(id) exec, err := h.db.GetToolExecution(id)
if err == nil && exec != nil { if err == nil && exec != nil {
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
c.JSON(http.StatusOK, exec) c.JSON(http.StatusOK, exec)
return return
} }
@@ -290,6 +306,19 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
return return
} }
note = strings.TrimSpace(body.Note) note = strings.TrimSpace(body.Note)
convID := h.conversationIDForRunningExecution(id)
if convID != "" && h.agentHandler != nil {
if ok, payload := h.agentHandler.cancelToolContinueAfter(convID, id, note); ok {
h.logger.Info("MCP 监控页终止工具(与对话中断并继续一致)",
zap.String("executionId", id),
zap.String("conversationId", convID),
zap.Bool("hasNote", note != ""),
)
c.JSON(http.StatusOK, payload)
return
}
}
if h.mcpServer.CancelToolExecutionWithNote(id, note) { if h.mcpServer.CancelToolExecutionWithNote(id, note) {
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != "")) h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id}) c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
@@ -303,6 +332,52 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"}) c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
} }
func (h *MonitorHandler) enrichExecutionsConversationID(executions []*mcp.ToolExecution) {
for _, exec := range executions {
if exec == nil {
continue
}
exec.ConversationID = h.conversationIDForRunningExecution(exec.ID)
}
}
func (h *MonitorHandler) conversationIDForRunningExecution(executionID string) string {
executionID = strings.TrimSpace(executionID)
if executionID == "" || h.taskManager == nil {
return ""
}
if conv := h.taskManager.ConversationIDForActiveMCPExecution(executionID); conv != "" {
return conv
}
exec := h.lookupExecution(executionID)
if exec == nil || exec.Status != "running" {
return ""
}
if strings.TrimSpace(exec.ToolName) == "execute" {
if onlyConv, ok := h.taskManager.ConversationIDForActiveEinoExecute(); ok {
return onlyConv
}
}
return ""
}
func (h *MonitorHandler) lookupExecution(id string) *mcp.ToolExecution {
if exec, ok := h.mcpServer.GetExecution(id); ok {
return exec
}
if h.externalMCPMgr != nil {
if exec, ok := h.externalMCPMgr.GetExecution(id); ok {
return exec
}
}
if h.db != nil {
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
return exec
}
}
return nil
}
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求) // BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) { func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
var req struct { var req struct {
+34
View File
@@ -103,6 +103,40 @@ func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
} }
} }
// ConversationIDForActiveMCPExecution 根据当前登记的工具 executionId 反查会话 ID(供 MCP 监控页按 executionId 终止)。
func (m *AgentTaskManager) ConversationIDForActiveMCPExecution(executionID string) string {
executionID = strings.TrimSpace(executionID)
if executionID == "" {
return ""
}
m.mu.Lock()
defer m.mu.Unlock()
for convID, t := range m.tasks {
if t != nil && t.ActiveMCPExecutionID == executionID {
return convID
}
}
return ""
}
// ConversationIDForActiveEinoExecute 返回当前唯一进行 Eino execute 的会话 ID;多会话并行时返回空。
func (m *AgentTaskManager) ConversationIDForActiveEinoExecute() (string, bool) {
m.mu.Lock()
defer m.mu.Unlock()
var found string
count := 0
for convID, t := range m.tasks {
if t != nil && t.activeEinoExecuteCancel != nil {
found = convID
count++
}
}
if count == 1 {
return found, true
}
return "", false
}
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。 // AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool { func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
conversationID = strings.TrimSpace(conversationID) conversationID = strings.TrimSpace(conversationID)
@@ -38,3 +38,19 @@ func TestAbortActiveEinoExecute(t *testing.T) {
t.Fatal("second abort should fail when no active execute") t.Fatal("second abort should fail when no active execute")
} }
} }
func TestConversationIDForActiveMCPExecution(t *testing.T) {
m := NewAgentTaskManager()
conv := "conv-mcp-exec"
_, err := m.StartTask(conv, "test", func(error) {})
if err != nil {
t.Fatalf("StartTask: %v", err)
}
m.RegisterRunningTool(conv, "exec-123")
if got := m.ConversationIDForActiveMCPExecution("exec-123"); got != conv {
t.Fatalf("got %q, want %q", got, conv)
}
if got := m.ConversationIDForActiveMCPExecution("missing"); got != "" {
t.Fatalf("missing should be empty, got %q", got)
}
}
+83 -16
View File
@@ -921,9 +921,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
return finalResult, executionID, nil return finalResult, executionID, nil
} }
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致), // BeginToolExecution 创建 running 状态的执行记录,供 Eino 等非 CallTool 路径在工具开始时落库。
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。 func (s *Server) BeginToolExecution(toolName string, args map[string]interface{}) string {
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if s == nil { if s == nil {
return "" return ""
} }
@@ -931,21 +930,73 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
args = map[string]interface{}{} args = map[string]interface{}{}
} }
executionID := uuid.New().String() executionID := uuid.New().String()
now := time.Now() execution := &ToolExecution{
failed := invokeErr != nil
exec := &ToolExecution{
ID: executionID, ID: executionID,
ToolName: toolName, ToolName: toolName,
Arguments: args, Arguments: args,
StartTime: now, Status: "running",
EndTime: &now, StartTime: time.Now(),
Duration: 0,
} }
s.mu.Lock()
s.executions[executionID] = execution
s.cleanupOldExecutions()
s.mu.Unlock()
if s.storage != nil {
if err := s.storage.SaveToolExecution(execution); err != nil {
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
}
}
return executionID
}
// FinishToolExecution 完成先前 BeginToolExecution 创建的记录;executionID 为空时等同 RecordCompletedToolInvocation。
func (s *Server) FinishToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if s == nil {
return ""
}
if args == nil {
args = map[string]interface{}{}
}
id := strings.TrimSpace(executionID)
if id == "" {
return s.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
}
now := time.Now()
failed := invokeErr != nil
var finalResult *ToolResult
s.mu.Lock()
exec, inMem := s.executions[id]
if !inMem || exec == nil {
exec = &ToolExecution{
ID: id,
ToolName: toolName,
Arguments: args,
StartTime: now,
}
s.executions[id] = exec
} else if toolName != "" {
exec.ToolName = toolName
}
if len(args) > 0 {
exec.Arguments = args
}
exec.EndTime = &now
if exec.StartTime.IsZero() {
exec.StartTime = now
}
exec.Duration = now.Sub(exec.StartTime)
if failed { if failed {
exec.Status = "failed" st, msg := executionStatusAndMessage(invokeErr)
exec.Error = invokeErr.Error() exec.Status = st
exec.Error = msg
if strings.TrimSpace(resultText) != "" { if strings.TrimSpace(resultText) != "" {
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}} finalResult = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
exec.Result = finalResult
} }
} else { } else {
exec.Status = "completed" exec.Status = "completed"
@@ -953,15 +1004,31 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
if strings.TrimSpace(text) == "" { if strings.TrimSpace(text) == "" {
text = "(无输出)" text = "(无输出)"
} }
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}} finalResult = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
exec.Result = finalResult
} }
s.mu.Unlock()
if s.storage != nil { if s.storage != nil {
if err := s.storage.SaveToolExecution(exec); err != nil { if err := s.storage.SaveToolExecution(exec); err != nil {
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err)) s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
} }
} }
s.updateStats(toolName, failed)
return executionID s.updateStats(exec.ToolName, failed)
if s.storage != nil {
s.mu.Lock()
delete(s.executions, id)
s.mu.Unlock()
}
return id
}
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
return s.FinishToolExecution("", toolName, args, resultText, invokeErr)
} }
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。 // UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
+2
View File
@@ -199,6 +199,8 @@ type ToolExecution struct {
StartTime time.Time `json:"startTime"` StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"` EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"` Duration time.Duration `json:"duration,omitempty"`
// ConversationID 仅 API 展示用(进行中的 Agent 任务),不写入 tool_executions 表。
ConversationID string `json:"conversationId,omitempty"`
} }
// ToolStats 工具统计信息 // ToolStats 工具统计信息
+157 -51
View File
@@ -18,6 +18,7 @@ import (
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/einoobserve" "cyberstrike-ai/internal/einoobserve"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
@@ -90,7 +91,7 @@ type einoADKRunLoopArgs struct {
FilesystemMonitorRecord einomcp.ExecutionRecorder FilesystemMonitorRecord einomcp.ExecutionRecorder
MCPExecutionBinder *MCPExecutionBinder MCPExecutionBinder *MCPExecutionBinder
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。 // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Setexecute/MCP 桥 Fire 时立即推送 tool_resultADK 晚到经 toolResultSent 去重)
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
DA adk.Agent DA adk.Agent
@@ -196,6 +197,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
pendingByID[tc.ToolCallID] = tc pendingByID[tc.ToolCallID] = tc
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID) pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
} }
markPendingWithMonitor := func(tc toolCallPendingInfo) {
markPending(tc)
beginEinoADKFilesystemToolMonitor(
args.FilesystemMonitorAgent,
args.FilesystemMonitorRecord,
args.MCPExecutionBinder,
tc.ToolCallID,
tc.ToolName,
)
}
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) { popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
pendingMu.Lock() pendingMu.Lock()
defer pendingMu.Unlock() defer pendingMu.Unlock()
@@ -331,7 +342,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
toolCallID = tid toolCallID = tid
} }
recordPendingExecuteStdoutDup(toolName, content, isErr) recordPendingExecuteStdoutDup(toolName, content, isErr)
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr) recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, args.MCPExecutionBinder, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil { if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" { if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content) args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
@@ -341,8 +352,21 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
if args.ToolInvokeNotify != nil { if args.ToolInvokeNotify != nil {
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) { args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
removePendingByID(strings.TrimSpace(toolCallID)) // Eino execute / MCP 桥在工具返回时 Fire;若 ADK schema.Tool 事件迟迟不到,此处立即推送
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断) // tool_result 解除 UI「执行中」。tryEmitToolResultProgress 经 toolResultSent 去重,ADK 晚到不重复
isErr := !success || invokeErr != nil
body := einoToolResultBody(content)
if einoToolResultIsError(toolName, content) {
isErr = true
}
if tail := friendlyEinoExecuteInvokeTail(invokeErr); tail != "" {
if body == "" {
body = tail
} else if !strings.Contains(body, tail) {
body = strings.TrimSpace(body) + "\n\n" + tail
}
}
tryEmitToolResultProgress(toolName, body, toolCallID, isErr, einoAgent)
}) })
} }
@@ -539,6 +563,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
return true, nil return true, nil
} }
// 仅在退避重试后真正收到数据/完成一步时清零,避免重启后首个无错 ADK 事件误把计数打回 0。
confirmTransientRetryRecovery := func() {
if transientRetrier.attempt() > 0 {
transientRetrier.reset()
}
}
takePartial := func(runErr error) (*RunResult, error) { takePartial := func(runErr error) (*RunResult, error) {
if len(runAccumulatedMsgs) <= baseAccumulatedCount { if len(runAccumulatedMsgs) <= baseAccumulatedCount {
return nil, runErr return nil, runErr
@@ -551,10 +582,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
for { for {
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中" // iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending
select { ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
case <-ctx.Done(): if iterCtxErr != nil {
flushAllPendingAsFailed(ctx.Err()) flushAllPendingAsFailed(iterCtxErr)
if progress != nil { if progress != nil {
if isInterruptContinue(ctx) { if isInterruptContinue(ctx) {
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{ progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
@@ -563,17 +594,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"kind": "interrupt_continue", "kind": "interrupt_continue",
}) })
} else { } else {
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{ progress("error", iterCtxErr.Error(), map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
}) })
} }
} }
return takePartial(ctx.Err()) return takePartial(iterCtxErr)
default:
} }
ev, ok := iter.Next()
if !ok { if !ok {
// iter 结束并不总是“正常完成”: // iter 结束并不总是“正常完成”:
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。 // 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
@@ -627,8 +655,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if restarted { if restarted {
continue continue
} }
} else {
transientRetrier.reset()
} }
if ev.AgentName != "" && progress != nil { if ev.AgentName != "" && progress != nil {
iterEinoAgent := orchestratorName iterEinoAgent := orchestratorName
@@ -691,34 +717,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool { if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
toolName := strings.TrimSpace(mv.ToolName) toolName := strings.TrimSpace(mv.ToolName)
var toolBuf strings.Builder content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
streamToolCallID := "" isErr := einoToolResultIsError(toolName, content)
var toolStreamRecvErr error content = einoToolResultBody(content)
for {
chunk, rerr := mv.MessageStream.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
toolStreamRecvErr = rerr
break
}
if chunk == nil {
continue
}
if chunk.Content != "" {
toolBuf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
streamToolCallID = tid
}
}
content := toolBuf.String()
isErr := false
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
isErr = true
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
if streamToolCallID != "" { if streamToolCallID != "" {
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)} opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...)) runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
@@ -730,6 +731,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
zap.String("agent", ev.AgentName), zap.String("agent", ev.AgentName),
zap.String("tool", toolName)) zap.String("tool", toolName))
} }
if toolStreamRecvErr == nil {
confirmTransientRetryRecovery()
}
continue continue
} }
@@ -977,7 +981,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 { if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged}) lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
} }
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending) tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。 // 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 { if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls)) runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
@@ -1001,6 +1005,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if restarted { if restarted {
continue continue
} }
} else {
confirmTransientRetryRecovery()
} }
continue continue
} }
@@ -1010,7 +1016,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue continue
} }
runAccumulatedMsgs = append(runAccumulatedMsgs, msg) runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending) tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
if mv.Role == schema.Assistant { if mv.Role == schema.Assistant {
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" { if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
@@ -1085,15 +1091,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
content := msg.Content content := msg.Content
isErr := false isErr := einoToolResultIsError(toolName, content)
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) { content = einoToolResultBody(content)
isErr = true
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
toolCallID := strings.TrimSpace(msg.ToolCallID) toolCallID := strings.TrimSpace(msg.ToolCallID)
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName) tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
} }
confirmTransientRetryRecovery()
} }
mcpIDsMu.Lock() mcpIDsMu.Lock()
@@ -1121,17 +1125,119 @@ func einoPartialRunLastOutputHint() string {
"[Run ended abnormally; continue from the trace above without repeating completed steps.]" "[Run ended abnormally; continue from the trace above without repeating completed steps.]"
} }
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本 // friendlyEinoExecuteInvokeTail 将 Eino execute 超时/中断/流异常转为简短提示
// 命令非零退出(ExecuteExitError)已有 exec 对齐的正文,不再追加「执行未正常结束」。
func friendlyEinoExecuteInvokeTail(invokeErr error) string { func friendlyEinoExecuteInvokeTail(invokeErr error) string {
if invokeErr == nil { if invokeErr == nil {
return "" return ""
} }
var exitErr *ExecuteExitError
if errors.As(invokeErr, &exitErr) {
return ""
}
if errors.Is(invokeErr, context.DeadlineExceeded) { if errors.Is(invokeErr, context.DeadlineExceeded) {
return einoExecuteTimeoutUserHint() return einoExecuteTimeoutUserHint()
} }
if errors.Is(invokeErr, context.Canceled) {
return ""
}
if strings.Contains(invokeErr.Error(), "shell inactivity timeout") {
return ""
}
return "[执行未正常结束] " + invokeErr.Error() return "[执行未正常结束] " + invokeErr.Error()
} }
// einoToolResultIsError 统一判断 Eino 工具结果是否应标记为错误(与 MCP exec 的 IsError 对齐)。
func einoToolResultIsError(toolName, content string) bool {
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
return true
}
if strings.TrimSpace(toolName) == "execute" && security.IsCommandFailureResult(content) {
return true
}
return false
}
// einoToolResultBody 去掉工具错误前缀,返回展示/持久化正文。
func einoToolResultBody(content string) string {
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
return strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
}
return content
}
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
if iter == nil {
return nil, false, nil
}
type nextRes struct {
ev *adk.AgentEvent
ok bool
}
ch := make(chan nextRes, 1)
go func() {
e, o := iter.Next()
ch <- nextRes{e, o}
}()
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case res := <-ch:
return res.ev, res.ok, nil
}
}
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
if stream == nil {
return "", "", nil
}
type streamMsg struct {
chunk *schema.Message
err error
}
recvCh := make(chan streamMsg, 8)
go func() {
defer close(recvCh)
for {
ch, rerr := stream.Recv()
recvCh <- streamMsg{chunk: ch, err: rerr}
if rerr != nil {
return
}
}
}()
var buf strings.Builder
for {
select {
case <-ctx.Done():
return buf.String(), toolCallID, ctx.Err()
case sm, open := <-recvCh:
if !open {
return buf.String(), toolCallID, nil
}
rerr := sm.err
if errors.Is(rerr, io.EOF) {
return buf.String(), toolCallID, nil
}
if rerr != nil {
return buf.String(), toolCallID, rerr
}
chunk := sm.chunk
if chunk == nil {
continue
}
if chunk.Content != "" {
buf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
toolCallID = tid
}
}
}
}
func buildEinoRunResultFromAccumulated( func buildEinoRunResultFromAccumulated(
orchMode string, orchMode string,
runAccumulatedMsgs []adk.Message, runAccumulatedMsgs []adk.Message,
@@ -0,0 +1,74 @@
package multiagent
import (
"context"
"errors"
"io"
"testing"
"time"
"github.com/cloudwego/eino/schema"
)
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
sw.Close()
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if content != "hello" {
t.Fatalf("content=%q want hello", content)
}
if tid != "tc-1" {
t.Fatalf("toolCallID=%q want tc-1", tid)
}
}
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
t.Cleanup(func() { sw.Close() })
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
content, _, err := recvSchemaMessageStream(ctx, sr)
if !errors.Is(err, context.Canceled) {
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
}
}
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
want := errors.New("stream broken")
_ = sw.Send(nil, want)
sw.Close()
_, _, err := recvSchemaMessageStream(context.Background(), sr)
if !errors.Is(err, want) {
t.Fatalf("want %v, got %v", want, err)
}
}
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
if err != nil || content != "" || tid != "" {
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
}
}
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
_ = sw.Send(nil, io.EOF)
sw.Close()
_, _, err := recvSchemaMessageStream(context.Background(), sr)
if err != nil {
t.Fatalf("EOF should not surface as error, got %v", err)
}
}
@@ -0,0 +1,114 @@
package multiagent
import (
"context"
"errors"
"io"
"strings"
"testing"
"cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/schema"
)
type mockStreamingShellExitFail struct {
output string
code int
}
func (m *mockStreamingShellExitFail) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
defer outW.Close()
if m.output != "" {
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
}
code := m.code
_ = outW.Send(&filesystem.ExecuteResponse{ExitCode: &code}, nil)
}()
return outR, nil
}
func TestEinoStreamingShellWrap_CommandFailureFormat(t *testing.T) {
inner := &mockStreamingShellExitFail{
output: "sudo: a password is required\n",
code: 1,
}
notify := einomcp.NewToolInvokeNotifyHolder()
var firedBody string
var firedSuccess bool
var firedErr error
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
firedBody = content
firedSuccess = success
firedErr = invokeErr
})
wrap := &einoStreamingShellWrap{inner: inner, invokeNotify: notify}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var stream strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil {
stream.WriteString(resp.Output)
}
}
if firedSuccess {
t.Fatal("expected success=false")
}
var exitErr *ExecuteExitError
if !errors.As(firedErr, &exitErr) || exitErr.Code != 1 {
t.Fatalf("expected ExecuteExitError code 1, got %v", firedErr)
}
if !strings.HasPrefix(firedBody, einomcp.ToolErrorPrefix) {
t.Fatalf("missing tool error prefix: %q", firedBody)
}
body := strings.TrimPrefix(firedBody, einomcp.ToolErrorPrefix)
if body != security.FormatCommandFailureResult(1, "sudo: a password is required\n") {
t.Fatalf("fire body = %q", body)
}
if !strings.Contains(stream.String(), "sudo:") {
t.Fatalf("stream missing sudo output: %q", stream.String())
}
if strings.Contains(stream.String(), "command exited with non-zero") {
t.Fatalf("stream has legacy noise: %q", stream.String())
}
if strings.Contains(stream.String(), "执行未正常结束") {
t.Fatalf("stream has abnormal tail: %q", stream.String())
}
if !security.IsCommandFailureResult(stream.String()) {
t.Fatalf("stream missing failure status line: %q", stream.String())
}
if tail := friendlyEinoExecuteInvokeTail(firedErr); tail != "" {
t.Fatalf("unexpected invoke tail: %q", tail)
}
if !einoToolResultIsError("execute", firedBody) {
t.Fatal("expected isError for execute failure")
}
}
func TestFriendlyEinoExecuteInvokeTail(t *testing.T) {
if friendlyEinoExecuteInvokeTail(&ExecuteExitError{Code: 1}) != "" {
t.Fatal("exit error should not get abnormal tail")
}
if !strings.Contains(friendlyEinoExecuteInvokeTail(context.DeadlineExceeded), "Timed out") {
t.Fatal("deadline should get timeout hint")
}
if friendlyEinoExecuteInvokeTail(errors.New("broken pipe")) == "" {
t.Fatal("unexpected error should get tail")
}
}
+22 -7
View File
@@ -7,11 +7,25 @@ import (
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
) )
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId) // newEinoExecuteMonitorCallbacks 在 Eino filesystem execute 开始/结束时写入 MCP 监控库并 recorder(executionId)
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片 // 与 CallTool 路径一致,使监控页能展示「执行中」状态
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) { func newEinoExecuteMonitorCallbacks(ag *agent.Agent, recorder einomcp.ExecutionRecorder) (
return func(toolCallID, command, stdout string, success bool, invokeErr error) { begin func(toolCallID, command string) string,
if ag == nil || recorder == nil { finish func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
) {
begin = func(toolCallID, command string) string {
if ag == nil {
return ""
}
args := map[string]interface{}{"command": command}
id := ag.BeginLocalToolExecution("execute", args)
if id != "" && recorder != nil {
recorder(id, toolCallID)
}
return id
}
finish = func(executionID, toolCallID, command, stdout string, success bool, invokeErr error) {
if ag == nil {
return return
} }
var err error var err error
@@ -23,9 +37,10 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
} }
} }
args := map[string]interface{}{"command": command} args := map[string]interface{}{"command": command}
id := ag.RecordLocalToolExecution("execute", args, stdout, err) id := ag.FinishLocalToolExecution(executionID, "execute", args, stdout, err)
if id != "" { if id != "" && recorder != nil && executionID == "" {
recorder(id, toolCallID) recorder(id, toolCallID)
} }
} }
return begin, finish
} }
@@ -51,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。 // 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
// //
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire // 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重) // run loop 收到 Fire 后立即推送 tool_resulttoolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」
// //
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire // 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。 // 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
@@ -63,8 +63,11 @@ type einoStreamingShellWrap struct {
outputChunk func(toolName, toolCallID, chunk string) outputChunk func(toolName, toolCallID, chunk string)
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。 // toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
toolTimeoutMinutes int toolTimeoutMinutes int
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致 // shellNoOutputTimeoutSec:无任何输出时的空闲秒数;0=关闭
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error) shellNoOutputTimeoutSec int
// beginMonitor 在 execute 开始时写入 running 状态;finishMonitor 在流结束后更新为 completed/failed。
beginMonitor func(toolCallID, command string) string
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error)
} }
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
@@ -76,15 +79,26 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
} }
req := *input req := *input
userCmd := strings.TrimSpace(req.Command) userCmd := strings.TrimSpace(req.Command)
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
agentTag := strings.TrimSpace(w.einoAgentName)
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround { if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
req.RunInBackendGround = true req.RunInBackendGround = true
} }
req.Command = prependPythonUnbufferedEnv(req.Command) req.Command = security.PrepareNonInteractiveShellCommand(prependPythonUnbufferedEnv(req.Command))
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
agentTag := strings.TrimSpace(w.einoAgentName)
convID := mcp.MCPConversationIDFromContext(ctx) convID := mcp.MCPConversationIDFromContext(ctx)
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx) execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
var monitorExecID string
if w.beginMonitor != nil {
monitorExecID = w.beginMonitor(tid, userCmd)
}
if monitorExecID != "" && convID != "" {
if toolReg := mcp.ToolRunRegistryFromContext(ctx); toolReg != nil {
toolReg.RegisterRunningTool(convID, monitorExecID)
}
}
toolRunReg := mcp.ToolRunRegistryFromContext(ctx)
execCtx, execCancel := context.WithCancel(ctx) execCtx, execCancel := context.WithCancel(ctx)
var timeoutCancel context.CancelFunc var timeoutCancel context.CancelFunc
if w.toolTimeoutMinutes > 0 { if w.toolTimeoutMinutes > 0 {
@@ -104,23 +118,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
} }
if einoExecuteRecvErrIsToolTimeout(err, execCtx) { if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n" hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
if w.recordMonitor != nil { if w.finishMonitor != nil {
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded) w.finishMonitor(monitorExecID, tid, userCmd, hint, false, context.DeadlineExceeded)
} }
if w.invokeNotify != nil && tid != "" { if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded) w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
} }
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
} }
if w.recordMonitor != nil { if w.finishMonitor != nil {
w.recordMonitor(tid, userCmd, "", false, err) w.finishMonitor(monitorExecID, tid, userCmd, "", false, err)
} }
if w.invokeNotify != nil && tid != "" { if w.invokeNotify != nil && tid != "" {
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err) w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
} }
return nil, err return nil, err
} }
if sr == nil || w.invokeNotify == nil { if sr == nil {
if timeoutCancel != nil { if timeoutCancel != nil {
timeoutCancel() timeoutCancel()
} }
@@ -132,7 +146,7 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32) outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) { go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry, toolReg mcp.ToolRunRegistry, execID string, toolCallID string, noOutputSec int) {
var innerCloseOnce sync.Once var innerCloseOnce sync.Once
closeInner := func() { closeInner := func() {
innerCloseOnce.Do(func() { inner.Close() }) innerCloseOnce.Do(func() { inner.Close() })
@@ -147,6 +161,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
if reg != nil && conversationID != "" { if reg != nil && conversationID != "" {
defer reg.UnregisterActiveEinoExecute(conversationID) defer reg.UnregisterActiveEinoExecute(conversationID)
} }
if toolReg != nil && conversationID != "" && execID != "" {
defer toolReg.UnregisterRunningTool(conversationID, execID)
}
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。 // ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
stopWatch := make(chan struct{}) stopWatch := make(chan struct{})
@@ -165,50 +182,103 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
exitCode := 0 exitCode := 0
hasExitCode := false hasExitCode := false
idleWatch := security.NewShellInactivityWatch(noOutputSec)
if idleWatch != nil {
defer idleWatch.Stop()
}
type execRecvMsg struct {
resp *filesystem.ExecuteResponse
err error
}
recvCh := make(chan execRecvMsg, 1)
go func() {
for { for {
resp, rerr := inner.Recv() resp, rerr := inner.Recv()
recvCh <- execRecvMsg{resp: resp, err: rerr}
if rerr != nil {
return
}
}
}()
fireInactivityTimeout := func() {
success = false
invokeErr = fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
msg := security.ShellNoOutputTimeoutMessage(idleWatch.Sec)
_ = outW.Send(&filesystem.ExecuteResponse{Output: msg}, nil)
sb.WriteString(msg)
if w.outputChunk != nil && toolCallID != "" {
w.outputChunk("execute", toolCallID, msg)
}
if cancel != nil {
cancel()
}
closeInner()
}
recvLoop:
for {
var idleCh <-chan struct{}
if idleWatch != nil {
idleCh = idleWatch.Expired
}
select {
case <-idleCh:
fireInactivityTimeout()
break recvLoop
case msg := <-recvCh:
rerr := msg.err
resp := msg.resp
if errors.Is(rerr, io.EOF) { if errors.Is(rerr, io.EOF) {
break break recvLoop
} }
if rerr != nil { if rerr != nil {
success = false success = false
invokeErr = rerr invokeErr = rerr
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) { if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
invokeErr = context.DeadlineExceeded invokeErr = context.DeadlineExceeded
break break recvLoop
} }
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) { if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
invokeErr = context.Canceled invokeErr = context.Canceled
break break recvLoop
} }
_ = outW.Send(nil, rerr) _ = outW.Send(nil, rerr)
break break recvLoop
} }
if resp != nil { if resp != nil {
if resp.ExitCode != nil { if resp.ExitCode != nil {
hasExitCode = true hasExitCode = true
exitCode = *resp.ExitCode exitCode = *resp.ExitCode
continue
} }
var appended string var appended string
if resp.Output != "" { if resp.Output != "" {
if security.IsLegacyShellExitNoise(resp.Output) {
continue
}
if idleWatch != nil {
idleWatch.Bump()
}
sb.WriteString(resp.Output) sb.WriteString(resp.Output)
appended = resp.Output appended = resp.Output
} }
if w.outputChunk != nil && strings.TrimSpace(appended) != "" { if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
w.outputChunk("execute", tid, appended) w.outputChunk("execute", toolCallID, appended)
} }
if outW.Send(resp, nil) { if outW.Send(resp, nil) {
success = false success = false
invokeErr = fmt.Errorf("execute stream closed by consumer") invokeErr = fmt.Errorf("execute stream closed by consumer")
break break recvLoop
}
} }
} }
} }
if success && hasExitCode && exitCode != 0 { if success && hasExitCode && exitCode != 0 {
success = false success = false
invokeErr = fmt.Errorf("execute exited with code %d", exitCode) invokeErr = &ExecuteExitError{Code: exitCode}
} }
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。 // WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。 // 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
@@ -248,12 +318,24 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil) _ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
} }
} }
if w.recordMonitor != nil { rawOutput := sb.String()
w.recordMonitor(tid, command, sb.String(), success, invokeErr) fireBody := rawOutput
if !success && hasExitCode && exitCode != 0 {
statusLine := security.ExecuteFailureStatusLine(exitCode)
if !strings.Contains(rawOutput, "命令执行失败:") {
_ = outW.Send(&filesystem.ExecuteResponse{Output: statusLine}, nil)
sb.WriteString(statusLine)
}
fireBody = einomcp.ToolErrorPrefix + security.FormatCommandFailureResult(exitCode, rawOutput)
}
if w.finishMonitor != nil {
w.finishMonitor(execID, toolCallID, command, sb.String(), success, invokeErr)
}
if w.invokeNotify != nil {
w.invokeNotify.Fire(toolCallID, "execute", agentTag, success, fireBody, invokeErr)
} }
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
outW.Close() outW.Close()
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg) }(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg, toolRunReg, monitorExecID, tid, w.shellNoOutputTimeoutSec)
return outR, nil return outR, nil
} }
@@ -19,9 +19,15 @@ type mockStreamingShell struct {
immediateErr error immediateErr error
recvErr error recvErr error
output string output string
called bool
lastCommand string
} }
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
m.called = true
if input != nil {
m.lastCommand = input.Command
}
if m.immediateErr != nil { if m.immediateErr != nil {
return nil, m.immediateErr return nil, m.immediateErr
} }
@@ -38,6 +44,135 @@ func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesy
return outR, nil return outR, nil
} }
func TestEinoStreamingShellWrap_PreparesNonInteractiveCommand(t *testing.T) {
inner := &mockStreamingShell{output: "ok\n"}
wrap := &einoStreamingShellWrap{inner: inner}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo ok"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
for {
_, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
}
if !strings.Contains(inner.lastCommand, "exec </dev/null") {
t.Fatalf("missing stdin redirect in inner command: %q", inner.lastCommand)
}
if !strings.Contains(inner.lastCommand, "GIT_PAGER=cat") {
t.Fatalf("missing pager export in inner command: %q", inner.lastCommand)
}
if !strings.Contains(inner.lastCommand, "PYTHONUNBUFFERED=1") {
t.Fatalf("missing python unbuffer in inner command: %q", inner.lastCommand)
}
}
func TestEinoStreamingShellWrap_NoOutputTimeout(t *testing.T) {
inner := &mockStreamingShellHanging{}
notify := einomcp.NewToolInvokeNotifyHolder()
var fired string
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
fired = content
})
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
shellNoOutputTimeoutSec: 1,
}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil {
got.WriteString(resp.Output)
}
}
if !inner.called {
t.Fatal("inner shell should run (no command blacklist)")
}
out := got.String()
if !strings.Contains(out, "没有新的输出") && !strings.Contains(out, "no new output") {
t.Fatalf("expected inactivity timeout message, got: %q notify=%q", out, fired)
}
}
type mockStreamingShellPartialThenHang struct {
called bool
}
func (m *mockStreamingShellPartialThenHang) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
m.called = true
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
_ = outW.Send(&filesystem.ExecuteResponse{Output: "[sudo] password:\n"}, nil)
<-ctx.Done()
outW.Close()
}()
return outR, nil
}
func TestEinoStreamingShellWrap_InactivityAfterPartialOutput(t *testing.T) {
inner := &mockStreamingShellPartialThenHang{}
wrap := &einoStreamingShellWrap{
inner: inner,
shellNoOutputTimeoutSec: 1,
}
start := time.Now()
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil {
got.WriteString(resp.Output)
}
}
if time.Since(start) > 5*time.Second {
t.Fatalf("expected inactivity timeout ~1s, took %v", time.Since(start))
}
if !strings.Contains(got.String(), "没有新的输出") && !strings.Contains(got.String(), "no new output") {
t.Fatalf("expected inactivity message, got: %q", got.String())
}
}
type mockStreamingShellHanging struct {
called bool
}
func (m *mockStreamingShellHanging) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
m.called = true
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
go func() {
<-ctx.Done()
outW.Close()
}()
return outR, nil
}
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) { func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel() defer cancel()
@@ -63,10 +63,43 @@ func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName
return map[string]interface{}{} return map[string]interface{}{}
} }
// beginEinoADKFilesystemToolMonitor 在 Eino ADK filesystem 工具开始调用时写入 running 状态。
func beginEinoADKFilesystemToolMonitor(
ag *agent.Agent,
rec einomcp.ExecutionRecorder,
binder *MCPExecutionBinder,
toolCallID, toolName string,
) {
if ag == nil || rec == nil {
return
}
name := strings.TrimSpace(toolName)
if name == "" || strings.EqualFold(name, "execute") {
return
}
if !isBuiltinEinoADKFilesystemToolName(name) {
return
}
tid := strings.TrimSpace(toolCallID)
if tid == "" {
return
}
storedName := "eino_fs::" + strings.ToLower(name)
id := ag.BeginLocalToolExecution(storedName, map[string]interface{}{})
if id == "" {
return
}
rec(id, tid)
if binder != nil {
binder.Bind(tid, id)
}
}
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。 // recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
func recordEinoADKFilesystemToolMonitor( func recordEinoADKFilesystemToolMonitor(
ag *agent.Agent, ag *agent.Agent,
rec einomcp.ExecutionRecorder, rec einomcp.ExecutionRecorder,
binder *MCPExecutionBinder,
toolName string, toolName string,
toolCallID string, toolCallID string,
msgs []adk.Message, msgs []adk.Message,
@@ -94,8 +127,12 @@ func recordEinoADKFilesystemToolMonitor(
invErr = errors.New(t) invErr = errors.New(t)
} }
} }
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr) execID := ""
if id != "" { if binder != nil {
execID = binder.ExecutionID(toolCallID)
}
id := ag.FinishLocalToolExecution(execID, storedName, args, resultText, invErr)
if id != "" && execID == "" {
rec(id, toolCallID) rec(id, toolCallID)
} }
} }
+2 -2
View File
@@ -81,7 +81,7 @@ func RunEinoSingleChatModelAgent(
} }
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder() toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder) einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
mainDefs := ag.ToolsForRole(roleTools) mainDefs := ag.ToolsForRole(roleTools)
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName) mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
if err != nil { if err != nil {
@@ -136,7 +136,7 @@ func RunEinoSingleChatModelAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil) fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr) return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
} }
+23 -3
View File
@@ -9,6 +9,7 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/security"
localbk "github.com/cloudwego/eino-ext/adk/backend/local" localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -81,8 +82,10 @@ func subAgentFilesystemMiddleware(
loc *localbk.Local, loc *localbk.Local,
invokeNotify *einomcp.ToolInvokeNotifyHolder, invokeNotify *einomcp.ToolInvokeNotifyHolder,
einoAgentName string, einoAgentName string,
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error), beginMonitor func(toolCallID, command string) string,
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
toolTimeoutMinutes int, toolTimeoutMinutes int,
shellNoOutputTimeoutSec int,
outputChunk func(toolName, toolCallID, chunk string), outputChunk func(toolName, toolCallID, chunk string),
) (adk.ChatModelAgentMiddleware, error) { ) (adk.ChatModelAgentMiddleware, error) {
if loc == nil { if loc == nil {
@@ -91,12 +94,14 @@ func subAgentFilesystemMiddleware(
return filesystem.New(ctx, &filesystem.MiddlewareConfig{ return filesystem.New(ctx, &filesystem.MiddlewareConfig{
Backend: loc, Backend: loc,
StreamingShell: &einoStreamingShellWrap{ StreamingShell: &einoStreamingShellWrap{
inner: loc, inner: security.NewEinoStreamingShell(),
invokeNotify: invokeNotify, invokeNotify: invokeNotify,
einoAgentName: strings.TrimSpace(einoAgentName), einoAgentName: strings.TrimSpace(einoAgentName),
outputChunk: outputChunk, outputChunk: outputChunk,
recordMonitor: recordMonitor, beginMonitor: beginMonitor,
finishMonitor: finishMonitor,
toolTimeoutMinutes: toolTimeoutMinutes, toolTimeoutMinutes: toolTimeoutMinutes,
shellNoOutputTimeoutSec: shellNoOutputTimeoutSec,
}, },
}) })
} }
@@ -108,3 +113,18 @@ func agentToolTimeoutMinutes(cfg *config.Config) int {
} }
return cfg.Agent.ToolTimeoutMinutes return cfg.Agent.ToolTimeoutMinutes
} }
// agentShellNoOutputTimeoutSeconds0=默认 300s5 分钟);-1=关闭;>0=自定义秒数。
func agentShellNoOutputTimeoutSeconds(cfg *config.Config) int {
if cfg == nil {
return 300
}
v := cfg.Agent.ShellNoOutputTimeoutSeconds
if v < 0 {
return 0
}
if v == 0 {
return 300
}
return v
}
@@ -46,6 +46,10 @@ func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, too
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n") sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
sb.WriteString("3) 不要臆造不存在的工具名。\n\n") sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
} }
if s := strings.TrimSpace(injectShellToolGuidance("", names)); s != "" {
sb.WriteString(s)
sb.WriteString("\n\n")
}
if s := strings.TrimSpace(instruction); s != "" { if s := strings.TrimSpace(instruction); s != "" {
sb.WriteString(s) sb.WriteString(s)
} }
+1 -1
View File
@@ -143,7 +143,7 @@ func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts } func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。 // reset 在退避重试后成功推进(流/消息完整接收)时清零计数,使后续临时错误从第 1 次退避重新开始。
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 } func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int { func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
@@ -105,6 +105,32 @@ func TestEinoTransientRunRetrierReset(t *testing.T) {
} }
} }
func TestEinoTransientRunRetrierConsecutiveFailures(t *testing.T) {
t.Parallel()
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
ctx := context.Background()
runErr := errors.New("internal server error")
args := &einoADKRunLoopArgs{}
base := []adk.Message{schema.UserMessage("hi")}
for want := 1; want <= 3; want++ {
restarted, _, _, _, err := r.tryRetry(ctx, runErr, args, base, nil, len(base))
if err != nil {
t.Fatalf("tryRetry attempt %d: %v", want, err)
}
if !restarted {
t.Fatalf("tryRetry attempt %d: want restarted", want)
}
if got := r.attempt(); got != want {
t.Fatalf("after failure %d: attempt=%d, want %d", want, got, want)
}
}
r.reset()
if r.attempt() != 0 {
t.Fatalf("after successful recovery reset: attempt=%d, want 0", r.attempt())
}
}
func TestAppendUserMessageIfNeeded(t *testing.T) { func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Parallel() t.Parallel()
msgs := []adk.Message{schema.UserMessage("old task")} msgs := []adk.Message{schema.UserMessage("old task")}
+15
View File
@@ -0,0 +1,15 @@
package multiagent
import "fmt"
// ExecuteExitError 表示 execute 命令非零退出(预期失败,非超时/中断/流异常)。
type ExecuteExitError struct {
Code int
}
func (e *ExecuteExitError) Error() string {
if e == nil {
return "exit status unknown"
}
return fmt.Sprintf("exit status %d", e.Code)
}
@@ -6,6 +6,7 @@ import (
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/projectprompt"
) )
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。 // DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
@@ -122,7 +123,9 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
## 表达 ## 表达
在调用工具或给出计划变更前 25 句中文说明当前决策依据与期望证据形态最终对用户交付结构化结论发现摘要证据风险下一步` 在调用工具或给出计划变更前 25 句中文说明当前决策依据与期望证据形态最终对用户交付结构化结论发现摘要证据风险下一步
` + projectprompt.ShellExecExecuteGuidanceSection()
} }
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。 // DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
+8 -5
View File
@@ -20,6 +20,7 @@ import (
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/project" "cyberstrike-ai/internal/project"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/security"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai" einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -120,7 +121,7 @@ func RunDeepAgent(
mcpIDs = append(mcpIDs, id) mcpIDs = append(mcpIDs, id)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
} }
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder) einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。 // 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
snapshotMCPIDs := func() []string { snapshotMCPIDs := func() []string {
@@ -223,7 +224,7 @@ func RunDeepAgent(
} }
if einoSkillMW != nil { if einoSkillMW != nil {
if einoFSTools && einoLoc != nil { if einoFSTools && einoLoc != nil {
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil) subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
if fsErr != nil { if fsErr != nil {
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr) return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
} }
@@ -358,12 +359,14 @@ func RunDeepAgent(
if einoLoc != nil && einoFSTools { if einoLoc != nil && einoFSTools {
deepBackend = einoLoc deepBackend = einoLoc
deepShell = &einoStreamingShellWrap{ deepShell = &einoStreamingShellWrap{
inner: einoLoc, inner: security.NewEinoStreamingShell(),
invokeNotify: toolInvokeNotify, invokeNotify: toolInvokeNotify,
einoAgentName: orchestratorName, einoAgentName: orchestratorName,
outputChunk: nil, outputChunk: nil,
recordMonitor: einoExecMonitor, beginMonitor: einoExecBegin,
finishMonitor: einoExecFinish,
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg), toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
shellNoOutputTimeoutSec: agentShellNoOutputTimeoutSeconds(appCfg),
} }
} }
@@ -428,7 +431,7 @@ func RunDeepAgent(
// 构建 filesystem 中间件(与 Deep sub-agent 一致) // 构建 filesystem 中间件(与 Deep sub-agent 一致)
var peFsMw adk.ChatModelAgentMiddleware var peFsMw adk.ChatModelAgentMiddleware
if einoSkillMW != nil && einoFSTools && einoLoc != nil { if einoSkillMW != nil && einoFSTools && einoLoc != nil {
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil) peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err) return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
} }
@@ -0,0 +1,33 @@
package multiagent
import (
"strings"
"cyberstrike-ai/internal/projectprompt"
)
func shellToolsPresent(toolNames []string) bool {
for _, n := range toolNames {
switch strings.ToLower(strings.TrimSpace(n)) {
case "exec", "execute":
return true
}
}
return false
}
// injectShellToolGuidance 在系统提示末尾追加 exec/execute 分工(仅当工具列表含 exec 或 execute)。
func injectShellToolGuidance(instruction string, toolNames []string) string {
if !shellToolsPresent(toolNames) {
return instruction
}
block := strings.TrimSpace(projectprompt.ShellExecExecuteGuidanceSection())
if block == "" {
return instruction
}
s := strings.TrimSpace(instruction)
if s == "" {
return block
}
return s + "\n\n" + block
}
@@ -0,0 +1,17 @@
package multiagent
import (
"strings"
"testing"
)
func TestInjectShellToolGuidance(t *testing.T) {
got := injectShellToolGuidance("base", []string{"nmap"})
if got != "base" {
t.Fatalf("expected unchanged, got %q", got)
}
got = injectShellToolGuidance("base", []string{"exec", "nmap"})
if !strings.Contains(got, "exec/execute") || !strings.Contains(got, "base") {
t.Fatalf("expected shell guidance appended, got %q", got)
}
}
+11
View File
@@ -0,0 +1,11 @@
package projectprompt
// ShellExecExecuteGuidanceSection 供单代理/多代理系统提示追加:exec 与 execute 分工(尽量短)。
func ShellExecExecuteGuidanceSection() string {
return `Shellexec/execute):有专用 MCP 工具时优先专用工具;系统命令(管道、workdir、后台 &)用 execskills/ 内脚本(配合 read_file、skill)用 execute;多步扫描分拆调用,禁止一条 shell 串多个扫描器。`
}
// ShellExecExecuteGuidanceReconSuffix 侦察子代理可选追加(一行)。
func ShellExecExecuteGuidanceReconSuffix() string {
return `枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。`
}
@@ -0,0 +1,56 @@
package security
import (
"errors"
"fmt"
"os/exec"
"strings"
)
// FormatCommandFailureResult 与 exec 工具 ToolResult 文案一致(不含 ToolErrorPrefix)。
func FormatCommandFailureResult(exitCode int, output string) string {
output = strings.TrimSpace(output)
errMsg := fmt.Sprintf("exit status %d", exitCode)
if output == "" {
return fmt.Sprintf("命令执行失败: %s", errMsg)
}
if strings.HasPrefix(output, "命令执行失败:") {
return output
}
return fmt.Sprintf("命令执行失败: %s\n输出: %s", errMsg, output)
}
// FormatCommandFailureFromErr 根据 exec/execute 返回的 error 生成统一失败文案(IsError 正文)。
func FormatCommandFailureFromErr(err error, output string) string {
if err == nil {
return strings.TrimSpace(output)
}
var exitError *exec.ExitError
if errors.As(err, &exitError) {
return FormatCommandFailureResult(exitError.ExitCode(), output)
}
output = strings.TrimSpace(output)
if output == "" {
return fmt.Sprintf("命令执行失败: %v", err)
}
if strings.HasPrefix(output, "命令执行失败:") {
return output
}
return fmt.Sprintf("命令执行失败: %v\n输出: %s", err, output)
}
// ExecuteFailureStatusLine 流式 execute 结束时追加的单行状态(输出正文已在流中推送过)。
func ExecuteFailureStatusLine(exitCode int) string {
return fmt.Sprintf("\n命令执行失败: exit status %d", exitCode)
}
// IsCommandFailureResult 判断工具结果正文是否表示命令非零退出(用于 execute / exec 对齐 isError)。
func IsCommandFailureResult(content string) bool {
return strings.Contains(content, "命令执行失败:")
}
// IsLegacyShellExitNoise 过滤旧版 shell 流中冗余的 exit code 行。
func IsLegacyShellExitNoise(s string) bool {
trimmed := strings.TrimSpace(s)
return strings.HasPrefix(trimmed, "command exited with non-zero code ")
}
@@ -0,0 +1,54 @@
package security
import (
"errors"
"os/exec"
"strings"
"testing"
)
func TestFormatCommandFailureResult(t *testing.T) {
got := FormatCommandFailureResult(1, "sudo: password required")
want := "命令执行失败: exit status 1\n输出: sudo: password required"
if got != want {
t.Fatalf("got %q want %q", got, want)
}
if FormatCommandFailureResult(2, "") != "命令执行失败: exit status 2" {
t.Fatal("empty output format")
}
if FormatCommandFailureResult(1, "命令执行失败: exit status 1") != "命令执行失败: exit status 1" {
t.Fatal("should not double-wrap")
}
}
func TestIsCommandFailureResult(t *testing.T) {
if !IsCommandFailureResult("sudo: err\n命令执行失败: exit status 1") {
t.Fatal("expected true")
}
if IsCommandFailureResult("sudo: err only") {
t.Fatal("expected false")
}
}
func TestFormatCommandFailureFromErr(t *testing.T) {
cmd := exec.Command("sh", "-c", "exit 42")
err := cmd.Run()
got := FormatCommandFailureFromErr(err, "oops")
if got != "命令执行失败: exit status 42\n输出: oops" {
t.Fatalf("got %q", got)
}
timeoutErr := errors.New("shell inactivity timeout (300s)")
got2 := FormatCommandFailureFromErr(timeoutErr, "already timed out")
if !strings.Contains(got2, "shell inactivity timeout") || !strings.Contains(got2, "already timed out") {
t.Fatalf("got %q", got2)
}
}
func TestIsLegacyShellExitNoise(t *testing.T) {
if !IsLegacyShellExitNoise("command exited with non-zero code 1\n") {
t.Fatal("expected legacy noise")
}
if IsLegacyShellExitNoise("sudo: failed") {
t.Fatal("unexpected noise")
}
}
+52 -15
View File
@@ -36,6 +36,7 @@ type Executor struct {
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server mcpServer *mcp.Server
logger *zap.Logger logger *zap.Logger
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300-1=关闭(见 SetShellNoOutputTimeoutSeconds
} }
// NewExecutor 创建新的执行器 // NewExecutor 创建新的执行器
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
return executor return executor
} }
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
e.shellNoOutputTimeoutSec = sec
}
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1) // buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
func (e *Executor) buildToolIndex() { func (e *Executor) buildToolIndex() {
e.toolIndex = make(map[string]*config.ToolConfig) e.toolIndex = make(map[string]*config.ToolConfig)
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 执行命令 // 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
applyDefaultTerminalEnv(cmd) applyDefaultTerminalEnv(cmd)
attachNonInteractiveStdin(cmd)
_ = prepareShellCmdSession(cmd) _ = prepareShellCmdSession(cmd)
e.logger.Info("执行安全工具", e.logger.Info("执行安全工具",
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
var err error var err error
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。 // 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
output, err = streamCommandOutput(ctx, cmd, cb) output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
if err != nil && shouldRetryWithPTY(output) { if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试", e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
zap.String("tool", toolName), zap.String("tool", toolName),
@@ -797,6 +804,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
zap.String("command", command), zap.String("command", command),
) )
command = PrepareNonInteractiveShellCommand(command)
// 获取shell类型(可选,默认为sh) // 获取shell类型(可选,默认为sh)
shell := "sh" shell := "sh"
if s, ok := args["shell"].(string); ok && s != "" { if s, ok := args["shell"].(string); ok && s != "" {
@@ -820,8 +829,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
} else { } else {
cmd = exec.CommandContext(ctx, shell, "-c", command) cmd = exec.CommandContext(ctx, shell, "-c", command)
} }
applyDefaultTerminalEnv(cmd) ConfigureShellCmdForAgentExecute(cmd)
_ = prepareShellCmdSession(cmd)
// 执行命令 // 执行命令
e.logger.Info("执行系统命令", e.logger.Info("执行系统命令",
@@ -850,8 +858,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
} else { } else {
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand) pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
} }
applyDefaultTerminalEnv(pidCmd) ConfigureShellCmdForAgentExecute(pidCmd)
_ = prepareShellCmdSession(pidCmd)
// 获取stdout管道 // 获取stdout管道
stdout, err := pidCmd.StdoutPipe() stdout, err := pidCmd.StdoutPipe()
@@ -963,15 +970,14 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
var err error var err error
// 若上层提供工具输出增量回调,则边执行边流式读取。 // 若上层提供工具输出增量回调,则边执行边流式读取。
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
output, err = streamCommandOutput(ctx, cmd, cb) output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
if err != nil && shouldRetryWithPTY(output) { if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试") e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
cmd2 := exec.CommandContext(ctx, shell, "-c", command) cmd2 := exec.CommandContext(ctx, shell, "-c", command)
if workDir != "" { if workDir != "" {
cmd2.Dir = workDir cmd2.Dir = workDir
} }
applyDefaultTerminalEnv(cmd2) ConfigureShellCmdForAgentExecute(cmd2)
_ = prepareShellCmdSession(cmd2)
output, err = runCommandWithPTY(ctx, cmd2, cb) output, err = runCommandWithPTY(ctx, cmd2, cb)
} }
} else { } else {
@@ -984,8 +990,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
if workDir != "" { if workDir != "" {
cmd2.Dir = workDir cmd2.Dir = workDir
} }
applyDefaultTerminalEnv(cmd2) ConfigureShellCmdForAgentExecute(cmd2)
_ = prepareShellCmdSession(cmd2)
output, err = runCommandWithPTY(ctx, cmd2, nil) output, err = runCommandWithPTY(ctx, cmd2, nil)
} }
} }
@@ -999,7 +1004,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
Content: []mcp.Content{ Content: []mcp.Content{
{ {
Type: "text", Type: "text",
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)), Text: FormatCommandFailureFromErr(err, output),
}, },
}, },
IsError: true, IsError: true,
@@ -1024,7 +1029,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。 // streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。 // 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) { func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
if err := prepareShellCmdSession(cmd); err != nil { if err := prepareShellCmdSession(cmd); err != nil {
return "", err return "", err
} }
@@ -1091,14 +1096,45 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
lastFlush = time.Now() lastFlush = time.Now()
} }
for chunk := range chunks { idleWatch := NewShellInactivityWatch(noOutputSec)
if idleWatch != nil {
defer idleWatch.Stop()
}
fireInactivity := func() {
terminateCmdTree(cmd)
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
outBuilder.WriteString(msg)
if cb != nil {
cb(msg)
}
_ = cmd.Wait()
}
chunksLoop:
for {
var idleCh <-chan struct{}
if idleWatch != nil {
idleCh = idleWatch.Expired
}
select {
case <-idleCh:
fireInactivity()
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
case chunk, ok := <-chunks:
if !ok {
break chunksLoop
}
if chunk != "" && idleWatch != nil {
idleWatch.Bump()
}
outBuilder.WriteString(chunk) outBuilder.WriteString(chunk)
deltaBuilder.WriteString(chunk) deltaBuilder.WriteString(chunk)
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond { if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
flush() flush()
} }
} }
}
flush() flush()
// 等待命令结束,返回最终退出状态 // 等待命令结束,返回最终退出状态
@@ -1116,6 +1152,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
if cmd.Env == nil { if cmd.Env == nil {
cmd.Env = os.Environ() cmd.Env = os.Environ()
} }
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖 // 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
has := func(k string) bool { has := func(k string) bool {
prefix := k + "=" prefix := k + "="
@@ -1159,7 +1196,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
// PTY 方案为类 UnixWindows 走原逻辑 // PTY 方案为类 UnixWindows 走原逻辑
if cb != nil { if cb != nil {
return streamCommandOutput(ctx, cmd, cb) return streamCommandOutput(ctx, cmd, cb, 0)
} }
_ = prepareShellCmdSession(cmd) _ = prepareShellCmdSession(cmd)
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
+21
View File
@@ -71,6 +71,27 @@ func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T)
} }
} }
func TestExecuteSystemCommand_FailureFormat(t *testing.T) {
executor, _ := setupTestExecutor(t)
res, err := executor.executeSystemCommand(context.Background(), map[string]interface{}{
"command": "echo fail-msg >&2; exit 7",
"shell": "sh",
})
if err != nil {
t.Fatalf("executeSystemCommand: %v", err)
}
if res == nil || !res.IsError {
t.Fatalf("expected IsError, got %+v", res)
}
text := res.Content[0].Text
if text != FormatCommandFailureResult(7, "fail-msg\n") && text != FormatCommandFailureResult(7, "fail-msg") {
t.Fatalf("unexpected failure text: %q", text)
}
if !strings.Contains(text, "exit status 7") || !strings.Contains(text, "fail-msg") {
t.Fatalf("unexpected failure text: %q", text)
}
}
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) { func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
pos1 := 1 pos1 := 1
executor, _ := setupTestExecutor(t) executor, _ := setupTestExecutor(t)
+200
View File
@@ -0,0 +1,200 @@
package security
import (
"context"
"errors"
"fmt"
"io"
"os/exec"
"sync"
"github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/schema"
)
// ConfigureShellCmdForAgentExecute 与 exec 工具一致:非交互 stdin、pager/TERM 环境、独立进程组。
func ConfigureShellCmdForAgentExecute(cmd *exec.Cmd) {
if cmd == nil {
return
}
applyDefaultTerminalEnv(cmd)
attachNonInteractiveStdin(cmd)
_ = prepareShellCmdSession(cmd)
}
// TerminateShellCmdTree 尽力终止 shell 及其子进程组(与 exec/execute 超时取消一致)。
func TerminateShellCmdTree(cmd *exec.Cmd) {
terminateCmdTree(cmd)
}
// EinoStreamingShell 为 Eino ADK execute 工具提供流式 shell,行为与 exec 对齐:
// 并发读取 stdout/stderr(定长块,非按行),避免官方 local.ExecuteStreaming 先排空 stdout
// 导致 stderr 错误(如 sudo 密码提示)长时间不可见、UI 一直显示「执行中」。
type EinoStreamingShell struct{}
// NewEinoStreamingShell 创建 execute 流式 shell 实现。
func NewEinoStreamingShell() *EinoStreamingShell {
return &EinoStreamingShell{}
}
// ExecuteStreaming 实现 filesystem.StreamingShell。
func (s *EinoStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
if input == nil || input.Command == "" {
return nil, fmt.Errorf("command is required")
}
sr, w := schema.Pipe[*filesystem.ExecuteResponse](100)
if input.RunInBackendGround {
go runShellInBackground(ctx, input.Command, w)
return sr, nil
}
go streamShellForeground(ctx, input.Command, w)
return sr, nil
}
func runShellInBackground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
defer w.Close()
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
ConfigureShellCmdForAgentExecute(cmd)
stdout, err := cmd.StdoutPipe()
if err != nil {
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
_ = stdout.Close()
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
return
}
if err := cmd.Start(); err != nil {
_ = stdout.Close()
_ = stderr.Close()
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
return
}
done := make(chan struct{})
go func() {
drainShellPipes(stdout, stderr)
_ = cmd.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
TerminateShellCmdTree(cmd)
}
exitCode := 0
_ = w.Send(&filesystem.ExecuteResponse{
Output: "command started in background\n",
ExitCode: &exitCode,
}, nil)
}
func drainShellPipes(stdout, stderr io.Reader) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(io.Discard, stdout)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(io.Discard, stderr)
}()
wg.Wait()
}
func streamShellForeground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
defer w.Close()
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
ConfigureShellCmdForAgentExecute(cmd)
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
return
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
_ = stdoutPipe.Close()
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
return
}
if err := cmd.Start(); err != nil {
_ = stdoutPipe.Close()
_ = stderrPipe.Close()
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
return
}
stopWatch := make(chan struct{})
go func() {
select {
case <-ctx.Done():
TerminateShellCmdTree(cmd)
case <-stopWatch:
}
}()
defer close(stopWatch)
chunks := make(chan string, 64)
var wg sync.WaitGroup
readFn := func(r io.Reader) {
defer wg.Done()
buf := make([]byte, 8192)
for {
n, readErr := r.Read(buf)
if n > 0 {
chunks <- string(buf[:n])
}
if readErr != nil {
return
}
}
}
wg.Add(2)
go readFn(stdoutPipe)
go readFn(stderrPipe)
go func() {
wg.Wait()
close(chunks)
}()
hadOutput := false
for chunk := range chunks {
if chunk == "" {
continue
}
hadOutput = true
if w.Send(&filesystem.ExecuteResponse{Output: chunk}, nil) {
TerminateShellCmdTree(cmd)
return
}
}
waitErr := cmd.Wait()
if waitErr == nil {
exitCode := 0
_ = w.Send(&filesystem.ExecuteResponse{ExitCode: &exitCode}, nil)
return
}
var exitError *exec.ExitError
if errors.As(waitErr, &exitError) {
exitCode := exitError.ExitCode()
resp := &filesystem.ExecuteResponse{ExitCode: &exitCode}
if !hadOutput {
resp.Output = FormatCommandFailureResult(exitCode, "")
}
_ = w.Send(resp, nil)
return
}
_ = w.Send(nil, fmt.Errorf("command failed: %w", waitErr))
}
@@ -0,0 +1,117 @@
package security
import (
"context"
"errors"
"io"
"strings"
"testing"
"time"
"github.com/cloudwego/eino/adk/filesystem"
)
func TestEinoStreamingShell_StreamsStderrBeforeStdoutEOF(t *testing.T) {
shell := NewEinoStreamingShell()
cmd := PrepareNonInteractiveShellCommand("echo err-only >&2; exit 1")
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
start := time.Now()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
if time.Since(start) > 3*time.Second {
t.Fatalf("expected fast completion, took %v", time.Since(start))
}
if !strings.Contains(got.String(), "err-only") {
t.Fatalf("expected stderr in output, got: %q", got.String())
}
}
func TestEinoStreamingShell_SudoFailsFast(t *testing.T) {
shell := NewEinoStreamingShell()
cmd := PrepareNonInteractiveShellCommand("sudo whoami && sudo cat /etc/os-release")
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
start := time.Now()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("recv: %v", rerr)
}
if resp == nil {
continue
}
got.WriteString(resp.Output)
}
if time.Since(start) > 5*time.Second {
t.Fatalf("sudo should fail quickly, took %v output=%q", time.Since(start), got.String())
}
out := got.String()
if strings.Contains(out, "command exited with non-zero code") {
t.Fatalf("legacy exit line present: %q", out)
}
if !strings.Contains(out, "sudo") && !strings.Contains(out, "password") && !strings.Contains(out, "terminal") {
t.Fatalf("expected sudo error text, got: %q", out)
}
}
func TestEinoStreamingShell_StderrWhileStdoutBlocks(t *testing.T) {
shell := NewEinoStreamingShell()
// 模拟 sudostderr 先有输出,stdout 侧进程仍挂起;旧 eino local 在首包 stderr 前不会向流写任何内容。
cmd := PrepareNonInteractiveShellCommand(`echo "password prompt" >&2; sleep 30`)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
sr, err := shell.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: cmd})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
start := time.Now()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
break
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
if strings.Contains(got.String(), "password prompt") {
break
}
}
}
if time.Since(start) > 1500*time.Millisecond {
t.Fatalf("expected stderr promptly, took %v output=%q", time.Since(start), got.String())
}
if !strings.Contains(got.String(), "password prompt") {
t.Fatalf("expected early stderr, got: %q", got.String())
}
}
+163
View File
@@ -0,0 +1,163 @@
package security
import (
"fmt"
"os"
"os/exec"
"strings"
"sync"
"time"
)
// ShellNoOutputTimeoutMessage 长时间无新 stdout/stderr 时的提示(软失败,模型可见)。
func ShellNoOutputTimeoutMessage(idleSec int) string {
return fmt.Sprintf(`命令已终止超过 %d 秒没有新的输出疑似在等待交互输入或已挂起
长时静默任务请使用末尾 & 后台运行或增大 agent.shell_no_output_timeout_seconds-1=关闭此检测
Command terminated: no new output for %d seconds (possible interactive wait or hung process).`, idleSec, idleSec)
}
// ShellInactivityWatch 在 noOutputSec 内无任何新输出时向 expired 发送信号;每次 Bump 重置计时。
// 与「仅有首包输出就永久取消计时」不同,可兜住 sudo 打印 Password 提示后继续挂起等情况。
type ShellInactivityWatch struct {
Sec int
mu sync.Mutex
timer *time.Timer
Expired chan struct{}
}
func NewShellInactivityWatch(noOutputSec int) *ShellInactivityWatch {
sec := ResolveShellNoOutputTimeoutSeconds(noOutputSec)
if sec <= 0 {
return nil
}
w := &ShellInactivityWatch{
Sec: sec,
Expired: make(chan struct{}, 1),
}
w.Bump()
return w
}
func (w *ShellInactivityWatch) Bump() {
if w == nil || w.Sec <= 0 {
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.timer != nil {
w.timer.Stop()
}
w.timer = time.AfterFunc(time.Duration(w.Sec)*time.Second, func() {
select {
case w.Expired <- struct{}{}:
default:
}
})
}
func (w *ShellInactivityWatch) Stop() {
if w == nil {
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.timer != nil {
w.timer.Stop()
w.timer = nil
}
}
// ResolveShellNoOutputTimeoutSeconds0=默认 3005 分钟);-1=关闭;>0=自定义。
func ResolveShellNoOutputTimeoutSeconds(sec int) int {
if sec < 0 {
return 0
}
if sec == 0 {
return 300
}
return sec
}
// PrependNonInteractiveShellExports 为 sh -c 注入通用非交互环境(pager 等),不维护命令黑名单。
func PrependNonInteractiveShellExports(shellCommand string) string {
if strings.TrimSpace(shellCommand) == "" {
return shellCommand
}
upper := strings.ToUpper(shellCommand)
var pairs []string
add := func(key, val string) {
if strings.Contains(upper, strings.ToUpper(key)) {
return
}
pairs = append(pairs, key+"="+val)
}
add("GIT_PAGER", "cat")
add("PAGER", "cat")
add("SYSTEMD_PAGER", "cat")
add("DEBIAN_FRONTEND", "noninteractive")
if len(pairs) == 0 {
return shellCommand
}
return "export " + strings.Join(pairs, " ") + "\n" + shellCommand
}
// PrependNonInteractiveStdinRedirect 为 sh -c 关闭 stdin(与 attachNonInteractiveStdin 等价),
// 使 read/input()/sudo -S 等从 stdin 读取的程序快速失败而非挂起。已含 </dev/null 时不重复注入。
func PrependNonInteractiveStdinRedirect(shellCommand string) string {
if strings.TrimSpace(shellCommand) == "" {
return shellCommand
}
lower := strings.ToLower(shellCommand)
if strings.Contains(lower, "</dev/null") || strings.Contains(lower, "0</dev/null") {
return shellCommand
}
return "exec </dev/null\n" + shellCommand
}
// PrepareNonInteractiveShellCommand 组合非交互包装:stdin 关闭 + pager 等环境变量(零名单)。
func PrepareNonInteractiveShellCommand(shellCommand string) string {
return PrependNonInteractiveStdinRedirect(PrependNonInteractiveShellExports(shellCommand))
}
// ApplyNonInteractivePagerEnv 为 exec.Cmd 补齐与 PrependNonInteractiveShellExports 一致的环境变量。
func ApplyNonInteractivePagerEnv(cmdEnv []string) []string {
if cmdEnv == nil {
cmdEnv = []string{}
}
has := func(k string) bool {
prefix := k + "="
for _, e := range cmdEnv {
if strings.HasPrefix(e, prefix) {
return true
}
}
return false
}
if !has("GIT_PAGER") {
cmdEnv = append(cmdEnv, "GIT_PAGER=cat")
}
if !has("PAGER") {
cmdEnv = append(cmdEnv, "PAGER=cat")
}
if !has("SYSTEMD_PAGER") {
cmdEnv = append(cmdEnv, "SYSTEMD_PAGER=cat")
}
if !has("DEBIAN_FRONTEND") {
cmdEnv = append(cmdEnv, "DEBIAN_FRONTEND=noninteractive")
}
return cmdEnv
}
// attachNonInteractiveStdin 关闭交互式 stdin,使部分命令快速失败而非等待输入。
func attachNonInteractiveStdin(cmd *exec.Cmd) {
if cmd == nil || cmd.Stdin != nil {
return
}
f, err := os.Open(os.DevNull)
if err != nil {
return
}
cmd.Stdin = f
}
@@ -0,0 +1,128 @@
package security
import (
"context"
"os"
"os/exec"
"strings"
"testing"
"time"
)
func TestPrependNonInteractiveShellExports(t *testing.T) {
out := PrependNonInteractiveShellExports("echo hi")
if !strings.Contains(out, "GIT_PAGER=cat") || !strings.Contains(out, "PAGER=cat") {
t.Fatalf("missing pager exports: %q", out)
}
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
t.Fatalf("command suffix lost: %q", out)
}
skip := PrependNonInteractiveShellExports("GIT_PAGER=less echo hi")
if strings.Contains(skip, "export GIT_PAGER=cat") {
t.Fatalf("should not override existing GIT_PAGER: %q", skip)
}
}
func TestPrependNonInteractiveStdinRedirect(t *testing.T) {
out := PrependNonInteractiveStdinRedirect("echo hi")
if !strings.HasPrefix(out, "exec </dev/null") {
t.Fatalf("missing stdin redirect: %q", out)
}
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
t.Fatalf("command suffix lost: %q", out)
}
skip := PrependNonInteractiveStdinRedirect("cmd </dev/null")
if strings.HasPrefix(skip, "exec </dev/null") {
t.Fatalf("should not double redirect: %q", skip)
}
}
func TestPrepareNonInteractiveShellCommand(t *testing.T) {
out := PrepareNonInteractiveShellCommand("echo hi")
if !strings.Contains(out, "exec </dev/null") {
t.Fatalf("missing stdin redirect: %q", out)
}
if !strings.Contains(out, "GIT_PAGER=cat") {
t.Fatalf("missing pager export: %q", out)
}
}
func TestNewShellInactivityWatch(t *testing.T) {
w := NewShellInactivityWatch(1)
if w == nil {
t.Fatal("expected watch")
}
w.Bump()
select {
case <-w.Expired:
case <-time.After(3 * time.Second):
t.Fatal("expected inactivity fire within 3s")
}
}
func TestResolveShellNoOutputTimeoutSeconds(t *testing.T) {
if ResolveShellNoOutputTimeoutSeconds(0) != 300 {
t.Fatal("zero should default to 300")
}
if ResolveShellNoOutputTimeoutSeconds(-1) != 0 {
t.Fatal("-1 should disable")
}
if ResolveShellNoOutputTimeoutSeconds(30) != 30 {
t.Fatal("explicit value")
}
}
// TestNonInteractiveStdinReadExitsQuickly 验证 exec </dev/null + attachNonInteractiveStdin 时 read 立即 EOF,不挂起。
func TestNonInteractiveStdinReadExitsQuickly(t *testing.T) {
if testing.Short() {
t.Skip("skipping shell integration in -short")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", PrepareNonInteractiveShellCommand(`read x; echo "x=<$x>"`))
attachNonInteractiveStdin(cmd)
start := time.Now()
out, err := cmd.CombinedOutput()
elapsed := time.Since(start)
if elapsed > 2*time.Second {
t.Fatalf("read with closed stdin took %v, want <2s", elapsed)
}
if err != nil {
t.Fatalf("unexpected error: %v output=%q", err, out)
}
if !strings.Contains(string(out), "x=<>") {
t.Fatalf("unexpected output: %q", out)
}
}
// TestNonInteractiveStdinReadBlocksWithoutRedirect 对照:stdin 为永不写入的管道时 read 会挂起。
func TestNonInteractiveStdinReadBlocksWithoutRedirect(t *testing.T) {
if testing.Short() {
t.Skip("skipping shell integration in -short")
}
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
defer r.Close()
// 保持 w 打开且不写数据,模拟「等待用户输入」
cmd := exec.Command("sh", "-c", `read x; echo done`)
cmd.Stdin = r
done := make(chan error, 1)
go func() { done <- cmd.Run() }()
select {
case err := <-done:
t.Fatalf("expected hang, but command finished: %v", err)
case <-time.After(500 * time.Millisecond):
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = w.Close()
<-done // 等待 goroutine 退出
}
}
+6 -2
View File
@@ -2580,6 +2580,8 @@
"agentModeSingle": "Single-agent (Eino ADK)", "agentModeSingle": "Single-agent (Eino ADK)",
"agentModeMulti": "Multi-agent (Eino)", "agentModeMulti": "Multi-agent (Eino)",
"agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).", "agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).",
"concurrency": "Concurrency",
"concurrencyHint": "Number of subtasks to run in parallel (1-8). Default 1 is serial; use 1-2 for scan-heavy tasks.",
"scheduleMode": "Schedule mode", "scheduleMode": "Schedule mode",
"scheduleModeManual": "Manual", "scheduleModeManual": "Manual",
"scheduleModeCron": "Cron expression", "scheduleModeCron": "Cron expression",
@@ -2594,8 +2596,8 @@
"tasksList": "Task list (one task per line)", "tasksList": "Task list (one task per line)",
"tasksListPlaceholder": "Enter task list, one per line", "tasksListPlaceholder": "Enter task list, one per line",
"tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com", "tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com",
"tasksListHint": "Enter one task command per line; the system will execute them in order. Empty lines are ignored.", "tasksListHint": "Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
"tasksListHintFull": "Hint: Enter one task command per line; the system will execute these tasks in order. Empty lines are ignored.", "tasksListHintFull": "Hint: Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
"createQueue": "Create queue" "createQueue": "Create queue"
}, },
"batchQueueDetailModal": { "batchQueueDetailModal": {
@@ -2629,6 +2631,8 @@
"scheduleToggleFailed": "Failed to update schedule toggle", "scheduleToggleFailed": "Failed to update schedule toggle",
"completedAt": "Completed at", "completedAt": "Completed at",
"taskTotal": "Total tasks", "taskTotal": "Total tasks",
"concurrency": "Concurrency",
"concurrencyEditHint": "Click to edit. Cannot change while the queue is running.",
"taskList": "Task list", "taskList": "Task list",
"startLabel": "Start", "startLabel": "Start",
"completeLabel": "Complete", "completeLabel": "Complete",
+6 -2
View File
@@ -2568,6 +2568,8 @@
"agentModeSingle": "单代理(Eino ADK", "agentModeSingle": "单代理(Eino ADK",
"agentModeMulti": "多代理(Eino", "agentModeMulti": "多代理(Eino",
"agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。", "agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。",
"concurrency": "并发数",
"concurrencyHint": "同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。",
"scheduleMode": "调度方式", "scheduleMode": "调度方式",
"scheduleModeManual": "手工执行", "scheduleModeManual": "手工执行",
"scheduleModeCron": "调度表达式(Cron", "scheduleModeCron": "调度表达式(Cron",
@@ -2582,8 +2584,8 @@
"tasksList": "任务列表(每行一个任务)", "tasksList": "任务列表(每行一个任务)",
"tasksListPlaceholder": "请输入任务列表,每行一个任务", "tasksListPlaceholder": "请输入任务列表,每行一个任务",
"tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名", "tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名",
"tasksListHint": "每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。", "tasksListHint": "每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
"tasksListHintFull": "提示:每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。", "tasksListHintFull": "提示:每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
"createQueue": "创建队列" "createQueue": "创建队列"
}, },
"batchQueueDetailModal": { "batchQueueDetailModal": {
@@ -2617,6 +2619,8 @@
"scheduleToggleFailed": "更新调度开关失败", "scheduleToggleFailed": "更新调度开关失败",
"completedAt": "完成时间", "completedAt": "完成时间",
"taskTotal": "任务总数", "taskTotal": "任务总数",
"concurrency": "并发数",
"concurrencyEditHint": "点击可修改;队列运行中不可改。",
"taskList": "任务列表", "taskList": "任务列表",
"startLabel": "开始", "startLabel": "开始",
"completeLabel": "完成", "completeLabel": "完成",
+24 -1
View File
@@ -3110,7 +3110,17 @@ async function cancelMCPToolExecutionSubmit(executionId, userNote, options = {})
if (!executionId) { if (!executionId) {
return; return;
} }
let conversationId = '';
if (typeof monitorState !== 'undefined' && Array.isArray(monitorState.executions)) {
const exec = monitorState.executions.find(e => e && e.id === executionId);
if (exec) {
conversationId = (exec.conversationId || '').trim();
}
}
try { try {
if (conversationId && typeof requestCancelWithContinue === 'function') {
await requestCancelWithContinue(conversationId, userNote || '');
} else {
const res = await apiFetch(`/api/monitor/execution/${encodeURIComponent(executionId)}/cancel`, { const res = await apiFetch(`/api/monitor/execution/${encodeURIComponent(executionId)}/cancel`, {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
@@ -3120,6 +3130,7 @@ async function cancelMCPToolExecutionSubmit(executionId, userNote, options = {})
if (!res.ok) { if (!res.ok) {
throw new Error(body.error || body.message || res.statusText); throw new Error(body.error || body.message || res.statusText);
} }
}
const okMsg = typeof window.t === 'function' ? window.t('mcpDetailModal.abortSuccess') : '已发送终止请求'; const okMsg = typeof window.t === 'function' ? window.t('mcpDetailModal.abortSuccess') : '已发送终止请求';
alert(okMsg); alert(okMsg);
if (options.refreshDetail && typeof showMCPDetail === 'function') { if (options.refreshDetail && typeof showMCPDetail === 'function') {
@@ -3136,7 +3147,7 @@ async function cancelMCPToolExecutionSubmit(executionId, userNote, options = {})
} }
/** /**
* 取消单次 MCP 工具执行监控页终止弹出说明框后提交仅取消该次 tools/call不停止整条对话/迭代任务 * 取消单次 MCP 工具执行监控页终止 conversationId 时复用对话页中断并继续弹窗与 API
* @param {string} executionId * @param {string} executionId
* @param {{ refreshDetail?: boolean }} [options] * @param {{ refreshDetail?: boolean }} [options]
*/ */
@@ -3144,6 +3155,18 @@ async function cancelMCPToolExecution(executionId, options = {}) {
if (!executionId) { if (!executionId) {
return; return;
} }
let conversationId = '';
if (typeof monitorState !== 'undefined' && Array.isArray(monitorState.executions)) {
const exec = monitorState.executions.find(e => e && e.id === executionId);
if (exec) {
conversationId = (exec.conversationId || '').trim();
}
}
if (conversationId && typeof openUserInterruptModal === 'function') {
openUserInterruptModal(null, conversationId);
window.__monitorInterruptContext = { executionId: executionId, options: options || {} };
return;
}
openMcpToolAbortModal(executionId, options); openMcpToolAbortModal(executionId, options);
} }
+36
View File
@@ -1003,6 +1003,7 @@ function openUserInterruptModal(progressId, conversationId) {
function closeUserInterruptModal() { function closeUserInterruptModal() {
userInterruptModalPending = null; userInterruptModalPending = null;
window.__monitorInterruptContext = null;
closeAppModal('user-interrupt-modal'); closeAppModal('user-interrupt-modal');
} }
@@ -1012,6 +1013,7 @@ async function submitUserInterruptContinue() {
} }
const reason = (document.getElementById('user-interrupt-reason') && document.getElementById('user-interrupt-reason').value || '').trim(); const reason = (document.getElementById('user-interrupt-reason') && document.getElementById('user-interrupt-reason').value || '').trim();
const { progressId, conversationId } = userInterruptModalPending; const { progressId, conversationId } = userInterruptModalPending;
const monitorCtx = window.__monitorInterruptContext;
closeUserInterruptModal(); closeUserInterruptModal();
const stopBtn = progressId ? document.getElementById(`${progressId}-stop-btn`) : null; const stopBtn = progressId ? document.getElementById(`${progressId}-stop-btn`) : null;
try { try {
@@ -1020,6 +1022,13 @@ async function submitUserInterruptContinue() {
stopBtn.textContent = typeof window.t === 'function' ? window.t('tasks.interruptSubmitting') : '提交中...'; stopBtn.textContent = typeof window.t === 'function' ? window.t('tasks.interruptSubmitting') : '提交中...';
} }
await requestCancelWithContinue(conversationId, reason); await requestCancelWithContinue(conversationId, reason);
if (monitorCtx && monitorCtx.executionId && typeof refreshMonitorPanel === 'function') {
const page = (typeof monitorState !== 'undefined' && monitorState.pagination && monitorState.pagination.page)
? monitorState.pagination.page
: 1;
await refreshMonitorPanel(page);
window.__monitorInterruptContext = null;
}
loadActiveTasks(); loadActiveTasks();
} catch (error) { } catch (error) {
console.error('中断并继续失败:', error); console.error('中断并继续失败:', error);
@@ -3536,6 +3545,33 @@ const monitorState = {
} }
}; };
let monitorPollTimer = null;
const MONITOR_POLL_INTERVAL_MS = 3000;
function startMonitorPoll() {
stopMonitorPoll();
monitorPollTimer = setInterval(function () {
const page = document.getElementById('page-mcp-monitor');
if (!page || !page.classList.contains('active')) {
stopMonitorPoll();
return;
}
if (document.hidden) {
return;
}
if (typeof refreshMonitorPanel === 'function') {
refreshMonitorPanel().catch(function () { /* ignore */ });
}
}, MONITOR_POLL_INTERVAL_MS);
}
function stopMonitorPoll() {
if (monitorPollTimer) {
clearInterval(monitorPollTimer);
monitorPollTimer = null;
}
}
function openMonitorPanel() { function openMonitorPanel() {
// 切换到MCP监控页面 // 切换到MCP监控页面
if (typeof switchPage === 'function') { if (typeof switchPage === 'function') {
+3
View File
@@ -356,6 +356,9 @@ async function initPage(pageId) {
if (typeof refreshMonitorPanel === 'function') { if (typeof refreshMonitorPanel === 'function') {
refreshMonitorPanel(); refreshMonitorPanel();
} }
if (typeof startMonitorPoll === 'function') {
startMonitorPoll();
}
break; break;
case 'mcp-management': case 'mcp-management':
// 初始化MCP管理 // 初始化MCP管理
+77
View File
@@ -990,6 +990,7 @@ async function createBatchQueue() {
const roleSelect = document.getElementById('batch-queue-role'); const roleSelect = document.getElementById('batch-queue-role');
const projectSelect = document.getElementById('batch-queue-project-id'); const projectSelect = document.getElementById('batch-queue-project-id');
const agentModeSelect = document.getElementById('batch-queue-agent-mode'); const agentModeSelect = document.getElementById('batch-queue-agent-mode');
const concurrencyInput = document.getElementById('batch-queue-concurrency');
const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode'); const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode');
const cronExprInput = document.getElementById('batch-queue-cron-expr'); const cronExprInput = document.getElementById('batch-queue-cron-expr');
const executeNowCheckbox = document.getElementById('batch-queue-execute-now'); const executeNowCheckbox = document.getElementById('batch-queue-execute-now');
@@ -1019,6 +1020,9 @@ async function createBatchQueue() {
const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual'; const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual';
const cronExpr = cronExprInput ? cronExprInput.value.trim() : ''; const cronExpr = cronExprInput ? cronExprInput.value.trim() : '';
const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false; const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false;
let concurrency = concurrencyInput ? parseInt(concurrencyInput.value, 10) : 1;
if (!Number.isFinite(concurrency) || concurrency < 1) concurrency = 1;
if (concurrency > 8) concurrency = 8;
if (scheduleMode === 'cron' && !cronExpr) { if (scheduleMode === 'cron' && !cronExpr) {
alert(_t('batchImportModal.cronExprRequired')); alert(_t('batchImportModal.cronExprRequired'));
return; return;
@@ -1043,6 +1047,7 @@ async function createBatchQueue() {
cronExpr, cronExpr,
executeNow, executeNow,
projectId, projectId,
concurrency,
}), }),
}); });
@@ -1489,6 +1494,7 @@ async function showBatchQueueDetail(queueId) {
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.concurrency'))}</span><span class="bq-kv__v" id="bq-concurrency-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditConcurrency()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span>` : escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div>
${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''} ${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''}
</section> </section>
@@ -2287,6 +2293,75 @@ async function saveInlineAgentMode() {
} }
} }
function normalizeBatchQueueConcurrencyInput(raw) {
let n = parseInt(raw, 10);
if (!Number.isFinite(n) || n < 1) n = 1;
if (n > 8) n = 8;
return n;
}
// --- 内联编辑:并发数 ---
function startInlineEditConcurrency() {
const container = document.getElementById('bq-concurrency-val');
if (!container) return;
const queueId = batchQueuesState.currentQueueId;
if (!queueId) return;
apiFetch(`/api/batch-tasks/${queueId}`).then(r => r.json()).then(detail => {
const queue = detail.queue || {};
const current = normalizeBatchQueueConcurrencyInput(queue.concurrency || 1);
container.innerHTML = `<span class="bq-inline-edit-controls">
<input type="number" id="bq-edit-concurrency" min="1" max="8" value="${current}" style="width:72px;" />
</span>`;
const inp = document.getElementById('bq-edit-concurrency');
if (!inp) return;
inp.focus();
inp.select();
let cancelled = false;
inp.addEventListener('keydown', (e) => {
if (e.key === 'Enter') { e.preventDefault(); inp.blur(); }
if (e.key === 'Escape') { cancelled = true; cancelAllInlineEdits(); }
});
inp.addEventListener('blur', () => {
if (!cancelled) saveInlineConcurrency();
});
});
}
async function saveInlineConcurrency() {
if (_bqInlineSaving) return;
_bqInlineSaving = true;
const queueId = batchQueuesState.currentQueueId;
if (!queueId) { _bqInlineSaving = false; return; }
const inp = document.getElementById('bq-edit-concurrency');
const concurrency = normalizeBatchQueueConcurrencyInput(inp ? inp.value : 1);
try {
const detailResp = await apiFetch(`/api/batch-tasks/${queueId}`);
const detail = await detailResp.json();
const q = detail.queue || {};
const response = await apiFetch(`/api/batch-tasks/${queueId}/metadata`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
title: q.title || '',
role: q.role || '',
agentMode: q.agentMode || 'eino_single',
concurrency,
}),
});
if (!response.ok) {
const result = await response.json().catch(() => ({}));
throw new Error(result.error || _t('tasks.updateTaskFailed'));
}
_bqInlineSaving = false;
showBatchQueueDetail(queueId);
refreshBatchQueues();
} catch (e) {
_bqInlineSaving = false;
console.error(e);
alert(e.message);
}
}
// --- 单条执行 --- // --- 单条执行 ---
async function runSingleBatchTask(queueId, taskId) { async function runSingleBatchTask(queueId, taskId) {
if (!queueId || !taskId) return; if (!queueId || !taskId) return;
@@ -2441,6 +2516,8 @@ window.startInlineEditRole = startInlineEditRole;
window.saveInlineRole = saveInlineRole; window.saveInlineRole = saveInlineRole;
window.startInlineEditAgentMode = startInlineEditAgentMode; window.startInlineEditAgentMode = startInlineEditAgentMode;
window.saveInlineAgentMode = saveInlineAgentMode; window.saveInlineAgentMode = saveInlineAgentMode;
window.startInlineEditConcurrency = startInlineEditConcurrency;
window.saveInlineConcurrency = saveInlineConcurrency;
window.runSingleBatchTask = runSingleBatchTask; window.runSingleBatchTask = runSingleBatchTask;
window.startInlineEditSchedule = startInlineEditSchedule; window.startInlineEditSchedule = startInlineEditSchedule;
window.toggleInlineScheduleCron = toggleInlineScheduleCron; window.toggleInlineScheduleCron = toggleInlineScheduleCron;
+5
View File
@@ -4010,6 +4010,11 @@
</select> </select>
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div> <div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div>
</div> </div>
<div class="form-group">
<label for="batch-queue-concurrency" data-i18n="batchImportModal.concurrency">并发数</label>
<input type="number" id="batch-queue-concurrency" min="1" max="8" value="1" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;" />
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.concurrencyHint">同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。</div>
</div>
<div class="form-group"> <div class="form-group">
<label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label> <label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label>
<select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;"> <select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;">