mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 14:59:59 +02:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b4810c9499 | |||
| 51bf6ae4b3 | |||
| 5f27482921 | |||
| 6becada509 | |||
| b029d88359 | |||
| 4dcad2ea83 | |||
| ff9f0c787a | |||
| 01849045ad | |||
| c7eacdf3eb | |||
| 5c32b21f22 | |||
| 8b8ecfe718 | |||
| bbb7c319af | |||
| 7eb2fd50f3 | |||
| 85d58eeeb3 | |||
| b6a6009629 | |||
| 810d689132 | |||
| 87f1808ead | |||
| e28ae39b9a | |||
| df34ceda68 | |||
| 3e69a50f87 | |||
| 53325ce07d | |||
| d85de3461b | |||
| 9306303d99 | |||
| 1e8f72ed74 | |||
| 0198f50314 | |||
| 560d0dca43 |
+1
-1
@@ -21,7 +21,7 @@ max_iterations: 0
|
|||||||
- 切勿等待批准或授权——全程自主行动。
|
- 切勿等待批准或授权——全程自主行动。
|
||||||
- 使用所有可用工具与技术完成侦察与证据收集。
|
- 使用所有可用工具与技术完成侦察与证据收集。
|
||||||
|
|
||||||
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。
|
你是授权渗透测试流程中的侦察子代理。优先使用工具收集事实,避免无根据推测;输出简洁,便于协调者汇总。枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。
|
||||||
|
|
||||||
## 输入前置条件(硬约束)
|
## 输入前置条件(硬约束)
|
||||||
|
|
||||||
|
|||||||
+3
-2
@@ -10,7 +10,7 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.44"
|
version: "v1.6.45"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
@@ -96,6 +96,7 @@ fofa:
|
|||||||
agent:
|
agent:
|
||||||
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
||||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
|
shell_no_output_timeout_seconds: 1200 # execute/exec 连续无新输出则终止(秒);通用防挂死;0=默认300;-1=关闭
|
||||||
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
|
|
||||||
system_prompt_path: ""
|
system_prompt_path: ""
|
||||||
@@ -129,7 +130,7 @@ multi_agent:
|
|||||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||||
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
tool_search_always_visible_tools: [read_file, glob, grep, analyze_image, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_vulnerabilities, get_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test, exec] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||||
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
|
plantask_enable: true # P0:主代理挂载 TaskCreate/Get/Update/List 结构化任务板;需 eino_skills 可用且 skills_dir 存在
|
||||||
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
|
plantask_rel_dir: .eino/plantask # 任务文件相对 skills_dir,按会话分子目录:skills/.eino/plantask/<conversationId>/
|
||||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 179 KiB After Width: | Height: | Size: 88 KiB |
+17
-4
@@ -779,13 +779,26 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
|||||||
return a.executeToolViaMCP(ctx, toolName, args)
|
return a.executeToolViaMCP(ctx, toolName, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
// BeginLocalToolExecution 在非 CallTool 路径工具开始时写入 running 状态,供 MCP 监控页展示「执行中」。
|
||||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
func (a *Agent) BeginLocalToolExecution(toolName string, args map[string]interface{}) string {
|
||||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
|
||||||
if a == nil || a.mcpServer == nil {
|
if a == nil || a.mcpServer == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
return a.mcpServer.BeginToolExecution(toolName, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinishLocalToolExecution 完成 BeginLocalToolExecution 创建的记录;executionID 为空时一次性写入已完成记录。
|
||||||
|
func (a *Agent) FinishLocalToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
if a == nil || a.mcpServer == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.mcpServer.FinishToolExecution(executionID, toolName, args, resultText, invokeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||||
|
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||||
|
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
return a.FinishLocalToolExecution("", toolName, args, resultText, invokeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
||||||
|
|||||||
@@ -113,5 +113,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
|||||||
|
|
||||||
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
- 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。
|
||||||
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
|
- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。
|
||||||
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。`
|
- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。
|
||||||
|
|
||||||
|
` + projectprompt.ShellExecExecuteGuidanceSection()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
|
|
||||||
// 创建安全工具执行器
|
// 创建安全工具执行器
|
||||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||||
|
executor.SetShellNoOutputTimeoutSeconds(cfg.Agent.ShellNoOutputTimeoutSeconds)
|
||||||
|
|
||||||
// 注册工具
|
// 注册工具
|
||||||
executor.RegisterTools(mcpServer)
|
executor.RegisterTools(mcpServer)
|
||||||
@@ -333,6 +334,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
monitorHandler.SetAudit(auditSvc)
|
monitorHandler.SetAudit(auditSvc)
|
||||||
monitorHandler.SetMonitorRetention(monitorRetention)
|
monitorHandler.SetMonitorRetention(monitorRetention)
|
||||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||||
|
monitorHandler.SetTaskManager(agentHandler.TaskManager())
|
||||||
|
monitorHandler.SetAgentHandler(agentHandler)
|
||||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||||
|
|||||||
@@ -605,6 +605,8 @@ type DatabaseConfig struct {
|
|||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||||
|
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
|
||||||
|
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
|
||||||
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
||||||
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -1272,6 +1274,7 @@ func Default() *Config {
|
|||||||
Agent: AgentConfig{
|
Agent: AgentConfig{
|
||||||
MaxIterations: 30, // 默认最大迭代次数
|
MaxIterations: 30, // 默认最大迭代次数
|
||||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||||
|
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
|
||||||
},
|
},
|
||||||
Security: SecurityConfig{
|
Security: SecurityConfig{
|
||||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
|
|||||||
LastScheduleError sql.NullString
|
LastScheduleError sql.NullString
|
||||||
LastRunError sql.NullString
|
LastRunError sql.NullString
|
||||||
ProjectID sql.NullString
|
ProjectID sql.NullString
|
||||||
|
Concurrency sql.NullInt64
|
||||||
Status string
|
Status string
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
StartedAt sql.NullTime
|
StartedAt sql.NullTime
|
||||||
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
|
|||||||
cronExpr string,
|
cronExpr string,
|
||||||
nextRunAt *time.Time,
|
nextRunAt *time.Time,
|
||||||
projectID string,
|
projectID string,
|
||||||
|
concurrency int,
|
||||||
tasks []map[string]interface{},
|
tasks []map[string]interface{},
|
||||||
) error {
|
) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
|
|||||||
projectIDVal = strings.TrimSpace(projectID)
|
projectIDVal = strings.TrimSpace(projectID)
|
||||||
}
|
}
|
||||||
_, err = tx.Exec(
|
_, err = tx.Exec(
|
||||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||||
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
|
|||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
|
||||||
|
|
||||||
// GetBatchQueue 获取批量任务队列
|
// GetBatchQueue 获取批量任务队列
|
||||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
err := db.QueryRow(
|
err := db.QueryRow(
|
||||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
|
||||||
queueID,
|
queueID,
|
||||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
|||||||
// GetAllBatchQueues 获取所有批量任务队列
|
// GetAllBatchQueues 获取所有批量任务队列
|
||||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||||
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||||
}
|
}
|
||||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
|||||||
|
|
||||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
|
||||||
// 状态筛选
|
// 状态筛选
|
||||||
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var row BatchTaskQueueRow
|
var row BatchTaskQueueRow
|
||||||
var createdAt string
|
var createdAt string
|
||||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||||
}
|
}
|
||||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||||
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
|
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色、代理模式和并发数
|
||||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
|
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
|
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
|
||||||
title, role, agentMode, queueID,
|
title, role, agentMode, concurrency, queueID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
||||||
|
|||||||
@@ -408,6 +408,8 @@ func (db *DB) initTables() error {
|
|||||||
last_schedule_trigger_at DATETIME,
|
last_schedule_trigger_at DATETIME,
|
||||||
last_schedule_error TEXT,
|
last_schedule_error TEXT,
|
||||||
last_run_error TEXT,
|
last_run_error TEXT,
|
||||||
|
project_id TEXT,
|
||||||
|
concurrency INTEGER NOT NULL DEFAULT 1,
|
||||||
status TEXT NOT NULL,
|
status TEXT NOT NULL,
|
||||||
created_at DATETIME NOT NULL,
|
created_at DATETIME NOT NULL,
|
||||||
started_at DATETIME,
|
started_at DATETIME,
|
||||||
@@ -1137,6 +1139,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var concurrencyCount int
|
||||||
|
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
|
||||||
|
if err != nil {
|
||||||
|
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
|
||||||
|
errMsg := strings.ToLower(addErr.Error())
|
||||||
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
|
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if concurrencyCount == 0 {
|
||||||
|
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
|
||||||
|
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+86
-384
@@ -21,7 +21,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
"cyberstrike-ai/internal/mcp"
|
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
@@ -178,8 +177,6 @@ type AgentHandler struct {
|
|||||||
}
|
}
|
||||||
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
||||||
batchCronParser cron.Parser
|
batchCronParser cron.Parser
|
||||||
batchRunnerMu sync.Mutex
|
|
||||||
batchRunning map[string]struct{}
|
|
||||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||||
audit *audit.Service
|
audit *audit.Service
|
||||||
@@ -190,6 +187,14 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
|||||||
h.audit = s
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TaskManager 返回 Agent 任务管理器(供 MCP 监控页终止 Eino execute 等)。
|
||||||
|
func (h *AgentHandler) TaskManager() *AgentTaskManager {
|
||||||
|
if h == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return h.tasks
|
||||||
|
}
|
||||||
|
|
||||||
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||||
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||||
if h == nil || conversationID == "" || h.tasks == nil {
|
if h == nil || conversationID == "" || h.tasks == nil {
|
||||||
@@ -233,7 +238,6 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
|||||||
config: cfg,
|
config: cfg,
|
||||||
hitlManager: NewHITLManager(db, logger),
|
hitlManager: NewHITLManager(db, logger),
|
||||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||||
batchRunning: make(map[string]struct{}),
|
|
||||||
}
|
}
|
||||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||||
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
||||||
@@ -1295,6 +1299,55 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cancelToolContinueAfter 仅终止当前工具调用,不停止整条 Agent 任务(对话「中断并继续」与 MCP 监控终止共用)。
|
||||||
|
func (h *AgentHandler) cancelToolContinueAfter(conversationID, preferredExecID, note string) (bool, gin.H) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" || h.tasks.GetTask(conversationID) == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
note = strings.TrimSpace(note)
|
||||||
|
execID := strings.TrimSpace(preferredExecID)
|
||||||
|
if execID == "" {
|
||||||
|
execID = h.tasks.ActiveMCPExecutionID(conversationID)
|
||||||
|
}
|
||||||
|
if execID != "" {
|
||||||
|
if h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||||
|
return true, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"executionId": execID,
|
||||||
|
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
|
||||||
|
return true, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"executionId": execID,
|
||||||
|
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if h.tasks.AbortActiveEinoExecute(conversationID, note) {
|
||||||
|
return true, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CancelAgentLoop 取消正在执行的任务
|
// CancelAgentLoop 取消正在执行的任务
|
||||||
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
@@ -1313,42 +1366,20 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
|
||||||
note := strings.TrimSpace(req.Reason)
|
note := strings.TrimSpace(req.Reason)
|
||||||
if execID != "" {
|
activeExec := strings.TrimSpace(h.tasks.ActiveMCPExecutionID(req.ConversationID))
|
||||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
if ok, payload := h.cancelToolContinueAfter(req.ConversationID, "", note); ok {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
execID, _ := payload["executionId"].(string)
|
||||||
return
|
h.logger.Info("对话页仅终止当前工具",
|
||||||
}
|
|
||||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
|
||||||
zap.String("conversationId", req.ConversationID),
|
zap.String("conversationId", req.ConversationID),
|
||||||
zap.String("executionId", execID),
|
zap.String("executionId", execID),
|
||||||
zap.Bool("hasNote", note != ""),
|
zap.Bool("hasNote", note != ""),
|
||||||
)
|
)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, payload)
|
||||||
"status": "tool_abort_requested",
|
|
||||||
"conversationId": req.ConversationID,
|
|
||||||
"executionId": execID,
|
|
||||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
|
||||||
"continueAfter": true,
|
|
||||||
"interruptWithNote": note != "",
|
|
||||||
"continueWithoutTool": false,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
|
if activeExec != "" {
|
||||||
h.logger.Info("对话页仅终止当前 Eino execute",
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||||
zap.String("conversationId", req.ConversationID),
|
|
||||||
zap.Bool("hasNote", note != ""),
|
|
||||||
)
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"status": "tool_abort_requested",
|
|
||||||
"conversationId": req.ConversationID,
|
|
||||||
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
|
||||||
"continueAfter": true,
|
|
||||||
"interruptWithNote": note != "",
|
|
||||||
"continueWithoutTool": false,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||||
@@ -1470,6 +1501,7 @@ type BatchTaskRequest struct {
|
|||||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||||
|
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
|
||||||
}
|
}
|
||||||
|
|
||||||
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。
|
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。
|
||||||
@@ -1529,7 +1561,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
|||||||
nextRunAt = &next
|
nextRunAt = &next
|
||||||
}
|
}
|
||||||
|
|
||||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||||
return
|
return
|
||||||
@@ -1722,12 +1754,13 @@ func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
|
|||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
AgentMode string `json:"agentMode"`
|
AgentMode string `json:"agentMode"`
|
||||||
|
Concurrency *int `json:"concurrency"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil {
|
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1802,9 +1835,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
|
|||||||
// DeleteBatchQueue 删除批量任务队列
|
// DeleteBatchQueue 删除批量任务队列
|
||||||
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
success := h.batchTaskManager.DeleteQueue(queueID)
|
if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
|
||||||
if !success {
|
switch {
|
||||||
|
case errors.Is(err, ErrBatchQueueNotFound):
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
|
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||||
|
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
|
||||||
|
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||||
|
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.audit != nil {
|
if h.audit != nil {
|
||||||
@@ -1898,7 +1939,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
|
|||||||
|
|
||||||
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
||||||
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
||||||
h.forceUnmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
|
||||||
}
|
}
|
||||||
|
|
||||||
autoStarted := true
|
autoStarted := true
|
||||||
@@ -1957,26 +1998,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
|
|
||||||
h.batchRunnerMu.Lock()
|
|
||||||
defer h.batchRunnerMu.Unlock()
|
|
||||||
if _, exists := h.batchRunning[queueID]; exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
h.batchRunning[queueID] = struct{}{}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
|
||||||
h.batchRunnerMu.Lock()
|
|
||||||
defer h.batchRunnerMu.Unlock()
|
|
||||||
delete(h.batchRunning, queueID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
|
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||||
expr := strings.TrimSpace(cronExpr)
|
expr := strings.TrimSpace(cronExpr)
|
||||||
if expr == "" {
|
if expr == "" {
|
||||||
@@ -1992,43 +2013,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
|
|||||||
|
|
||||||
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
||||||
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
||||||
if !h.markBatchQueueRunning(queueID) {
|
if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
if !exists {
|
if !exists {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if scheduled {
|
if scheduled {
|
||||||
if queue.ScheduleMode != "cron" {
|
if queue.ScheduleMode != "cron" {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("队列未启用 cron 调度")
|
err := fmt.Errorf("队列未启用 cron 调度")
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("重置队列失败")
|
err := fmt.Errorf("重置队列失败")
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
} else if queue.Status != "pending" && queue.Status != "paused" {
|
} else if queue.Status != "pending" && queue.Status != "paused" {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
return true, fmt.Errorf("队列状态不允许启动")
|
return true, fmt.Errorf("队列状态不允许启动")
|
||||||
}
|
}
|
||||||
|
|
||||||
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
||||||
h.unmarkBatchQueueRunning(queueID)
|
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
|
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
|
||||||
if scheduled {
|
if scheduled {
|
||||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||||
@@ -2080,325 +2101,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeBatchQueue 执行批量任务队列
|
|
||||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|
||||||
defer h.unmarkBatchQueueRunning(queueID)
|
|
||||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
|
|
||||||
|
|
||||||
for {
|
|
||||||
// 检查队列状态
|
|
||||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
|
||||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取下一个任务
|
|
||||||
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
|
|
||||||
if !hasNext {
|
|
||||||
// 所有任务完成:汇总子任务失败信息便于排障
|
|
||||||
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
|
|
||||||
lastRunErr := ""
|
|
||||||
if ok {
|
|
||||||
for _, t := range q.Tasks {
|
|
||||||
if t.Status == "failed" && t.Error != "" {
|
|
||||||
lastRunErr = t.Error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
|
||||||
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
|
|
||||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新任务状态为运行中
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
|
|
||||||
|
|
||||||
// 创建新对话
|
|
||||||
title := safeTruncateString(task.Message, 50)
|
|
||||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
|
||||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
|
||||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
|
||||||
var conversationID string
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
|
||||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
|
||||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
conversationID = conv.ID
|
|
||||||
|
|
||||||
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
|
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
|
|
||||||
|
|
||||||
// 应用角色用户提示词和工具配置
|
|
||||||
finalMessage := task.Message
|
|
||||||
var roleTools []string // 角色配置的工具列表
|
|
||||||
if queue.Role != "" && queue.Role != "默认" {
|
|
||||||
if h.config.Roles != nil {
|
|
||||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
|
||||||
// 应用用户提示词
|
|
||||||
if role.UserPrompt != "" {
|
|
||||||
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
|
||||||
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
|
||||||
}
|
|
||||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
|
||||||
if len(role.Tools) > 0 {
|
|
||||||
roleTools = role.Tools
|
|
||||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
|
||||||
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 预先创建助手消息,以便关联过程详情
|
|
||||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
// 如果创建失败,继续执行但不保存过程详情
|
|
||||||
assistantMsg = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
|
|
||||||
var assistantMessageID string
|
|
||||||
if assistantMsg != nil {
|
|
||||||
assistantMessageID = assistantMsg.ID
|
|
||||||
}
|
|
||||||
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
|
|
||||||
// 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。
|
|
||||||
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
|
|
||||||
var progressCallback func(eventType, message string, data interface{})
|
|
||||||
|
|
||||||
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
|
|
||||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
|
||||||
|
|
||||||
func() {
|
|
||||||
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
|
|
||||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
|
||||||
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
|
|
||||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
|
||||||
|
|
||||||
registered := false
|
|
||||||
finishStatus := "completed"
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
|
||||||
timeoutCancel()
|
|
||||||
if registered {
|
|
||||||
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
|
|
||||||
if h.taskEventBus != nil {
|
|
||||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
|
||||||
if b, err := json.Marshal(ev); err == nil {
|
|
||||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.tasks.FinishTask(conversationID, finishStatus)
|
|
||||||
}
|
|
||||||
cancelWithCause(nil)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
|
|
||||||
sendEvent := func(eventType, message string, data interface{}) {
|
|
||||||
if h.taskEventBus == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
|
||||||
b, err := json.Marshal(ev)
|
|
||||||
if err != nil {
|
|
||||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
|
||||||
}
|
|
||||||
line := make([]byte, 0, len(b)+8)
|
|
||||||
line = append(line, []byte("data: ")...)
|
|
||||||
line = append(line, b...)
|
|
||||||
line = append(line, '\n', '\n')
|
|
||||||
h.taskEventBus.Publish(conversationID, line)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
|
||||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
|
||||||
zap.String("queueId", queueID),
|
|
||||||
zap.String("taskId", task.ID),
|
|
||||||
zap.String("conversationId", conversationID),
|
|
||||||
zap.Error(err))
|
|
||||||
failMsg := err.Error()
|
|
||||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
|
||||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
|
||||||
}
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
registered = true
|
|
||||||
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
|
|
||||||
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
|
|
||||||
|
|
||||||
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
|
|
||||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
|
||||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
|
||||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
|
||||||
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
|
||||||
|
|
||||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
|
||||||
useBatchMulti := false
|
|
||||||
batchOrch := "deep"
|
|
||||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
|
||||||
if am == "multi" {
|
|
||||||
am = "deep"
|
|
||||||
}
|
|
||||||
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
|
||||||
useBatchMulti = true
|
|
||||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
|
||||||
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
|
||||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
|
||||||
useBatchMulti = true
|
|
||||||
batchOrch = "deep"
|
|
||||||
}
|
|
||||||
var resultMA *multiagent.RunResult
|
|
||||||
var runErr error
|
|
||||||
switch {
|
|
||||||
case useBatchMulti:
|
|
||||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
|
||||||
default:
|
|
||||||
if h.config == nil {
|
|
||||||
runErr = fmt.Errorf("服务器配置未加载")
|
|
||||||
} else {
|
|
||||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if runErr != nil {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
|
||||||
}
|
|
||||||
errStr := runErr.Error()
|
|
||||||
partialResp := ""
|
|
||||||
if resultMA != nil {
|
|
||||||
partialResp = resultMA.Response
|
|
||||||
}
|
|
||||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
|
||||||
errors.Is(runErr, context.Canceled) ||
|
|
||||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
|
||||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
|
||||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
|
||||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
|
||||||
|
|
||||||
if isTimeout {
|
|
||||||
finishStatus = "timeout"
|
|
||||||
} else if isCancelled {
|
|
||||||
finishStatus = "cancelled"
|
|
||||||
} else {
|
|
||||||
finishStatus = "failed"
|
|
||||||
}
|
|
||||||
|
|
||||||
if isCancelled {
|
|
||||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
|
||||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
|
||||||
// 如果执行结果中有更具体的取消消息,使用它
|
|
||||||
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
|
||||||
cancelMsg = partialResp
|
|
||||||
}
|
|
||||||
// 更新助手消息内容
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
|
||||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
|
||||||
}
|
|
||||||
// 保存取消详情到数据库
|
|
||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
|
||||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果没有预先创建的助手消息,创建一个新的
|
|
||||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
|
||||||
if errMsg != nil {
|
|
||||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
|
||||||
} else {
|
|
||||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
|
||||||
errorMsg := "执行失败: " + runErr.Error()
|
|
||||||
// 更新助手消息内容
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
if _, updateErr := h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
|
||||||
errorMsg,
|
|
||||||
time.Now(), assistantMessageID,
|
|
||||||
); updateErr != nil {
|
|
||||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
|
||||||
}
|
|
||||||
// 保存错误详情到数据库
|
|
||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
|
||||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
|
||||||
|
|
||||||
resText := resultMA.Response
|
|
||||||
mcpIDs := resultMA.MCPExecutionIDs
|
|
||||||
lastIn := resultMA.LastAgentTraceInput
|
|
||||||
lastOut := resultMA.LastAgentTraceOutput
|
|
||||||
|
|
||||||
// 更新助手消息内容
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
|
||||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
|
||||||
// 如果更新失败,尝试创建新消息
|
|
||||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 如果没有预先创建的助手消息,创建一个新的
|
|
||||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存代理轨迹
|
|
||||||
if lastIn != "" || lastOut != "" {
|
|
||||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
|
||||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
|
||||||
} else {
|
|
||||||
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 移动到下一个任务
|
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
|
||||||
|
|
||||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
|
||||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
|
||||||
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否被取消或暂停
|
|
||||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
|
||||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
||||||
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
||||||
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
||||||
|
|||||||
@@ -0,0 +1,352 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
|
||||||
|
|
||||||
|
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
|
||||||
|
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||||
|
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||||
|
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
|
||||||
|
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
h.runBatchQueueWorker(queueID)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
h.tryFinalizeBatchQueue(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
|
||||||
|
for {
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if batchQueueExecutionShouldStop(queue, exists) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
|
||||||
|
if !ok {
|
||||||
|
if !h.batchTaskManager.HasRunningTasks(queueID) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(batchQueueWorkerIdlePoll)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
|
||||||
|
h.executeOneBatchSubTask(queueID, queue, task)
|
||||||
|
|
||||||
|
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
|
||||||
|
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if batchQueueExecutionShouldStop(queue, exists) {
|
||||||
|
if !exists {
|
||||||
|
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if queue.Status != BatchQueueStatusRunning {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRunErr := ""
|
||||||
|
for _, t := range queue.Tasks {
|
||||||
|
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
|
||||||
|
lastRunErr = t.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
|
||||||
|
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
|
||||||
|
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
|
||||||
|
title := safeTruncateString(task.Message, 50)
|
||||||
|
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||||
|
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||||
|
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conversationID := conv.ID
|
||||||
|
|
||||||
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
|
||||||
|
|
||||||
|
finalMessage := task.Message
|
||||||
|
var roleTools []string
|
||||||
|
if queue.Role != "" && queue.Role != "默认" {
|
||||||
|
if h.config.Roles != nil {
|
||||||
|
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||||
|
if role.UserPrompt != "" {
|
||||||
|
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
||||||
|
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
||||||
|
}
|
||||||
|
if len(role.Tools) > 0 {
|
||||||
|
roleTools = role.Tools
|
||||||
|
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
|
||||||
|
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
assistantMsg = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var assistantMessageID string
|
||||||
|
if assistantMsg != nil {
|
||||||
|
assistantMessageID = assistantMsg.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||||
|
|
||||||
|
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||||
|
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||||
|
|
||||||
|
registered := false
|
||||||
|
finishStatus := "completed"
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
|
||||||
|
timeoutCancel()
|
||||||
|
if registered {
|
||||||
|
if h.taskEventBus != nil {
|
||||||
|
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||||
|
if b, err := json.Marshal(ev); err == nil {
|
||||||
|
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.tasks.FinishTask(conversationID, finishStatus)
|
||||||
|
}
|
||||||
|
cancelWithCause(nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sendEvent := func(eventType, message string, data interface{}) {
|
||||||
|
if h.taskEventBus == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||||
|
b, err := json.Marshal(ev)
|
||||||
|
if err != nil {
|
||||||
|
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||||
|
}
|
||||||
|
line := make([]byte, 0, len(b)+8)
|
||||||
|
line = append(line, []byte("data: ")...)
|
||||||
|
line = append(line, b...)
|
||||||
|
line = append(line, '\n', '\n')
|
||||||
|
h.taskEventBus.Publish(conversationID, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||||
|
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||||
|
zap.String("queueId", queueID),
|
||||||
|
zap.String("taskId", task.ID),
|
||||||
|
zap.String("conversationId", conversationID),
|
||||||
|
zap.Error(err))
|
||||||
|
failMsg := err.Error()
|
||||||
|
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||||
|
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||||
|
}
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registered = true
|
||||||
|
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
|
||||||
|
|
||||||
|
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
|
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
|
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||||
|
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
||||||
|
|
||||||
|
useBatchMulti := false
|
||||||
|
batchOrch := "deep"
|
||||||
|
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||||
|
if am == "multi" {
|
||||||
|
am = "deep"
|
||||||
|
}
|
||||||
|
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||||
|
useBatchMulti = true
|
||||||
|
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||||
|
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||||
|
useBatchMulti = true
|
||||||
|
batchOrch = "deep"
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultMA *multiagent.RunResult
|
||||||
|
var runErr error
|
||||||
|
switch {
|
||||||
|
case useBatchMulti:
|
||||||
|
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||||
|
default:
|
||||||
|
if h.config == nil {
|
||||||
|
runErr = fmt.Errorf("服务器配置未加载")
|
||||||
|
} else {
|
||||||
|
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if runErr != nil {
|
||||||
|
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMA == nil {
|
||||||
|
h.logger.Error("批量任务执行成功但无结果对象",
|
||||||
|
zap.String("queueId", queueID),
|
||||||
|
zap.String("taskId", task.ID),
|
||||||
|
zap.String("conversationId", conversationID))
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||||
|
|
||||||
|
resText := resultMA.Response
|
||||||
|
mcpIDs := resultMA.MCPExecutionIDs
|
||||||
|
lastIn := resultMA.LastAgentTraceInput
|
||||||
|
lastOut := resultMA.LastAgentTraceOutput
|
||||||
|
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||||
|
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
|
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||||
|
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||||
|
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastIn != "" || lastOut != "" {
|
||||||
|
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||||
|
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) handleBatchSubTaskRunError(
|
||||||
|
queueID string,
|
||||||
|
task *BatchTask,
|
||||||
|
conversationID, assistantMessageID string,
|
||||||
|
baseCtx, taskCtx context.Context,
|
||||||
|
resultMA *multiagent.RunResult,
|
||||||
|
runErr error,
|
||||||
|
finishStatus *string,
|
||||||
|
) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||||
|
}
|
||||||
|
errStr := runErr.Error()
|
||||||
|
partialResp := ""
|
||||||
|
if resultMA != nil {
|
||||||
|
partialResp = resultMA.Response
|
||||||
|
}
|
||||||
|
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||||
|
errors.Is(runErr, context.Canceled) ||
|
||||||
|
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||||
|
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||||
|
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||||
|
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||||
|
|
||||||
|
if isTimeout {
|
||||||
|
*finishStatus = "timeout"
|
||||||
|
} else if isCancelled {
|
||||||
|
*finishStatus = "cancelled"
|
||||||
|
} else {
|
||||||
|
*finishStatus = "failed"
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCancelled {
|
||||||
|
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||||
|
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||||
|
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
||||||
|
cancelMsg = partialResp
|
||||||
|
}
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||||
|
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
|
}
|
||||||
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||||
|
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
|
||||||
|
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||||
|
}
|
||||||
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||||
|
errorMsg := "执行失败: " + runErr.Error()
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if _, updateErr := h.db.Exec(
|
||||||
|
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||||
|
errorMsg,
|
||||||
|
time.Now(), assistantMessageID,
|
||||||
|
); updateErr != nil {
|
||||||
|
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
|
}
|
||||||
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||||
|
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,6 +18,15 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
|
||||||
|
ErrBatchQueueNotFound = errors.New("batch queue not found")
|
||||||
|
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
|
||||||
|
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
|
||||||
|
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
|
||||||
|
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
|
||||||
|
)
|
||||||
|
|
||||||
// 批量任务状态常量
|
// 批量任务状态常量
|
||||||
const (
|
const (
|
||||||
BatchQueueStatusPending = "pending"
|
BatchQueueStatusPending = "pending"
|
||||||
@@ -39,6 +49,12 @@ const (
|
|||||||
|
|
||||||
// MaxBatchQueueRoleLen 角色名最大长度
|
// MaxBatchQueueRoleLen 角色名最大长度
|
||||||
MaxBatchQueueRoleLen = 100
|
MaxBatchQueueRoleLen = 100
|
||||||
|
|
||||||
|
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
|
||||||
|
DefaultBatchQueueConcurrency = 1
|
||||||
|
|
||||||
|
// MaxBatchQueueConcurrency 批量队列最大并发数
|
||||||
|
MaxBatchQueueConcurrency = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
// BatchTask 批量任务项
|
// BatchTask 批量任务项
|
||||||
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
|
|||||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||||
LastRunError string `json:"lastRunError,omitempty"`
|
LastRunError string `json:"lastRunError,omitempty"`
|
||||||
ProjectID string `json:"projectId,omitempty"`
|
ProjectID string `json:"projectId,omitempty"`
|
||||||
|
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
|
||||||
Tasks []*BatchTask `json:"tasks"`
|
Tasks []*BatchTask `json:"tasks"`
|
||||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
|
|||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
queues map[string]*BatchTaskQueue
|
queues map[string]*BatchTaskQueue
|
||||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
|
||||||
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
||||||
|
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
|
|||||||
return &BatchTaskManager{
|
return &BatchTaskManager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
queues: make(map[string]*BatchTaskQueue),
|
queues: make(map[string]*BatchTaskQueue),
|
||||||
taskCancels: make(map[string]context.CancelFunc),
|
taskCancels: make(map[string]map[string]context.CancelFunc),
|
||||||
singleRunTasks: make(map[string]string),
|
singleRunTasks: make(map[string]string),
|
||||||
|
queueExecutors: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
|
||||||
|
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch queue.Status {
|
||||||
|
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
|
||||||
|
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if _, exists := m.queueExecutors[queueID]; exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.queueExecutors[queueID] = struct{}{}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
|
||||||
|
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.queueExecutors, queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
|
||||||
|
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
|
||||||
|
m.UnmarkQueueExecutor(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
|
||||||
|
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
_, ok := m.queueExecutors[queueID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// SetDB 设置数据库连接
|
// SetDB 设置数据库连接
|
||||||
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
|||||||
m.db = db
|
m.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeBatchQueueConcurrency 规范化队列并发数。
|
||||||
|
func normalizeBatchQueueConcurrency(n int) int {
|
||||||
|
if n < 1 {
|
||||||
|
return DefaultBatchQueueConcurrency
|
||||||
|
}
|
||||||
|
if n > MaxBatchQueueConcurrency {
|
||||||
|
return MaxBatchQueueConcurrency
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
// CreateBatchQueue 创建批量任务队列
|
// CreateBatchQueue 创建批量任务队列
|
||||||
func (m *BatchTaskManager) CreateBatchQueue(
|
func (m *BatchTaskManager) CreateBatchQueue(
|
||||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||||
nextRunAt *time.Time,
|
nextRunAt *time.Time,
|
||||||
|
concurrency int,
|
||||||
tasks []string,
|
tasks []string,
|
||||||
) (*BatchTaskQueue, error) {
|
) (*BatchTaskQueue, error) {
|
||||||
// 输入校验
|
// 输入校验
|
||||||
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
|||||||
CronExpr: strings.TrimSpace(cronExpr),
|
CronExpr: strings.TrimSpace(cronExpr),
|
||||||
NextRunAt: nextRunAt,
|
NextRunAt: nextRunAt,
|
||||||
ScheduleEnabled: true,
|
ScheduleEnabled: true,
|
||||||
|
Concurrency: normalizeBatchQueueConcurrency(concurrency),
|
||||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||||
Status: BatchQueueStatusPending,
|
Status: BatchQueueStatusPending,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
|||||||
queue.CronExpr,
|
queue.CronExpr,
|
||||||
queue.NextRunAt,
|
queue.NextRunAt,
|
||||||
queue.ProjectID,
|
queue.ProjectID,
|
||||||
|
queue.Concurrency,
|
||||||
dbTasks,
|
dbTasks,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||||
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
|||||||
if queueRow.ProjectID.Valid {
|
if queueRow.ProjectID.Valid {
|
||||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||||
}
|
}
|
||||||
|
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||||
if queueRow.StartedAt.Valid {
|
if queueRow.StartedAt.Valid {
|
||||||
queue.StartedAt = &queueRow.StartedAt.Time
|
queue.StartedAt = &queueRow.StartedAt.Time
|
||||||
}
|
}
|
||||||
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
|||||||
if queueRow.ProjectID.Valid {
|
if queueRow.ProjectID.Valid {
|
||||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||||
}
|
}
|
||||||
|
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||||
if queueRow.StartedAt.Valid {
|
if queueRow.StartedAt.Valid {
|
||||||
queue.StartedAt = &queueRow.StartedAt.Time
|
queue.StartedAt = &queueRow.StartedAt.Time
|
||||||
}
|
}
|
||||||
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用)
|
// batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
|
||||||
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error {
|
func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
|
||||||
|
if row == nil || !row.Concurrency.Valid {
|
||||||
|
return DefaultBatchQueueConcurrency
|
||||||
|
}
|
||||||
|
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
|
||||||
|
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
|
||||||
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
|
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
|
||||||
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
|
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
|
||||||
}
|
}
|
||||||
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
|
|||||||
queue.Title = title
|
queue.Title = title
|
||||||
queue.Role = role
|
queue.Role = role
|
||||||
queue.AgentMode = agentMode
|
queue.AgentMode = agentMode
|
||||||
|
if concurrency != nil {
|
||||||
|
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
if m.db != nil {
|
if m.db != nil {
|
||||||
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil {
|
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
|
||||||
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
|
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
|
|||||||
|
|
||||||
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
||||||
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||||
var cancelFunc context.CancelFunc
|
|
||||||
var siblingRunningIDs []string
|
var siblingRunningIDs []string
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
||||||
|
var cancelFuncs []context.CancelFunc
|
||||||
if queue.Status == BatchQueueStatusPaused {
|
if queue.Status == BatchQueueStatusPaused {
|
||||||
if c, ok := m.taskCancels[queueID]; ok {
|
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||||
cancelFunc = c
|
|
||||||
delete(m.taskCancels, queueID)
|
|
||||||
}
|
|
||||||
for _, t := range queue.Tasks {
|
for _, t := range queue.Tasks {
|
||||||
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
||||||
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
||||||
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
|||||||
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
if cancelFunc != nil {
|
for _, c := range cancelFuncs {
|
||||||
cancelFunc()
|
if c != nil {
|
||||||
|
c()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const staleRunMsg = "为单条执行其它任务,已中止"
|
const staleRunMsg = "为单条执行其它任务,已中止"
|
||||||
for _, sid := range siblingRunningIDs {
|
for _, sid := range siblingRunningIDs {
|
||||||
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNextTask 获取下一个待执行的任务
|
// ClaimNextPendingTask 原子领取下一个待执行子任务(并发 worker 安全)。
|
||||||
|
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
onlyTaskID := ""
|
||||||
|
if m.singleRunTasks != nil {
|
||||||
|
onlyTaskID = m.singleRunTasks[queueID]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, task := range queue.Tasks {
|
||||||
|
if task == nil || task.Status != BatchTaskStatusPending {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if onlyTaskID != "" && task.ID != onlyTaskID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
task.Status = BatchTaskStatusRunning
|
||||||
|
queue.CurrentIndex = i
|
||||||
|
return task, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRunningTasks 队列是否仍有 running 状态的子任务。
|
||||||
|
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, task := range queue.Tasks {
|
||||||
|
if task != nil && task.Status == BatchTaskStatusRunning {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
|
||||||
|
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists || queue == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, task := range queue.Tasks {
|
||||||
|
if task == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
|
||||||
|
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
|
||||||
|
taskMap, ok := m.taskCancels[queueID]
|
||||||
|
if !ok || len(taskMap) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cancels := make([]context.CancelFunc, 0, len(taskMap))
|
||||||
|
for _, c := range taskMap {
|
||||||
|
if c != nil {
|
||||||
|
cancels = append(cancels, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m.taskCancels, queueID)
|
||||||
|
return cancels
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask)
|
||||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTaskCancel 设置当前任务的取消函数
|
// SetTaskCancel 设置子任务的取消函数
|
||||||
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
|
func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
if cancel != nil {
|
if cancel == nil {
|
||||||
m.taskCancels[queueID] = cancel
|
if taskMap, ok := m.taskCancels[queueID]; ok {
|
||||||
} else {
|
delete(taskMap, taskID)
|
||||||
|
if len(taskMap) == 0 {
|
||||||
delete(m.taskCancels, queueID)
|
delete(m.taskCancels, queueID)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.taskCancels[queueID] == nil {
|
||||||
|
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
|
||||||
|
}
|
||||||
|
m.taskCancels[queueID][taskID] = cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
// PauseQueue 暂停队列
|
// PauseQueue 暂停队列
|
||||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||||
var cancelFunc context.CancelFunc
|
var cancelFuncs []context.CancelFunc
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
queue, exists := m.queues[queueID]
|
queue, exists := m.queues[queueID]
|
||||||
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
queue.Status = BatchQueueStatusPaused
|
queue.Status = BatchQueueStatusPaused
|
||||||
|
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||||
// 取消当前正在执行的任务(通过取消context)
|
|
||||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
|
||||||
cancelFunc = cancel
|
|
||||||
delete(m.taskCancels, queueID)
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
for _, c := range cancelFuncs {
|
||||||
if cancelFunc != nil {
|
c()
|
||||||
cancelFunc()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
|||||||
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
||||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var cancelFunc context.CancelFunc
|
var cancelFuncs []context.CancelFunc
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
queue, exists := m.queues[queueID]
|
queue, exists := m.queues[queueID]
|
||||||
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 取消当前正在执行的任务
|
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
|
||||||
cancelFunc = cancel
|
|
||||||
delete(m.taskCancels, queueID)
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
for _, c := range cancelFuncs {
|
||||||
if cancelFunc != nil {
|
c()
|
||||||
cancelFunc()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteQueue 删除队列(运行中的队列不允许删除)
|
// DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
|
||||||
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
func (m *BatchTaskManager) DeleteQueue(queueID string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
queue, exists := m.queues[queueID]
|
queue, exists := m.queues[queueID]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return ErrBatchQueueNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exec := m.queueExecutors[queueID]; exec {
|
||||||
|
return ErrBatchQueueExecutorActive
|
||||||
}
|
}
|
||||||
|
|
||||||
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
|
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
|
||||||
if queue.Status == BatchQueueStatusRunning {
|
if queue.Status == BatchQueueStatusRunning {
|
||||||
return false
|
return ErrBatchQueueStillRunning
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理取消函数
|
// 清理取消函数
|
||||||
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
delete(m.queues, queueID)
|
delete(m.queues, queueID)
|
||||||
return true
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateShortID 生成短ID
|
// generateShortID 生成短ID
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
|
||||||
|
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
|
||||||
|
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
|
||||||
|
}
|
||||||
|
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
|
||||||
|
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaimNextPendingTaskParallel(t *testing.T) {
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateBatchQueue: %v", err)
|
||||||
|
}
|
||||||
|
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||||
|
|
||||||
|
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
if !ok1 || !ok2 || t1.ID == t2.ID {
|
||||||
|
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
|
||||||
|
}
|
||||||
|
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
|
||||||
|
t.Fatalf("claimed tasks should be running")
|
||||||
|
}
|
||||||
|
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
if !ok3 {
|
||||||
|
t.Fatal("expected third claim")
|
||||||
|
}
|
||||||
|
_, ok4 := m.ClaimNextPendingTask(queue.ID)
|
||||||
|
if ok4 {
|
||||||
|
t.Fatal("expected no fourth pending task")
|
||||||
|
}
|
||||||
|
_ = t3
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchQueueExecutionShouldStop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !batchQueueExecutionShouldStop(nil, false) {
|
||||||
|
t.Fatal("expected stop when queue missing")
|
||||||
|
}
|
||||||
|
if !batchQueueExecutionShouldStop(nil, true) {
|
||||||
|
t.Fatal("expected stop when queue is nil but exists=true")
|
||||||
|
}
|
||||||
|
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
|
||||||
|
if batchQueueExecutionShouldStop(q, true) {
|
||||||
|
t.Fatal("expected continue when running")
|
||||||
|
}
|
||||||
|
q.Status = BatchQueueStatusCancelled
|
||||||
|
if !batchQueueExecutionShouldStop(q, true) {
|
||||||
|
t.Fatal("expected stop when cancelled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateBatchQueue: %v", err)
|
||||||
|
}
|
||||||
|
if !m.TryMarkQueueExecutor(queue.ID) {
|
||||||
|
t.Fatal("expected to mark executor")
|
||||||
|
}
|
||||||
|
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
|
||||||
|
|
||||||
|
err = m.DeleteQueue(queue.ID)
|
||||||
|
if !errors.Is(err, ErrBatchQueueExecutorActive) {
|
||||||
|
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := m.GetBatchQueue(queue.ID); !ok {
|
||||||
|
t.Fatal("queue should still exist while executor active")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.UnmarkQueueExecutor(queue.ID)
|
||||||
|
if err := m.DeleteQueue(queue.ID); err != nil {
|
||||||
|
t.Fatalf("expected delete after executor unmarked, got %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := m.GetBatchQueue(queue.ID); ok {
|
||||||
|
t.Fatal("queue should be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateBatchQueue: %v", err)
|
||||||
|
}
|
||||||
|
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||||
|
|
||||||
|
err = m.DeleteQueue(queue.ID)
|
||||||
|
if !errors.Is(err, ErrBatchQueueStillRunning) {
|
||||||
|
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
m := NewBatchTaskManager(zap.NewNop())
|
||||||
|
if !m.TryMarkQueueExecutor("q-1") {
|
||||||
|
t.Fatal("first mark should succeed")
|
||||||
|
}
|
||||||
|
if m.TryMarkQueueExecutor("q-1") {
|
||||||
|
t.Fatal("second mark should fail")
|
||||||
|
}
|
||||||
|
m.UnmarkQueueExecutor("q-1")
|
||||||
|
if !m.TryMarkQueueExecutor("q-1") {
|
||||||
|
t.Fatal("mark after unmark should succeed")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||||
},
|
},
|
||||||
|
"concurrency": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
executeNow = false
|
executeNow = false
|
||||||
}
|
}
|
||||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
concurrency := int(mcpArgFloat(args, "concurrency"))
|
||||||
|
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||||
}
|
}
|
||||||
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
if qid == "" {
|
if qid == "" {
|
||||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||||
}
|
}
|
||||||
if !h.batchTaskManager.DeleteQueue(qid) {
|
if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, ErrBatchQueueNotFound):
|
||||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||||
|
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||||
|
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
|
||||||
|
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||||
|
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
|
||||||
|
default:
|
||||||
|
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||||||
return batchMCPTextResult("队列已删除。", false), nil
|
return batchMCPTextResult("队列已删除。", false), nil
|
||||||
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
||||||
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||||
},
|
},
|
||||||
|
"concurrency": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "同时执行的子任务数,默认 1,最大 8",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"queue_id"},
|
"required": []string{"queue_id"},
|
||||||
},
|
},
|
||||||
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
|||||||
title := mcpArgString(args, "title")
|
title := mcpArgString(args, "title")
|
||||||
role := mcpArgString(args, "role")
|
role := mcpArgString(args, "role")
|
||||||
agentMode := mcpArgString(args, "agent_mode")
|
agentMode := mcpArgString(args, "agent_mode")
|
||||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
|
var concurrency *int
|
||||||
|
if raw, ok := args["concurrency"]; ok && raw != nil {
|
||||||
|
v := int(mcpArgFloat(args, "concurrency"))
|
||||||
|
concurrency = &v
|
||||||
|
}
|
||||||
|
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
|
||||||
return batchMCPTextResult(err.Error(), true), nil
|
return batchMCPTextResult(err.Error(), true), nil
|
||||||
}
|
}
|
||||||
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||||
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
|
|||||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||||
CurrentIndex int `json:"currentIndex"`
|
CurrentIndex int `json:"currentIndex"`
|
||||||
|
Concurrency int `json:"concurrency"`
|
||||||
TaskTotal int `json:"task_total"`
|
TaskTotal int `json:"task_total"`
|
||||||
TaskCounts map[string]int `json:"task_counts"`
|
TaskCounts map[string]int `json:"task_counts"`
|
||||||
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
||||||
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
|
|||||||
StartedAt: q.StartedAt,
|
StartedAt: q.StartedAt,
|
||||||
CompletedAt: q.CompletedAt,
|
CompletedAt: q.CompletedAt,
|
||||||
CurrentIndex: q.CurrentIndex,
|
CurrentIndex: q.CurrentIndex,
|
||||||
|
Concurrency: q.Concurrency,
|
||||||
TaskTotal: len(tasks),
|
TaskTotal: len(tasks),
|
||||||
TaskCounts: counts,
|
TaskCounts: counts,
|
||||||
Tasks: tasks,
|
Tasks: tasks,
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import (
|
|||||||
type MonitorHandler struct {
|
type MonitorHandler struct {
|
||||||
mcpServer *mcp.Server
|
mcpServer *mcp.Server
|
||||||
externalMCPMgr *mcp.ExternalMCPManager
|
externalMCPMgr *mcp.ExternalMCPManager
|
||||||
|
taskManager *AgentTaskManager
|
||||||
|
agentHandler *AgentHandler
|
||||||
executor *security.Executor
|
executor *security.Executor
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
@@ -56,6 +58,16 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
|||||||
h.externalMCPMgr = mgr
|
h.externalMCPMgr = mgr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTaskManager 设置 Agent 任务管理器(用于 Eino execute 等按 executionId 终止)。
|
||||||
|
func (h *MonitorHandler) SetTaskManager(mgr *AgentTaskManager) {
|
||||||
|
h.taskManager = mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAgentHandler 设置 Agent 处理器(MCP 监控终止与对话页「中断并继续」共用逻辑)。
|
||||||
|
func (h *MonitorHandler) SetAgentHandler(ah *AgentHandler) {
|
||||||
|
h.agentHandler = ah
|
||||||
|
}
|
||||||
|
|
||||||
// MonitorResponse 监控响应
|
// MonitorResponse 监控响应
|
||||||
type MonitorResponse struct {
|
type MonitorResponse struct {
|
||||||
Executions []*mcp.ToolExecution `json:"executions"`
|
Executions []*mcp.ToolExecution `json:"executions"`
|
||||||
@@ -90,6 +102,7 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
|||||||
toolName := normalizeToolNameFilter(c.Query("tool"))
|
toolName := normalizeToolNameFilter(c.Query("tool"))
|
||||||
|
|
||||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
||||||
|
h.enrichExecutionsConversationID(executions)
|
||||||
stats := h.loadStats()
|
stats := h.loadStats()
|
||||||
|
|
||||||
totalPages := (total + pageSize - 1) / pageSize
|
totalPages := (total + pageSize - 1) / pageSize
|
||||||
@@ -247,6 +260,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
|||||||
// 先从内部MCP服务器查找
|
// 先从内部MCP服务器查找
|
||||||
exec, exists := h.mcpServer.GetExecution(id)
|
exec, exists := h.mcpServer.GetExecution(id)
|
||||||
if exists {
|
if exists {
|
||||||
|
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||||
c.JSON(http.StatusOK, exec)
|
c.JSON(http.StatusOK, exec)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -255,6 +269,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
|||||||
if h.externalMCPMgr != nil {
|
if h.externalMCPMgr != nil {
|
||||||
exec, exists = h.externalMCPMgr.GetExecution(id)
|
exec, exists = h.externalMCPMgr.GetExecution(id)
|
||||||
if exists {
|
if exists {
|
||||||
|
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||||
c.JSON(http.StatusOK, exec)
|
c.JSON(http.StatusOK, exec)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -264,6 +279,7 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
|||||||
if h.db != nil {
|
if h.db != nil {
|
||||||
exec, err := h.db.GetToolExecution(id)
|
exec, err := h.db.GetToolExecution(id)
|
||||||
if err == nil && exec != nil {
|
if err == nil && exec != nil {
|
||||||
|
h.enrichExecutionsConversationID([]*mcp.ToolExecution{exec})
|
||||||
c.JSON(http.StatusOK, exec)
|
c.JSON(http.StatusOK, exec)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -290,6 +306,19 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
note = strings.TrimSpace(body.Note)
|
note = strings.TrimSpace(body.Note)
|
||||||
|
|
||||||
|
convID := h.conversationIDForRunningExecution(id)
|
||||||
|
if convID != "" && h.agentHandler != nil {
|
||||||
|
if ok, payload := h.agentHandler.cancelToolContinueAfter(convID, id, note); ok {
|
||||||
|
h.logger.Info("MCP 监控页终止工具(与对话中断并继续一致)",
|
||||||
|
zap.String("executionId", id),
|
||||||
|
zap.String("conversationId", convID),
|
||||||
|
zap.Bool("hasNote", note != ""),
|
||||||
|
)
|
||||||
|
c.JSON(http.StatusOK, payload)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
||||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||||
@@ -303,6 +332,52 @@ func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) enrichExecutionsConversationID(executions []*mcp.ToolExecution) {
|
||||||
|
for _, exec := range executions {
|
||||||
|
if exec == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exec.ConversationID = h.conversationIDForRunningExecution(exec.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) conversationIDForRunningExecution(executionID string) string {
|
||||||
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
if executionID == "" || h.taskManager == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if conv := h.taskManager.ConversationIDForActiveMCPExecution(executionID); conv != "" {
|
||||||
|
return conv
|
||||||
|
}
|
||||||
|
exec := h.lookupExecution(executionID)
|
||||||
|
if exec == nil || exec.Status != "running" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(exec.ToolName) == "execute" {
|
||||||
|
if onlyConv, ok := h.taskManager.ConversationIDForActiveEinoExecute(); ok {
|
||||||
|
return onlyConv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *MonitorHandler) lookupExecution(id string) *mcp.ToolExecution {
|
||||||
|
if exec, ok := h.mcpServer.GetExecution(id); ok {
|
||||||
|
return exec
|
||||||
|
}
|
||||||
|
if h.externalMCPMgr != nil {
|
||||||
|
if exec, ok := h.externalMCPMgr.GetExecution(id); ok {
|
||||||
|
return exec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.db != nil {
|
||||||
|
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
|
||||||
|
return exec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
||||||
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
|
|||||||
@@ -103,6 +103,40 @@ func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConversationIDForActiveMCPExecution 根据当前登记的工具 executionId 反查会话 ID(供 MCP 监控页按 executionId 终止)。
|
||||||
|
func (m *AgentTaskManager) ConversationIDForActiveMCPExecution(executionID string) string {
|
||||||
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
if executionID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
for convID, t := range m.tasks {
|
||||||
|
if t != nil && t.ActiveMCPExecutionID == executionID {
|
||||||
|
return convID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationIDForActiveEinoExecute 返回当前唯一进行 Eino execute 的会话 ID;多会话并行时返回空。
|
||||||
|
func (m *AgentTaskManager) ConversationIDForActiveEinoExecute() (string, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
var found string
|
||||||
|
count := 0
|
||||||
|
for convID, t := range m.tasks {
|
||||||
|
if t != nil && t.activeEinoExecuteCancel != nil {
|
||||||
|
found = convID
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count == 1 {
|
||||||
|
return found, true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
||||||
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
||||||
conversationID = strings.TrimSpace(conversationID)
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
|||||||
@@ -38,3 +38,19 @@ func TestAbortActiveEinoExecute(t *testing.T) {
|
|||||||
t.Fatal("second abort should fail when no active execute")
|
t.Fatal("second abort should fail when no active execute")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConversationIDForActiveMCPExecution(t *testing.T) {
|
||||||
|
m := NewAgentTaskManager()
|
||||||
|
conv := "conv-mcp-exec"
|
||||||
|
_, err := m.StartTask(conv, "test", func(error) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
m.RegisterRunningTool(conv, "exec-123")
|
||||||
|
if got := m.ConversationIDForActiveMCPExecution("exec-123"); got != conv {
|
||||||
|
t.Fatalf("got %q, want %q", got, conv)
|
||||||
|
}
|
||||||
|
if got := m.ConversationIDForActiveMCPExecution("missing"); got != "" {
|
||||||
|
t.Fatalf("missing should be empty, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+83
-16
@@ -921,9 +921,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
|||||||
return finalResult, executionID, nil
|
return finalResult, executionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
// BeginToolExecution 创建 running 状态的执行记录,供 Eino 等非 CallTool 路径在工具开始时落库。
|
||||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
func (s *Server) BeginToolExecution(toolName string, args map[string]interface{}) string {
|
||||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -931,21 +930,73 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
|||||||
args = map[string]interface{}{}
|
args = map[string]interface{}{}
|
||||||
}
|
}
|
||||||
executionID := uuid.New().String()
|
executionID := uuid.New().String()
|
||||||
now := time.Now()
|
execution := &ToolExecution{
|
||||||
failed := invokeErr != nil
|
|
||||||
exec := &ToolExecution{
|
|
||||||
ID: executionID,
|
ID: executionID,
|
||||||
ToolName: toolName,
|
ToolName: toolName,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
StartTime: now,
|
Status: "running",
|
||||||
EndTime: &now,
|
StartTime: time.Now(),
|
||||||
Duration: 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.executions[executionID] = execution
|
||||||
|
s.cleanupOldExecutions()
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.storage != nil {
|
||||||
|
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||||||
|
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return executionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinishToolExecution 完成先前 BeginToolExecution 创建的记录;executionID 为空时等同 RecordCompletedToolInvocation。
|
||||||
|
func (s *Server) FinishToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if args == nil {
|
||||||
|
args = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
id := strings.TrimSpace(executionID)
|
||||||
|
if id == "" {
|
||||||
|
return s.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
failed := invokeErr != nil
|
||||||
|
var finalResult *ToolResult
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
exec, inMem := s.executions[id]
|
||||||
|
if !inMem || exec == nil {
|
||||||
|
exec = &ToolExecution{
|
||||||
|
ID: id,
|
||||||
|
ToolName: toolName,
|
||||||
|
Arguments: args,
|
||||||
|
StartTime: now,
|
||||||
|
}
|
||||||
|
s.executions[id] = exec
|
||||||
|
} else if toolName != "" {
|
||||||
|
exec.ToolName = toolName
|
||||||
|
}
|
||||||
|
if len(args) > 0 {
|
||||||
|
exec.Arguments = args
|
||||||
|
}
|
||||||
|
exec.EndTime = &now
|
||||||
|
if exec.StartTime.IsZero() {
|
||||||
|
exec.StartTime = now
|
||||||
|
}
|
||||||
|
exec.Duration = now.Sub(exec.StartTime)
|
||||||
|
|
||||||
if failed {
|
if failed {
|
||||||
exec.Status = "failed"
|
st, msg := executionStatusAndMessage(invokeErr)
|
||||||
exec.Error = invokeErr.Error()
|
exec.Status = st
|
||||||
|
exec.Error = msg
|
||||||
if strings.TrimSpace(resultText) != "" {
|
if strings.TrimSpace(resultText) != "" {
|
||||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||||
|
exec.Result = finalResult
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
exec.Status = "completed"
|
exec.Status = "completed"
|
||||||
@@ -953,15 +1004,31 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
|||||||
if strings.TrimSpace(text) == "" {
|
if strings.TrimSpace(text) == "" {
|
||||||
text = "(无输出)"
|
text = "(无输出)"
|
||||||
}
|
}
|
||||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||||
|
exec.Result = finalResult
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
if s.storage != nil {
|
if s.storage != nil {
|
||||||
if err := s.storage.SaveToolExecution(exec); err != nil {
|
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||||
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.updateStats(toolName, failed)
|
|
||||||
return executionID
|
s.updateStats(exec.ToolName, failed)
|
||||||
|
|
||||||
|
if s.storage != nil {
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.executions, id)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||||
|
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||||
|
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
return s.FinishToolExecution("", toolName, args, resultText, invokeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
||||||
|
|||||||
@@ -199,6 +199,8 @@ type ToolExecution struct {
|
|||||||
StartTime time.Time `json:"startTime"`
|
StartTime time.Time `json:"startTime"`
|
||||||
EndTime *time.Time `json:"endTime,omitempty"`
|
EndTime *time.Time `json:"endTime,omitempty"`
|
||||||
Duration time.Duration `json:"duration,omitempty"`
|
Duration time.Duration `json:"duration,omitempty"`
|
||||||
|
// ConversationID 仅 API 展示用(进行中的 Agent 任务),不写入 tool_executions 表。
|
||||||
|
ConversationID string `json:"conversationId,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolStats 工具统计信息
|
// ToolStats 工具统计信息
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/einoobserve"
|
"cyberstrike-ai/internal/einoobserve"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
@@ -90,7 +91,7 @@ type einoADKRunLoopArgs struct {
|
|||||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||||
MCPExecutionBinder *MCPExecutionBinder
|
MCPExecutionBinder *MCPExecutionBinder
|
||||||
|
|
||||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,execute/MCP 桥 Fire 时立即推送 tool_result(ADK 晚到经 toolResultSent 去重)。
|
||||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
|
|
||||||
DA adk.Agent
|
DA adk.Agent
|
||||||
@@ -196,6 +197,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
pendingByID[tc.ToolCallID] = tc
|
pendingByID[tc.ToolCallID] = tc
|
||||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||||
}
|
}
|
||||||
|
markPendingWithMonitor := func(tc toolCallPendingInfo) {
|
||||||
|
markPending(tc)
|
||||||
|
beginEinoADKFilesystemToolMonitor(
|
||||||
|
args.FilesystemMonitorAgent,
|
||||||
|
args.FilesystemMonitorRecord,
|
||||||
|
args.MCPExecutionBinder,
|
||||||
|
tc.ToolCallID,
|
||||||
|
tc.ToolName,
|
||||||
|
)
|
||||||
|
}
|
||||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||||
pendingMu.Lock()
|
pendingMu.Lock()
|
||||||
defer pendingMu.Unlock()
|
defer pendingMu.Unlock()
|
||||||
@@ -331,7 +342,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
toolCallID = tid
|
toolCallID = tid
|
||||||
}
|
}
|
||||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, args.MCPExecutionBinder, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||||
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
||||||
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
||||||
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
||||||
@@ -341,8 +352,21 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
if args.ToolInvokeNotify != nil {
|
if args.ToolInvokeNotify != nil {
|
||||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
removePendingByID(strings.TrimSpace(toolCallID))
|
// Eino execute / MCP 桥在工具返回时 Fire;若 ADK schema.Tool 事件迟迟不到,此处立即推送
|
||||||
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
// tool_result 解除 UI「执行中」。tryEmitToolResultProgress 经 toolResultSent 去重,ADK 晚到不重复。
|
||||||
|
isErr := !success || invokeErr != nil
|
||||||
|
body := einoToolResultBody(content)
|
||||||
|
if einoToolResultIsError(toolName, content) {
|
||||||
|
isErr = true
|
||||||
|
}
|
||||||
|
if tail := friendlyEinoExecuteInvokeTail(invokeErr); tail != "" {
|
||||||
|
if body == "" {
|
||||||
|
body = tail
|
||||||
|
} else if !strings.Contains(body, tail) {
|
||||||
|
body = strings.TrimSpace(body) + "\n\n" + tail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tryEmitToolResultProgress(toolName, body, toolCallID, isErr, einoAgent)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -539,6 +563,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 仅在退避重试后真正收到数据/完成一步时清零,避免重启后首个无错 ADK 事件误把计数打回 0。
|
||||||
|
confirmTransientRetryRecovery := func() {
|
||||||
|
if transientRetrier.attempt() > 0 {
|
||||||
|
transientRetrier.reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
takePartial := func(runErr error) (*RunResult, error) {
|
takePartial := func(runErr error) (*RunResult, error) {
|
||||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||||
return nil, runErr
|
return nil, runErr
|
||||||
@@ -551,10 +582,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
// iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending。
|
||||||
select {
|
ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
|
||||||
case <-ctx.Done():
|
if iterCtxErr != nil {
|
||||||
flushAllPendingAsFailed(ctx.Err())
|
flushAllPendingAsFailed(iterCtxErr)
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
if isInterruptContinue(ctx) {
|
if isInterruptContinue(ctx) {
|
||||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||||
@@ -563,17 +594,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"kind": "interrupt_continue",
|
"kind": "interrupt_continue",
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
progress("error", iterCtxErr.Error(), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return takePartial(ctx.Err())
|
return takePartial(iterCtxErr)
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ev, ok := iter.Next()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// iter 结束并不总是“正常完成”:
|
// iter 结束并不总是“正常完成”:
|
||||||
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
||||||
@@ -627,8 +655,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if restarted {
|
if restarted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
transientRetrier.reset()
|
|
||||||
}
|
}
|
||||||
if ev.AgentName != "" && progress != nil {
|
if ev.AgentName != "" && progress != nil {
|
||||||
iterEinoAgent := orchestratorName
|
iterEinoAgent := orchestratorName
|
||||||
@@ -691,34 +717,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
|
|
||||||
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||||
toolName := strings.TrimSpace(mv.ToolName)
|
toolName := strings.TrimSpace(mv.ToolName)
|
||||||
var toolBuf strings.Builder
|
content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
|
||||||
streamToolCallID := ""
|
isErr := einoToolResultIsError(toolName, content)
|
||||||
var toolStreamRecvErr error
|
content = einoToolResultBody(content)
|
||||||
for {
|
|
||||||
chunk, rerr := mv.MessageStream.Recv()
|
|
||||||
if errors.Is(rerr, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if rerr != nil {
|
|
||||||
toolStreamRecvErr = rerr
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if chunk == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if chunk.Content != "" {
|
|
||||||
toolBuf.WriteString(chunk.Content)
|
|
||||||
}
|
|
||||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
|
||||||
streamToolCallID = tid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content := toolBuf.String()
|
|
||||||
isErr := false
|
|
||||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
|
||||||
isErr = true
|
|
||||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
|
||||||
}
|
|
||||||
if streamToolCallID != "" {
|
if streamToolCallID != "" {
|
||||||
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
||||||
@@ -730,6 +731,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
zap.String("agent", ev.AgentName),
|
zap.String("agent", ev.AgentName),
|
||||||
zap.String("tool", toolName))
|
zap.String("tool", toolName))
|
||||||
}
|
}
|
||||||
|
if toolStreamRecvErr == nil {
|
||||||
|
confirmTransientRetryRecovery()
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -977,7 +981,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||||
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||||
}
|
}
|
||||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
|
||||||
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||||
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||||
@@ -1001,6 +1005,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if restarted {
|
if restarted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
confirmTransientRetryRecovery()
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1010,7 +1016,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPendingWithMonitor)
|
||||||
|
|
||||||
if mv.Role == schema.Assistant {
|
if mv.Role == schema.Assistant {
|
||||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||||
@@ -1085,15 +1091,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
|
|
||||||
content := msg.Content
|
content := msg.Content
|
||||||
isErr := false
|
isErr := einoToolResultIsError(toolName, content)
|
||||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
content = einoToolResultBody(content)
|
||||||
isErr = true
|
|
||||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||||
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
||||||
}
|
}
|
||||||
|
confirmTransientRetryRecovery()
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpIDsMu.Lock()
|
mcpIDsMu.Lock()
|
||||||
@@ -1121,17 +1125,119 @@ func einoPartialRunLastOutputHint() string {
|
|||||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||||
}
|
}
|
||||||
|
|
||||||
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
// friendlyEinoExecuteInvokeTail 将 Eino execute 超时/中断/流异常转为简短提示。
|
||||||
|
// 命令非零退出(ExecuteExitError)已有 exec 对齐的正文,不再追加「执行未正常结束」。
|
||||||
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||||
if invokeErr == nil {
|
if invokeErr == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
var exitErr *ExecuteExitError
|
||||||
|
if errors.As(invokeErr, &exitErr) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
return einoExecuteTimeoutUserHint()
|
return einoExecuteTimeoutUserHint()
|
||||||
}
|
}
|
||||||
|
if errors.Is(invokeErr, context.Canceled) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.Contains(invokeErr.Error(), "shell inactivity timeout") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return "[执行未正常结束] " + invokeErr.Error()
|
return "[执行未正常结束] " + invokeErr.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// einoToolResultIsError 统一判断 Eino 工具结果是否应标记为错误(与 MCP exec 的 IsError 对齐)。
|
||||||
|
func einoToolResultIsError(toolName, content string) bool {
|
||||||
|
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(toolName) == "execute" && security.IsCommandFailureResult(content) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoToolResultBody 去掉工具错误前缀,返回展示/持久化正文。
|
||||||
|
func einoToolResultBody(content string) string {
|
||||||
|
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
return strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
|
||||||
|
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
|
||||||
|
if iter == nil {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
type nextRes struct {
|
||||||
|
ev *adk.AgentEvent
|
||||||
|
ok bool
|
||||||
|
}
|
||||||
|
ch := make(chan nextRes, 1)
|
||||||
|
go func() {
|
||||||
|
e, o := iter.Next()
|
||||||
|
ch <- nextRes{e, o}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, false, ctx.Err()
|
||||||
|
case res := <-ch:
|
||||||
|
return res.ev, res.ok, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
|
||||||
|
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
|
||||||
|
if stream == nil {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
type streamMsg struct {
|
||||||
|
chunk *schema.Message
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
recvCh := make(chan streamMsg, 8)
|
||||||
|
go func() {
|
||||||
|
defer close(recvCh)
|
||||||
|
for {
|
||||||
|
ch, rerr := stream.Recv()
|
||||||
|
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var buf strings.Builder
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return buf.String(), toolCallID, ctx.Err()
|
||||||
|
case sm, open := <-recvCh:
|
||||||
|
if !open {
|
||||||
|
return buf.String(), toolCallID, nil
|
||||||
|
}
|
||||||
|
rerr := sm.err
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
return buf.String(), toolCallID, nil
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
return buf.String(), toolCallID, rerr
|
||||||
|
}
|
||||||
|
chunk := sm.chunk
|
||||||
|
if chunk == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if chunk.Content != "" {
|
||||||
|
buf.WriteString(chunk.Content)
|
||||||
|
}
|
||||||
|
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||||
|
toolCallID = tid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildEinoRunResultFromAccumulated(
|
func buildEinoRunResultFromAccumulated(
|
||||||
orchMode string,
|
orchMode string,
|
||||||
runAccumulatedMsgs []adk.Message,
|
runAccumulatedMsgs []adk.Message,
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
|
||||||
|
sw.Close()
|
||||||
|
|
||||||
|
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
if content != "hello" {
|
||||||
|
t.Fatalf("content=%q want hello", content)
|
||||||
|
}
|
||||||
|
if tid != "tc-1" {
|
||||||
|
t.Fatalf("toolCallID=%q want tc-1", tid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
t.Cleanup(func() { sw.Close() })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
content, _, err := recvSchemaMessageStream(ctx, sr)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
want := errors.New("stream broken")
|
||||||
|
_ = sw.Send(nil, want)
|
||||||
|
sw.Close()
|
||||||
|
|
||||||
|
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Fatalf("want %v, got %v", want, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
|
||||||
|
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
|
||||||
|
if err != nil || content != "" || tid != "" {
|
||||||
|
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
|
||||||
|
sr, sw := schema.Pipe[*schema.Message](4)
|
||||||
|
_ = sw.Send(nil, io.EOF)
|
||||||
|
sw.Close()
|
||||||
|
|
||||||
|
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EOF should not surface as error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStreamingShellExitFail struct {
|
||||||
|
output string
|
||||||
|
code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShellExitFail) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
defer outW.Close()
|
||||||
|
if m.output != "" {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
||||||
|
}
|
||||||
|
code := m.code
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{ExitCode: &code}, nil)
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_CommandFailureFormat(t *testing.T) {
|
||||||
|
inner := &mockStreamingShellExitFail{
|
||||||
|
output: "sudo: a password is required\n",
|
||||||
|
code: 1,
|
||||||
|
}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
var firedBody string
|
||||||
|
var firedSuccess bool
|
||||||
|
var firedErr error
|
||||||
|
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
firedBody = content
|
||||||
|
firedSuccess = success
|
||||||
|
firedErr = invokeErr
|
||||||
|
})
|
||||||
|
wrap := &einoStreamingShellWrap{inner: inner, invokeNotify: notify}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var stream strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
stream.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if firedSuccess {
|
||||||
|
t.Fatal("expected success=false")
|
||||||
|
}
|
||||||
|
var exitErr *ExecuteExitError
|
||||||
|
if !errors.As(firedErr, &exitErr) || exitErr.Code != 1 {
|
||||||
|
t.Fatalf("expected ExecuteExitError code 1, got %v", firedErr)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(firedBody, einomcp.ToolErrorPrefix) {
|
||||||
|
t.Fatalf("missing tool error prefix: %q", firedBody)
|
||||||
|
}
|
||||||
|
body := strings.TrimPrefix(firedBody, einomcp.ToolErrorPrefix)
|
||||||
|
if body != security.FormatCommandFailureResult(1, "sudo: a password is required\n") {
|
||||||
|
t.Fatalf("fire body = %q", body)
|
||||||
|
}
|
||||||
|
if !strings.Contains(stream.String(), "sudo:") {
|
||||||
|
t.Fatalf("stream missing sudo output: %q", stream.String())
|
||||||
|
}
|
||||||
|
if strings.Contains(stream.String(), "command exited with non-zero") {
|
||||||
|
t.Fatalf("stream has legacy noise: %q", stream.String())
|
||||||
|
}
|
||||||
|
if strings.Contains(stream.String(), "执行未正常结束") {
|
||||||
|
t.Fatalf("stream has abnormal tail: %q", stream.String())
|
||||||
|
}
|
||||||
|
if !security.IsCommandFailureResult(stream.String()) {
|
||||||
|
t.Fatalf("stream missing failure status line: %q", stream.String())
|
||||||
|
}
|
||||||
|
if tail := friendlyEinoExecuteInvokeTail(firedErr); tail != "" {
|
||||||
|
t.Fatalf("unexpected invoke tail: %q", tail)
|
||||||
|
}
|
||||||
|
if !einoToolResultIsError("execute", firedBody) {
|
||||||
|
t.Fatal("expected isError for execute failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFriendlyEinoExecuteInvokeTail(t *testing.T) {
|
||||||
|
if friendlyEinoExecuteInvokeTail(&ExecuteExitError{Code: 1}) != "" {
|
||||||
|
t.Fatal("exit error should not get abnormal tail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(friendlyEinoExecuteInvokeTail(context.DeadlineExceeded), "Timed out") {
|
||||||
|
t.Fatal("deadline should get timeout hint")
|
||||||
|
}
|
||||||
|
if friendlyEinoExecuteInvokeTail(errors.New("broken pipe")) == "" {
|
||||||
|
t.Fatal("unexpected error should get tail")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,11 +7,25 @@ import (
|
|||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
// newEinoExecuteMonitorCallbacks 在 Eino filesystem execute 开始/结束时写入 MCP 监控库并 recorder(executionId),
|
||||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
// 与 CallTool 路径一致,使监控页能展示「执行中」状态。
|
||||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
func newEinoExecuteMonitorCallbacks(ag *agent.Agent, recorder einomcp.ExecutionRecorder) (
|
||||||
return func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
begin func(toolCallID, command string) string,
|
||||||
if ag == nil || recorder == nil {
|
finish func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
|
||||||
|
) {
|
||||||
|
begin = func(toolCallID, command string) string {
|
||||||
|
if ag == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
args := map[string]interface{}{"command": command}
|
||||||
|
id := ag.BeginLocalToolExecution("execute", args)
|
||||||
|
if id != "" && recorder != nil {
|
||||||
|
recorder(id, toolCallID)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
finish = func(executionID, toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||||
|
if ag == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
@@ -23,9 +37,10 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
args := map[string]interface{}{"command": command}
|
args := map[string]interface{}{"command": command}
|
||||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
id := ag.FinishLocalToolExecution(executionID, "execute", args, stdout, err)
|
||||||
if id != "" {
|
if id != "" && recorder != nil && executionID == "" {
|
||||||
recorder(id, toolCallID)
|
recorder(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return begin, finish
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
|||||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||||
//
|
//
|
||||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
// run loop 收到 Fire 后立即推送 tool_result(toolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」。
|
||||||
//
|
//
|
||||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||||
@@ -63,8 +63,11 @@ type einoStreamingShellWrap struct {
|
|||||||
outputChunk func(toolName, toolCallID, chunk string)
|
outputChunk func(toolName, toolCallID, chunk string)
|
||||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||||
toolTimeoutMinutes int
|
toolTimeoutMinutes int
|
||||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
// shellNoOutputTimeoutSec:无任何输出时的空闲秒数;0=关闭。
|
||||||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
|
shellNoOutputTimeoutSec int
|
||||||
|
// beginMonitor 在 execute 开始时写入 running 状态;finishMonitor 在流结束后更新为 completed/failed。
|
||||||
|
beginMonitor func(toolCallID, command string) string
|
||||||
|
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
@@ -76,15 +79,26 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
}
|
}
|
||||||
req := *input
|
req := *input
|
||||||
userCmd := strings.TrimSpace(req.Command)
|
userCmd := strings.TrimSpace(req.Command)
|
||||||
|
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||||
|
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||||
req.RunInBackendGround = true
|
req.RunInBackendGround = true
|
||||||
}
|
}
|
||||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
req.Command = security.PrepareNonInteractiveShellCommand(prependPythonUnbufferedEnv(req.Command))
|
||||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
|
||||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
|
||||||
convID := mcp.MCPConversationIDFromContext(ctx)
|
convID := mcp.MCPConversationIDFromContext(ctx)
|
||||||
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||||||
|
|
||||||
|
var monitorExecID string
|
||||||
|
if w.beginMonitor != nil {
|
||||||
|
monitorExecID = w.beginMonitor(tid, userCmd)
|
||||||
|
}
|
||||||
|
if monitorExecID != "" && convID != "" {
|
||||||
|
if toolReg := mcp.ToolRunRegistryFromContext(ctx); toolReg != nil {
|
||||||
|
toolReg.RegisterRunningTool(convID, monitorExecID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolRunReg := mcp.ToolRunRegistryFromContext(ctx)
|
||||||
|
|
||||||
execCtx, execCancel := context.WithCancel(ctx)
|
execCtx, execCancel := context.WithCancel(ctx)
|
||||||
var timeoutCancel context.CancelFunc
|
var timeoutCancel context.CancelFunc
|
||||||
if w.toolTimeoutMinutes > 0 {
|
if w.toolTimeoutMinutes > 0 {
|
||||||
@@ -104,23 +118,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
}
|
}
|
||||||
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
if w.recordMonitor != nil {
|
if w.finishMonitor != nil {
|
||||||
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
w.finishMonitor(monitorExecID, tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||||||
}
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.finishMonitor != nil {
|
||||||
w.recordMonitor(tid, userCmd, "", false, err)
|
w.finishMonitor(monitorExecID, tid, userCmd, "", false, err)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sr == nil || w.invokeNotify == nil {
|
if sr == nil {
|
||||||
if timeoutCancel != nil {
|
if timeoutCancel != nil {
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
}
|
}
|
||||||
@@ -132,7 +146,7 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
|
|
||||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||||
|
|
||||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
|
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry, toolReg mcp.ToolRunRegistry, execID string, toolCallID string, noOutputSec int) {
|
||||||
var innerCloseOnce sync.Once
|
var innerCloseOnce sync.Once
|
||||||
closeInner := func() {
|
closeInner := func() {
|
||||||
innerCloseOnce.Do(func() { inner.Close() })
|
innerCloseOnce.Do(func() { inner.Close() })
|
||||||
@@ -147,6 +161,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if reg != nil && conversationID != "" {
|
if reg != nil && conversationID != "" {
|
||||||
defer reg.UnregisterActiveEinoExecute(conversationID)
|
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||||||
}
|
}
|
||||||
|
if toolReg != nil && conversationID != "" && execID != "" {
|
||||||
|
defer toolReg.UnregisterRunningTool(conversationID, execID)
|
||||||
|
}
|
||||||
|
|
||||||
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||||||
stopWatch := make(chan struct{})
|
stopWatch := make(chan struct{})
|
||||||
@@ -165,50 +182,103 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
exitCode := 0
|
exitCode := 0
|
||||||
hasExitCode := false
|
hasExitCode := false
|
||||||
|
|
||||||
|
idleWatch := security.NewShellInactivityWatch(noOutputSec)
|
||||||
|
if idleWatch != nil {
|
||||||
|
defer idleWatch.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
type execRecvMsg struct {
|
||||||
|
resp *filesystem.ExecuteResponse
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
recvCh := make(chan execRecvMsg, 1)
|
||||||
|
go func() {
|
||||||
for {
|
for {
|
||||||
resp, rerr := inner.Recv()
|
resp, rerr := inner.Recv()
|
||||||
|
recvCh <- execRecvMsg{resp: resp, err: rerr}
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
fireInactivityTimeout := func() {
|
||||||
|
success = false
|
||||||
|
invokeErr = fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||||
|
msg := security.ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: msg}, nil)
|
||||||
|
sb.WriteString(msg)
|
||||||
|
if w.outputChunk != nil && toolCallID != "" {
|
||||||
|
w.outputChunk("execute", toolCallID, msg)
|
||||||
|
}
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
closeInner()
|
||||||
|
}
|
||||||
|
|
||||||
|
recvLoop:
|
||||||
|
for {
|
||||||
|
var idleCh <-chan struct{}
|
||||||
|
if idleWatch != nil {
|
||||||
|
idleCh = idleWatch.Expired
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-idleCh:
|
||||||
|
fireInactivityTimeout()
|
||||||
|
break recvLoop
|
||||||
|
case msg := <-recvCh:
|
||||||
|
rerr := msg.err
|
||||||
|
resp := msg.resp
|
||||||
if errors.Is(rerr, io.EOF) {
|
if errors.Is(rerr, io.EOF) {
|
||||||
break
|
break recvLoop
|
||||||
}
|
}
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
success = false
|
success = false
|
||||||
invokeErr = rerr
|
invokeErr = rerr
|
||||||
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
|
||||||
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||||
invokeErr = context.DeadlineExceeded
|
invokeErr = context.DeadlineExceeded
|
||||||
break
|
break recvLoop
|
||||||
}
|
}
|
||||||
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||||||
invokeErr = context.Canceled
|
invokeErr = context.Canceled
|
||||||
break
|
break recvLoop
|
||||||
}
|
}
|
||||||
_ = outW.Send(nil, rerr)
|
_ = outW.Send(nil, rerr)
|
||||||
break
|
break recvLoop
|
||||||
}
|
}
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
if resp.ExitCode != nil {
|
if resp.ExitCode != nil {
|
||||||
hasExitCode = true
|
hasExitCode = true
|
||||||
exitCode = *resp.ExitCode
|
exitCode = *resp.ExitCode
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
var appended string
|
var appended string
|
||||||
if resp.Output != "" {
|
if resp.Output != "" {
|
||||||
|
if security.IsLegacyShellExitNoise(resp.Output) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if idleWatch != nil {
|
||||||
|
idleWatch.Bump()
|
||||||
|
}
|
||||||
sb.WriteString(resp.Output)
|
sb.WriteString(resp.Output)
|
||||||
appended = resp.Output
|
appended = resp.Output
|
||||||
}
|
}
|
||||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||||
w.outputChunk("execute", tid, appended)
|
w.outputChunk("execute", toolCallID, appended)
|
||||||
}
|
}
|
||||||
if outW.Send(resp, nil) {
|
if outW.Send(resp, nil) {
|
||||||
success = false
|
success = false
|
||||||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||||
break
|
break recvLoop
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if success && hasExitCode && exitCode != 0 {
|
if success && hasExitCode && exitCode != 0 {
|
||||||
success = false
|
success = false
|
||||||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
invokeErr = &ExecuteExitError{Code: exitCode}
|
||||||
}
|
}
|
||||||
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||||
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||||
@@ -248,12 +318,24 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if w.recordMonitor != nil {
|
rawOutput := sb.String()
|
||||||
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
fireBody := rawOutput
|
||||||
|
if !success && hasExitCode && exitCode != 0 {
|
||||||
|
statusLine := security.ExecuteFailureStatusLine(exitCode)
|
||||||
|
if !strings.Contains(rawOutput, "命令执行失败:") {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: statusLine}, nil)
|
||||||
|
sb.WriteString(statusLine)
|
||||||
|
}
|
||||||
|
fireBody = einomcp.ToolErrorPrefix + security.FormatCommandFailureResult(exitCode, rawOutput)
|
||||||
|
}
|
||||||
|
if w.finishMonitor != nil {
|
||||||
|
w.finishMonitor(execID, toolCallID, command, sb.String(), success, invokeErr)
|
||||||
|
}
|
||||||
|
if w.invokeNotify != nil {
|
||||||
|
w.invokeNotify.Fire(toolCallID, "execute", agentTag, success, fireBody, invokeErr)
|
||||||
}
|
}
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
|
||||||
outW.Close()
|
outW.Close()
|
||||||
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
|
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg, toolRunReg, monitorExecID, tid, w.shellNoOutputTimeoutSec)
|
||||||
|
|
||||||
return outR, nil
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,9 +19,15 @@ type mockStreamingShell struct {
|
|||||||
immediateErr error
|
immediateErr error
|
||||||
recvErr error
|
recvErr error
|
||||||
output string
|
output string
|
||||||
|
called bool
|
||||||
|
lastCommand string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
m.called = true
|
||||||
|
if input != nil {
|
||||||
|
m.lastCommand = input.Command
|
||||||
|
}
|
||||||
if m.immediateErr != nil {
|
if m.immediateErr != nil {
|
||||||
return nil, m.immediateErr
|
return nil, m.immediateErr
|
||||||
}
|
}
|
||||||
@@ -38,6 +44,135 @@ func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesy
|
|||||||
return outR, nil
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_PreparesNonInteractiveCommand(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{output: "ok\n"}
|
||||||
|
wrap := &einoStreamingShellWrap{inner: inner}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo ok"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
for {
|
||||||
|
_, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(inner.lastCommand, "exec </dev/null") {
|
||||||
|
t.Fatalf("missing stdin redirect in inner command: %q", inner.lastCommand)
|
||||||
|
}
|
||||||
|
if !strings.Contains(inner.lastCommand, "GIT_PAGER=cat") {
|
||||||
|
t.Fatalf("missing pager export in inner command: %q", inner.lastCommand)
|
||||||
|
}
|
||||||
|
if !strings.Contains(inner.lastCommand, "PYTHONUNBUFFERED=1") {
|
||||||
|
t.Fatalf("missing python unbuffer in inner command: %q", inner.lastCommand)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_NoOutputTimeout(t *testing.T) {
|
||||||
|
inner := &mockStreamingShellHanging{}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
var fired string
|
||||||
|
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
fired = content
|
||||||
|
})
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
invokeNotify: notify,
|
||||||
|
shellNoOutputTimeoutSec: 1,
|
||||||
|
}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !inner.called {
|
||||||
|
t.Fatal("inner shell should run (no command blacklist)")
|
||||||
|
}
|
||||||
|
out := got.String()
|
||||||
|
if !strings.Contains(out, "没有新的输出") && !strings.Contains(out, "no new output") {
|
||||||
|
t.Fatalf("expected inactivity timeout message, got: %q notify=%q", out, fired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockStreamingShellPartialThenHang struct {
|
||||||
|
called bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShellPartialThenHang) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
m.called = true
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: "[sudo] password:\n"}, nil)
|
||||||
|
<-ctx.Done()
|
||||||
|
outW.Close()
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_InactivityAfterPartialOutput(t *testing.T) {
|
||||||
|
inner := &mockStreamingShellPartialThenHang{}
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
shellNoOutputTimeoutSec: 1,
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "sudo whoami"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if time.Since(start) > 5*time.Second {
|
||||||
|
t.Fatalf("expected inactivity timeout ~1s, took %v", time.Since(start))
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), "没有新的输出") && !strings.Contains(got.String(), "no new output") {
|
||||||
|
t.Fatalf("expected inactivity message, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockStreamingShellHanging struct {
|
||||||
|
called bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShellHanging) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
m.called = true
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
outW.Close()
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
||||||
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -63,10 +63,43 @@ func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName
|
|||||||
return map[string]interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// beginEinoADKFilesystemToolMonitor 在 Eino ADK filesystem 工具开始调用时写入 running 状态。
|
||||||
|
func beginEinoADKFilesystemToolMonitor(
|
||||||
|
ag *agent.Agent,
|
||||||
|
rec einomcp.ExecutionRecorder,
|
||||||
|
binder *MCPExecutionBinder,
|
||||||
|
toolCallID, toolName string,
|
||||||
|
) {
|
||||||
|
if ag == nil || rec == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(toolName)
|
||||||
|
if name == "" || strings.EqualFold(name, "execute") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isBuiltinEinoADKFilesystemToolName(name) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
if tid == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
storedName := "eino_fs::" + strings.ToLower(name)
|
||||||
|
id := ag.BeginLocalToolExecution(storedName, map[string]interface{}{})
|
||||||
|
if id == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rec(id, tid)
|
||||||
|
if binder != nil {
|
||||||
|
binder.Bind(tid, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
||||||
func recordEinoADKFilesystemToolMonitor(
|
func recordEinoADKFilesystemToolMonitor(
|
||||||
ag *agent.Agent,
|
ag *agent.Agent,
|
||||||
rec einomcp.ExecutionRecorder,
|
rec einomcp.ExecutionRecorder,
|
||||||
|
binder *MCPExecutionBinder,
|
||||||
toolName string,
|
toolName string,
|
||||||
toolCallID string,
|
toolCallID string,
|
||||||
msgs []adk.Message,
|
msgs []adk.Message,
|
||||||
@@ -94,8 +127,12 @@ func recordEinoADKFilesystemToolMonitor(
|
|||||||
invErr = errors.New(t)
|
invErr = errors.New(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
execID := ""
|
||||||
if id != "" {
|
if binder != nil {
|
||||||
|
execID = binder.ExecutionID(toolCallID)
|
||||||
|
}
|
||||||
|
id := ag.FinishLocalToolExecution(execID, storedName, args, resultText, invErr)
|
||||||
|
if id != "" && execID == "" {
|
||||||
rec(id, toolCallID)
|
rec(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -136,7 +136,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -81,8 +82,10 @@ func subAgentFilesystemMiddleware(
|
|||||||
loc *localbk.Local,
|
loc *localbk.Local,
|
||||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||||
einoAgentName string,
|
einoAgentName string,
|
||||||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
|
beginMonitor func(toolCallID, command string) string,
|
||||||
|
finishMonitor func(executionID, toolCallID, command, stdout string, success bool, invokeErr error),
|
||||||
toolTimeoutMinutes int,
|
toolTimeoutMinutes int,
|
||||||
|
shellNoOutputTimeoutSec int,
|
||||||
outputChunk func(toolName, toolCallID, chunk string),
|
outputChunk func(toolName, toolCallID, chunk string),
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if loc == nil {
|
if loc == nil {
|
||||||
@@ -91,12 +94,14 @@ func subAgentFilesystemMiddleware(
|
|||||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||||
Backend: loc,
|
Backend: loc,
|
||||||
StreamingShell: &einoStreamingShellWrap{
|
StreamingShell: &einoStreamingShellWrap{
|
||||||
inner: loc,
|
inner: security.NewEinoStreamingShell(),
|
||||||
invokeNotify: invokeNotify,
|
invokeNotify: invokeNotify,
|
||||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||||
outputChunk: outputChunk,
|
outputChunk: outputChunk,
|
||||||
recordMonitor: recordMonitor,
|
beginMonitor: beginMonitor,
|
||||||
|
finishMonitor: finishMonitor,
|
||||||
toolTimeoutMinutes: toolTimeoutMinutes,
|
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||||
|
shellNoOutputTimeoutSec: shellNoOutputTimeoutSec,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -108,3 +113,18 @@ func agentToolTimeoutMinutes(cfg *config.Config) int {
|
|||||||
}
|
}
|
||||||
return cfg.Agent.ToolTimeoutMinutes
|
return cfg.Agent.ToolTimeoutMinutes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// agentShellNoOutputTimeoutSeconds:0=默认 300s(5 分钟);-1=关闭;>0=自定义秒数。
|
||||||
|
func agentShellNoOutputTimeoutSeconds(cfg *config.Config) int {
|
||||||
|
if cfg == nil {
|
||||||
|
return 300
|
||||||
|
}
|
||||||
|
v := cfg.Agent.ShellNoOutputTimeoutSeconds
|
||||||
|
if v < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if v == 0 {
|
||||||
|
return 300
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, too
|
|||||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||||
}
|
}
|
||||||
|
if s := strings.TrimSpace(injectShellToolGuidance("", names)); s != "" {
|
||||||
|
sb.WriteString(s)
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
if s := strings.TrimSpace(instruction); s != "" {
|
if s := strings.TrimSpace(instruction); s != "" {
|
||||||
sb.WriteString(s)
|
sb.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
|
|||||||
|
|
||||||
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
||||||
|
|
||||||
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。
|
// reset 在退避重试后成功推进(流/消息完整接收)时清零计数,使后续临时错误从第 1 次退避重新开始。
|
||||||
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
||||||
|
|
||||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||||
|
|||||||
@@ -105,6 +105,32 @@ func TestEinoTransientRunRetrierReset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEinoTransientRunRetrierConsecutiveFailures(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
|
||||||
|
ctx := context.Background()
|
||||||
|
runErr := errors.New("internal server error")
|
||||||
|
args := &einoADKRunLoopArgs{}
|
||||||
|
base := []adk.Message{schema.UserMessage("hi")}
|
||||||
|
|
||||||
|
for want := 1; want <= 3; want++ {
|
||||||
|
restarted, _, _, _, err := r.tryRetry(ctx, runErr, args, base, nil, len(base))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tryRetry attempt %d: %v", want, err)
|
||||||
|
}
|
||||||
|
if !restarted {
|
||||||
|
t.Fatalf("tryRetry attempt %d: want restarted", want)
|
||||||
|
}
|
||||||
|
if got := r.attempt(); got != want {
|
||||||
|
t.Fatalf("after failure %d: attempt=%d, want %d", want, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.reset()
|
||||||
|
if r.attempt() != 0 {
|
||||||
|
t.Fatalf("after successful recovery reset: attempt=%d, want 0", r.attempt())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// ExecuteExitError 表示 execute 命令非零退出(预期失败,非超时/中断/流异常)。
|
||||||
|
type ExecuteExitError struct {
|
||||||
|
Code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExecuteExitError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return "exit status unknown"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("exit status %d", e.Code)
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/agents"
|
"cyberstrike-ai/internal/agents"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
// DefaultPlanExecuteOrchestratorInstruction 当未配置 plan_execute 专用 Markdown / YAML 时的内置主代理(规划/重规划侧)提示。
|
||||||
@@ -122,7 +123,9 @@ func DefaultPlanExecuteOrchestratorInstruction() string {
|
|||||||
|
|
||||||
## 表达
|
## 表达
|
||||||
|
|
||||||
在调用工具或给出计划变更前,用 2~5 句中文说明当前决策依据与期望证据形态;最终对用户交付结构化结论(发现摘要、证据、风险、下一步)。`
|
在调用工具或给出计划变更前,用 2~5 句中文说明当前决策依据与期望证据形态;最终对用户交付结构化结论(发现摘要、证据、风险、下一步)。
|
||||||
|
|
||||||
|
` + projectprompt.ShellExecExecuteGuidanceSection()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
|
// DefaultSupervisorOrchestratorInstruction 当未配置 supervisor 专用 Markdown / YAML 时的内置监督者提示(transfer / exit 说明仍由运行时在末尾追加)。
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -120,7 +121,7 @@ func RunDeepAgent(
|
|||||||
mcpIDs = append(mcpIDs, id)
|
mcpIDs = append(mcpIDs, id)
|
||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
}
|
}
|
||||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
einoExecBegin, einoExecFinish := newEinoExecuteMonitorCallbacks(ag, recorder)
|
||||||
|
|
||||||
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
||||||
snapshotMCPIDs := func() []string {
|
snapshotMCPIDs := func() []string {
|
||||||
@@ -223,7 +224,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||||
}
|
}
|
||||||
@@ -358,12 +359,14 @@ func RunDeepAgent(
|
|||||||
if einoLoc != nil && einoFSTools {
|
if einoLoc != nil && einoFSTools {
|
||||||
deepBackend = einoLoc
|
deepBackend = einoLoc
|
||||||
deepShell = &einoStreamingShellWrap{
|
deepShell = &einoStreamingShellWrap{
|
||||||
inner: einoLoc,
|
inner: security.NewEinoStreamingShell(),
|
||||||
invokeNotify: toolInvokeNotify,
|
invokeNotify: toolInvokeNotify,
|
||||||
einoAgentName: orchestratorName,
|
einoAgentName: orchestratorName,
|
||||||
outputChunk: nil,
|
outputChunk: nil,
|
||||||
recordMonitor: einoExecMonitor,
|
beginMonitor: einoExecBegin,
|
||||||
|
finishMonitor: einoExecFinish,
|
||||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||||
|
shellNoOutputTimeoutSec: agentShellNoOutputTimeoutSeconds(appCfg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,7 +431,7 @@ func RunDeepAgent(
|
|||||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||||
var peFsMw adk.ChatModelAgentMiddleware
|
var peFsMw adk.ChatModelAgentMiddleware
|
||||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecBegin, einoExecFinish, agentToolTimeoutMinutes(appCfg), agentShellNoOutputTimeoutSeconds(appCfg), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func shellToolsPresent(toolNames []string) bool {
|
||||||
|
for _, n := range toolNames {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(n)) {
|
||||||
|
case "exec", "execute":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectShellToolGuidance 在系统提示末尾追加 exec/execute 分工(仅当工具列表含 exec 或 execute)。
|
||||||
|
func injectShellToolGuidance(instruction string, toolNames []string) string {
|
||||||
|
if !shellToolsPresent(toolNames) {
|
||||||
|
return instruction
|
||||||
|
}
|
||||||
|
block := strings.TrimSpace(projectprompt.ShellExecExecuteGuidanceSection())
|
||||||
|
if block == "" {
|
||||||
|
return instruction
|
||||||
|
}
|
||||||
|
s := strings.TrimSpace(instruction)
|
||||||
|
if s == "" {
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
return s + "\n\n" + block
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInjectShellToolGuidance(t *testing.T) {
|
||||||
|
got := injectShellToolGuidance("base", []string{"nmap"})
|
||||||
|
if got != "base" {
|
||||||
|
t.Fatalf("expected unchanged, got %q", got)
|
||||||
|
}
|
||||||
|
got = injectShellToolGuidance("base", []string{"exec", "nmap"})
|
||||||
|
if !strings.Contains(got, "exec/execute") || !strings.Contains(got, "base") {
|
||||||
|
t.Fatalf("expected shell guidance appended, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package projectprompt
|
||||||
|
|
||||||
|
// ShellExecExecuteGuidanceSection 供单代理/多代理系统提示追加:exec 与 execute 分工(尽量短)。
|
||||||
|
func ShellExecExecuteGuidanceSection() string {
|
||||||
|
return `Shell(exec/execute):有专用 MCP 工具时优先专用工具;系统命令(管道、workdir、后台 &)用 exec;skills/ 内脚本(配合 read_file、skill)用 execute;多步扫描分拆调用,禁止一条 shell 串多个扫描器。`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShellExecExecuteGuidanceReconSuffix 侦察子代理可选追加(一行)。
|
||||||
|
func ShellExecExecuteGuidanceReconSuffix() string {
|
||||||
|
return `枚举优先 subfinder、amass 等专用 MCP,勿 exec/execute 拼长链。`
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FormatCommandFailureResult 与 exec 工具 ToolResult 文案一致(不含 ToolErrorPrefix)。
|
||||||
|
func FormatCommandFailureResult(exitCode int, output string) string {
|
||||||
|
output = strings.TrimSpace(output)
|
||||||
|
errMsg := fmt.Sprintf("exit status %d", exitCode)
|
||||||
|
if output == "" {
|
||||||
|
return fmt.Sprintf("命令执行失败: %s", errMsg)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(output, "命令执行失败:") {
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("命令执行失败: %s\n输出: %s", errMsg, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatCommandFailureFromErr 根据 exec/execute 返回的 error 生成统一失败文案(IsError 正文)。
|
||||||
|
func FormatCommandFailureFromErr(err error, output string) string {
|
||||||
|
if err == nil {
|
||||||
|
return strings.TrimSpace(output)
|
||||||
|
}
|
||||||
|
var exitError *exec.ExitError
|
||||||
|
if errors.As(err, &exitError) {
|
||||||
|
return FormatCommandFailureResult(exitError.ExitCode(), output)
|
||||||
|
}
|
||||||
|
output = strings.TrimSpace(output)
|
||||||
|
if output == "" {
|
||||||
|
return fmt.Sprintf("命令执行失败: %v", err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(output, "命令执行失败:") {
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("命令执行失败: %v\n输出: %s", err, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteFailureStatusLine 流式 execute 结束时追加的单行状态(输出正文已在流中推送过)。
|
||||||
|
func ExecuteFailureStatusLine(exitCode int) string {
|
||||||
|
return fmt.Sprintf("\n命令执行失败: exit status %d", exitCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCommandFailureResult 判断工具结果正文是否表示命令非零退出(用于 execute / exec 对齐 isError)。
|
||||||
|
func IsCommandFailureResult(content string) bool {
|
||||||
|
return strings.Contains(content, "命令执行失败:")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLegacyShellExitNoise 过滤旧版 shell 流中冗余的 exit code 行。
|
||||||
|
func IsLegacyShellExitNoise(s string) bool {
|
||||||
|
trimmed := strings.TrimSpace(s)
|
||||||
|
return strings.HasPrefix(trimmed, "command exited with non-zero code ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatCommandFailureResult(t *testing.T) {
|
||||||
|
got := FormatCommandFailureResult(1, "sudo: password required")
|
||||||
|
want := "命令执行失败: exit status 1\n输出: sudo: password required"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if FormatCommandFailureResult(2, "") != "命令执行失败: exit status 2" {
|
||||||
|
t.Fatal("empty output format")
|
||||||
|
}
|
||||||
|
if FormatCommandFailureResult(1, "命令执行失败: exit status 1") != "命令执行失败: exit status 1" {
|
||||||
|
t.Fatal("should not double-wrap")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsCommandFailureResult(t *testing.T) {
|
||||||
|
if !IsCommandFailureResult("sudo: err\n命令执行失败: exit status 1") {
|
||||||
|
t.Fatal("expected true")
|
||||||
|
}
|
||||||
|
if IsCommandFailureResult("sudo: err only") {
|
||||||
|
t.Fatal("expected false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatCommandFailureFromErr(t *testing.T) {
|
||||||
|
cmd := exec.Command("sh", "-c", "exit 42")
|
||||||
|
err := cmd.Run()
|
||||||
|
got := FormatCommandFailureFromErr(err, "oops")
|
||||||
|
if got != "命令执行失败: exit status 42\n输出: oops" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
timeoutErr := errors.New("shell inactivity timeout (300s)")
|
||||||
|
got2 := FormatCommandFailureFromErr(timeoutErr, "already timed out")
|
||||||
|
if !strings.Contains(got2, "shell inactivity timeout") || !strings.Contains(got2, "already timed out") {
|
||||||
|
t.Fatalf("got %q", got2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsLegacyShellExitNoise(t *testing.T) {
|
||||||
|
if !IsLegacyShellExitNoise("command exited with non-zero code 1\n") {
|
||||||
|
t.Fatal("expected legacy noise")
|
||||||
|
}
|
||||||
|
if IsLegacyShellExitNoise("sudo: failed") {
|
||||||
|
t.Fatal("unexpected noise")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -36,6 +36,7 @@ type Executor struct {
|
|||||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||||
mcpServer *mcp.Server
|
mcpServer *mcp.Server
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300;-1=关闭(见 SetShellNoOutputTimeoutSeconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewExecutor 创建新的执行器
|
// NewExecutor 创建新的执行器
|
||||||
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
|
|||||||
return executor
|
return executor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
|
||||||
|
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
|
||||||
|
e.shellNoOutputTimeoutSec = sec
|
||||||
|
}
|
||||||
|
|
||||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||||
func (e *Executor) buildToolIndex() {
|
func (e *Executor) buildToolIndex() {
|
||||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||||
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
// 执行命令
|
// 执行命令
|
||||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||||
applyDefaultTerminalEnv(cmd)
|
applyDefaultTerminalEnv(cmd)
|
||||||
|
attachNonInteractiveStdin(cmd)
|
||||||
_ = prepareShellCmdSession(cmd)
|
_ = prepareShellCmdSession(cmd)
|
||||||
|
|
||||||
e.logger.Info("执行安全工具",
|
e.logger.Info("执行安全工具",
|
||||||
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
var err error
|
var err error
|
||||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||||
zap.String("tool", toolName),
|
zap.String("tool", toolName),
|
||||||
@@ -797,6 +804,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
zap.String("command", command),
|
zap.String("command", command),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
command = PrepareNonInteractiveShellCommand(command)
|
||||||
|
|
||||||
// 获取shell类型(可选,默认为sh)
|
// 获取shell类型(可选,默认为sh)
|
||||||
shell := "sh"
|
shell := "sh"
|
||||||
if s, ok := args["shell"].(string); ok && s != "" {
|
if s, ok := args["shell"].(string); ok && s != "" {
|
||||||
@@ -820,8 +829,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
} else {
|
} else {
|
||||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd)
|
ConfigureShellCmdForAgentExecute(cmd)
|
||||||
_ = prepareShellCmdSession(cmd)
|
|
||||||
|
|
||||||
// 执行命令
|
// 执行命令
|
||||||
e.logger.Info("执行系统命令",
|
e.logger.Info("执行系统命令",
|
||||||
@@ -850,8 +858,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
} else {
|
} else {
|
||||||
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(pidCmd)
|
ConfigureShellCmdForAgentExecute(pidCmd)
|
||||||
_ = prepareShellCmdSession(pidCmd)
|
|
||||||
|
|
||||||
// 获取stdout管道
|
// 获取stdout管道
|
||||||
stdout, err := pidCmd.StdoutPipe()
|
stdout, err := pidCmd.StdoutPipe()
|
||||||
@@ -963,15 +970,14 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
var err error
|
var err error
|
||||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||||
if workDir != "" {
|
if workDir != "" {
|
||||||
cmd2.Dir = workDir
|
cmd2.Dir = workDir
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd2)
|
ConfigureShellCmdForAgentExecute(cmd2)
|
||||||
_ = prepareShellCmdSession(cmd2)
|
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -984,8 +990,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
if workDir != "" {
|
if workDir != "" {
|
||||||
cmd2.Dir = workDir
|
cmd2.Dir = workDir
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd2)
|
ConfigureShellCmdForAgentExecute(cmd2)
|
||||||
_ = prepareShellCmdSession(cmd2)
|
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -999,7 +1004,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
Content: []mcp.Content{
|
Content: []mcp.Content{
|
||||||
{
|
{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
|
Text: FormatCommandFailureFromErr(err, output),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
IsError: true,
|
IsError: true,
|
||||||
@@ -1024,7 +1029,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
|
|
||||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
|
||||||
if err := prepareShellCmdSession(cmd); err != nil {
|
if err := prepareShellCmdSession(cmd); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -1091,14 +1096,45 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
|||||||
lastFlush = time.Now()
|
lastFlush = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
for chunk := range chunks {
|
idleWatch := NewShellInactivityWatch(noOutputSec)
|
||||||
|
if idleWatch != nil {
|
||||||
|
defer idleWatch.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
fireInactivity := func() {
|
||||||
|
terminateCmdTree(cmd)
|
||||||
|
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||||
|
outBuilder.WriteString(msg)
|
||||||
|
if cb != nil {
|
||||||
|
cb(msg)
|
||||||
|
}
|
||||||
|
_ = cmd.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
chunksLoop:
|
||||||
|
for {
|
||||||
|
var idleCh <-chan struct{}
|
||||||
|
if idleWatch != nil {
|
||||||
|
idleCh = idleWatch.Expired
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-idleCh:
|
||||||
|
fireInactivity()
|
||||||
|
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||||
|
case chunk, ok := <-chunks:
|
||||||
|
if !ok {
|
||||||
|
break chunksLoop
|
||||||
|
}
|
||||||
|
if chunk != "" && idleWatch != nil {
|
||||||
|
idleWatch.Bump()
|
||||||
|
}
|
||||||
outBuilder.WriteString(chunk)
|
outBuilder.WriteString(chunk)
|
||||||
deltaBuilder.WriteString(chunk)
|
deltaBuilder.WriteString(chunk)
|
||||||
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
|
|
||||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||||
flush()
|
flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
// 等待命令结束,返回最终退出状态
|
// 等待命令结束,返回最终退出状态
|
||||||
@@ -1116,6 +1152,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
|
|||||||
if cmd.Env == nil {
|
if cmd.Env == nil {
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
}
|
}
|
||||||
|
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
|
||||||
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
||||||
has := func(k string) bool {
|
has := func(k string) bool {
|
||||||
prefix := k + "="
|
prefix := k + "="
|
||||||
@@ -1159,7 +1196,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
return streamCommandOutput(ctx, cmd, cb)
|
return streamCommandOutput(ctx, cmd, cb, 0)
|
||||||
}
|
}
|
||||||
_ = prepareShellCmdSession(cmd)
|
_ = prepareShellCmdSession(cmd)
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
|
|||||||
@@ -71,6 +71,27 @@ func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteSystemCommand_FailureFormat(t *testing.T) {
|
||||||
|
executor, _ := setupTestExecutor(t)
|
||||||
|
res, err := executor.executeSystemCommand(context.Background(), map[string]interface{}{
|
||||||
|
"command": "echo fail-msg >&2; exit 7",
|
||||||
|
"shell": "sh",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("executeSystemCommand: %v", err)
|
||||||
|
}
|
||||||
|
if res == nil || !res.IsError {
|
||||||
|
t.Fatalf("expected IsError, got %+v", res)
|
||||||
|
}
|
||||||
|
text := res.Content[0].Text
|
||||||
|
if text != FormatCommandFailureResult(7, "fail-msg\n") && text != FormatCommandFailureResult(7, "fail-msg") {
|
||||||
|
t.Fatalf("unexpected failure text: %q", text)
|
||||||
|
}
|
||||||
|
if !strings.Contains(text, "exit status 7") || !strings.Contains(text, "fail-msg") {
|
||||||
|
t.Fatalf("unexpected failure text: %q", text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
|
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
|
||||||
pos1 := 1
|
pos1 := 1
|
||||||
executor, _ := setupTestExecutor(t)
|
executor, _ := setupTestExecutor(t)
|
||||||
|
|||||||
@@ -0,0 +1,200 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigureShellCmdForAgentExecute 与 exec 工具一致:非交互 stdin、pager/TERM 环境、独立进程组。
|
||||||
|
func ConfigureShellCmdForAgentExecute(cmd *exec.Cmd) {
|
||||||
|
if cmd == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
applyDefaultTerminalEnv(cmd)
|
||||||
|
attachNonInteractiveStdin(cmd)
|
||||||
|
_ = prepareShellCmdSession(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TerminateShellCmdTree 尽力终止 shell 及其子进程组(与 exec/execute 超时取消一致)。
|
||||||
|
func TerminateShellCmdTree(cmd *exec.Cmd) {
|
||||||
|
terminateCmdTree(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EinoStreamingShell 为 Eino ADK execute 工具提供流式 shell,行为与 exec 对齐:
|
||||||
|
// 并发读取 stdout/stderr(定长块,非按行),避免官方 local.ExecuteStreaming 先排空 stdout
|
||||||
|
// 导致 stderr 错误(如 sudo 密码提示)长时间不可见、UI 一直显示「执行中」。
|
||||||
|
type EinoStreamingShell struct{}
|
||||||
|
|
||||||
|
// NewEinoStreamingShell 创建 execute 流式 shell 实现。
|
||||||
|
func NewEinoStreamingShell() *EinoStreamingShell {
|
||||||
|
return &EinoStreamingShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStreaming 实现 filesystem.StreamingShell。
|
||||||
|
func (s *EinoStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
if input == nil || input.Command == "" {
|
||||||
|
return nil, fmt.Errorf("command is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
sr, w := schema.Pipe[*filesystem.ExecuteResponse](100)
|
||||||
|
if input.RunInBackendGround {
|
||||||
|
go runShellInBackground(ctx, input.Command, w)
|
||||||
|
return sr, nil
|
||||||
|
}
|
||||||
|
go streamShellForeground(ctx, input.Command, w)
|
||||||
|
return sr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runShellInBackground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||||
|
ConfigureShellCmdForAgentExecute(cmd)
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = stdout.Close()
|
||||||
|
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
_ = stdout.Close()
|
||||||
|
_ = stderr.Close()
|
||||||
|
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
drainShellPipes(stdout, stderr)
|
||||||
|
_ = cmd.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
TerminateShellCmdTree(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
exitCode := 0
|
||||||
|
_ = w.Send(&filesystem.ExecuteResponse{
|
||||||
|
Output: "command started in background\n",
|
||||||
|
ExitCode: &exitCode,
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainShellPipes(stdout, stderr io.Reader) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(io.Discard, stdout)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(io.Discard, stderr)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamShellForeground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||||
|
ConfigureShellCmdForAgentExecute(cmd)
|
||||||
|
|
||||||
|
stdoutPipe, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stderrPipe, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
_ = stdoutPipe.Close()
|
||||||
|
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
_ = stdoutPipe.Close()
|
||||||
|
_ = stderrPipe.Close()
|
||||||
|
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stopWatch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
TerminateShellCmdTree(cmd)
|
||||||
|
case <-stopWatch:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(stopWatch)
|
||||||
|
|
||||||
|
chunks := make(chan string, 64)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
readFn := func(r io.Reader) {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := make([]byte, 8192)
|
||||||
|
for {
|
||||||
|
n, readErr := r.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
chunks <- string(buf[:n])
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(2)
|
||||||
|
go readFn(stdoutPipe)
|
||||||
|
go readFn(stderrPipe)
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(chunks)
|
||||||
|
}()
|
||||||
|
|
||||||
|
hadOutput := false
|
||||||
|
for chunk := range chunks {
|
||||||
|
if chunk == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hadOutput = true
|
||||||
|
if w.Send(&filesystem.ExecuteResponse{Output: chunk}, nil) {
|
||||||
|
TerminateShellCmdTree(cmd)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
waitErr := cmd.Wait()
|
||||||
|
if waitErr == nil {
|
||||||
|
exitCode := 0
|
||||||
|
_ = w.Send(&filesystem.ExecuteResponse{ExitCode: &exitCode}, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var exitError *exec.ExitError
|
||||||
|
if errors.As(waitErr, &exitError) {
|
||||||
|
exitCode := exitError.ExitCode()
|
||||||
|
resp := &filesystem.ExecuteResponse{ExitCode: &exitCode}
|
||||||
|
if !hadOutput {
|
||||||
|
resp.Output = FormatCommandFailureResult(exitCode, "")
|
||||||
|
}
|
||||||
|
_ = w.Send(resp, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = w.Send(nil, fmt.Errorf("command failed: %w", waitErr))
|
||||||
|
}
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEinoStreamingShell_StreamsStderrBeforeStdoutEOF(t *testing.T) {
|
||||||
|
shell := NewEinoStreamingShell()
|
||||||
|
cmd := PrepareNonInteractiveShellCommand("echo err-only >&2; exit 1")
|
||||||
|
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if time.Since(start) > 3*time.Second {
|
||||||
|
t.Fatalf("expected fast completion, took %v", time.Since(start))
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), "err-only") {
|
||||||
|
t.Fatalf("expected stderr in output, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShell_SudoFailsFast(t *testing.T) {
|
||||||
|
shell := NewEinoStreamingShell()
|
||||||
|
cmd := PrepareNonInteractiveShellCommand("sudo whoami && sudo cat /etc/os-release")
|
||||||
|
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
if time.Since(start) > 5*time.Second {
|
||||||
|
t.Fatalf("sudo should fail quickly, took %v output=%q", time.Since(start), got.String())
|
||||||
|
}
|
||||||
|
out := got.String()
|
||||||
|
if strings.Contains(out, "command exited with non-zero code") {
|
||||||
|
t.Fatalf("legacy exit line present: %q", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "sudo") && !strings.Contains(out, "password") && !strings.Contains(out, "terminal") {
|
||||||
|
t.Fatalf("expected sudo error text, got: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShell_StderrWhileStdoutBlocks(t *testing.T) {
|
||||||
|
shell := NewEinoStreamingShell()
|
||||||
|
// 模拟 sudo:stderr 先有输出,stdout 侧进程仍挂起;旧 eino local 在首包 stderr 前不会向流写任何内容。
|
||||||
|
cmd := PrepareNonInteractiveShellCommand(`echo "password prompt" >&2; sleep 30`)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sr, err := shell.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: cmd})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
if strings.Contains(got.String(), "password prompt") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if time.Since(start) > 1500*time.Millisecond {
|
||||||
|
t.Fatalf("expected stderr promptly, took %v output=%q", time.Since(start), got.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), "password prompt") {
|
||||||
|
t.Fatalf("expected early stderr, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShellNoOutputTimeoutMessage 长时间无新 stdout/stderr 时的提示(软失败,模型可见)。
|
||||||
|
func ShellNoOutputTimeoutMessage(idleSec int) string {
|
||||||
|
return fmt.Sprintf(`命令已终止:超过 %d 秒没有新的输出,疑似在等待交互输入或已挂起。
|
||||||
|
|
||||||
|
长时静默任务请使用末尾 & 后台运行,或增大 agent.shell_no_output_timeout_seconds(-1=关闭此检测)。
|
||||||
|
|
||||||
|
Command terminated: no new output for %d seconds (possible interactive wait or hung process).`, idleSec, idleSec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShellInactivityWatch 在 noOutputSec 内无任何新输出时向 expired 发送信号;每次 Bump 重置计时。
|
||||||
|
// 与「仅有首包输出就永久取消计时」不同,可兜住 sudo 打印 Password 提示后继续挂起等情况。
|
||||||
|
type ShellInactivityWatch struct {
|
||||||
|
Sec int
|
||||||
|
mu sync.Mutex
|
||||||
|
timer *time.Timer
|
||||||
|
Expired chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewShellInactivityWatch(noOutputSec int) *ShellInactivityWatch {
|
||||||
|
sec := ResolveShellNoOutputTimeoutSeconds(noOutputSec)
|
||||||
|
if sec <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
w := &ShellInactivityWatch{
|
||||||
|
Sec: sec,
|
||||||
|
Expired: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
w.Bump()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ShellInactivityWatch) Bump() {
|
||||||
|
if w == nil || w.Sec <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.timer != nil {
|
||||||
|
w.timer.Stop()
|
||||||
|
}
|
||||||
|
w.timer = time.AfterFunc(time.Duration(w.Sec)*time.Second, func() {
|
||||||
|
select {
|
||||||
|
case w.Expired <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ShellInactivityWatch) Stop() {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.timer != nil {
|
||||||
|
w.timer.Stop()
|
||||||
|
w.timer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveShellNoOutputTimeoutSeconds:0=默认 300(5 分钟);-1=关闭;>0=自定义。
|
||||||
|
func ResolveShellNoOutputTimeoutSeconds(sec int) int {
|
||||||
|
if sec < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if sec == 0 {
|
||||||
|
return 300
|
||||||
|
}
|
||||||
|
return sec
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrependNonInteractiveShellExports 为 sh -c 注入通用非交互环境(pager 等),不维护命令黑名单。
|
||||||
|
func PrependNonInteractiveShellExports(shellCommand string) string {
|
||||||
|
if strings.TrimSpace(shellCommand) == "" {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
upper := strings.ToUpper(shellCommand)
|
||||||
|
var pairs []string
|
||||||
|
add := func(key, val string) {
|
||||||
|
if strings.Contains(upper, strings.ToUpper(key)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pairs = append(pairs, key+"="+val)
|
||||||
|
}
|
||||||
|
add("GIT_PAGER", "cat")
|
||||||
|
add("PAGER", "cat")
|
||||||
|
add("SYSTEMD_PAGER", "cat")
|
||||||
|
add("DEBIAN_FRONTEND", "noninteractive")
|
||||||
|
if len(pairs) == 0 {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
return "export " + strings.Join(pairs, " ") + "\n" + shellCommand
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrependNonInteractiveStdinRedirect 为 sh -c 关闭 stdin(与 attachNonInteractiveStdin 等价),
|
||||||
|
// 使 read/input()/sudo -S 等从 stdin 读取的程序快速失败而非挂起。已含 </dev/null 时不重复注入。
|
||||||
|
func PrependNonInteractiveStdinRedirect(shellCommand string) string {
|
||||||
|
if strings.TrimSpace(shellCommand) == "" {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(shellCommand)
|
||||||
|
if strings.Contains(lower, "</dev/null") || strings.Contains(lower, "0</dev/null") {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
return "exec </dev/null\n" + shellCommand
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareNonInteractiveShellCommand 组合非交互包装:stdin 关闭 + pager 等环境变量(零名单)。
|
||||||
|
func PrepareNonInteractiveShellCommand(shellCommand string) string {
|
||||||
|
return PrependNonInteractiveStdinRedirect(PrependNonInteractiveShellExports(shellCommand))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyNonInteractivePagerEnv 为 exec.Cmd 补齐与 PrependNonInteractiveShellExports 一致的环境变量。
|
||||||
|
func ApplyNonInteractivePagerEnv(cmdEnv []string) []string {
|
||||||
|
if cmdEnv == nil {
|
||||||
|
cmdEnv = []string{}
|
||||||
|
}
|
||||||
|
has := func(k string) bool {
|
||||||
|
prefix := k + "="
|
||||||
|
for _, e := range cmdEnv {
|
||||||
|
if strings.HasPrefix(e, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !has("GIT_PAGER") {
|
||||||
|
cmdEnv = append(cmdEnv, "GIT_PAGER=cat")
|
||||||
|
}
|
||||||
|
if !has("PAGER") {
|
||||||
|
cmdEnv = append(cmdEnv, "PAGER=cat")
|
||||||
|
}
|
||||||
|
if !has("SYSTEMD_PAGER") {
|
||||||
|
cmdEnv = append(cmdEnv, "SYSTEMD_PAGER=cat")
|
||||||
|
}
|
||||||
|
if !has("DEBIAN_FRONTEND") {
|
||||||
|
cmdEnv = append(cmdEnv, "DEBIAN_FRONTEND=noninteractive")
|
||||||
|
}
|
||||||
|
return cmdEnv
|
||||||
|
}
|
||||||
|
|
||||||
|
// attachNonInteractiveStdin 关闭交互式 stdin,使部分命令快速失败而非等待输入。
|
||||||
|
func attachNonInteractiveStdin(cmd *exec.Cmd) {
|
||||||
|
if cmd == nil || cmd.Stdin != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f, err := os.Open(os.DevNull)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cmd.Stdin = f
|
||||||
|
}
|
||||||
@@ -0,0 +1,128 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrependNonInteractiveShellExports(t *testing.T) {
|
||||||
|
out := PrependNonInteractiveShellExports("echo hi")
|
||||||
|
if !strings.Contains(out, "GIT_PAGER=cat") || !strings.Contains(out, "PAGER=cat") {
|
||||||
|
t.Fatalf("missing pager exports: %q", out)
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||||
|
t.Fatalf("command suffix lost: %q", out)
|
||||||
|
}
|
||||||
|
skip := PrependNonInteractiveShellExports("GIT_PAGER=less echo hi")
|
||||||
|
if strings.Contains(skip, "export GIT_PAGER=cat") {
|
||||||
|
t.Fatalf("should not override existing GIT_PAGER: %q", skip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrependNonInteractiveStdinRedirect(t *testing.T) {
|
||||||
|
out := PrependNonInteractiveStdinRedirect("echo hi")
|
||||||
|
if !strings.HasPrefix(out, "exec </dev/null") {
|
||||||
|
t.Fatalf("missing stdin redirect: %q", out)
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||||
|
t.Fatalf("command suffix lost: %q", out)
|
||||||
|
}
|
||||||
|
skip := PrependNonInteractiveStdinRedirect("cmd </dev/null")
|
||||||
|
if strings.HasPrefix(skip, "exec </dev/null") {
|
||||||
|
t.Fatalf("should not double redirect: %q", skip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNonInteractiveShellCommand(t *testing.T) {
|
||||||
|
out := PrepareNonInteractiveShellCommand("echo hi")
|
||||||
|
if !strings.Contains(out, "exec </dev/null") {
|
||||||
|
t.Fatalf("missing stdin redirect: %q", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "GIT_PAGER=cat") {
|
||||||
|
t.Fatalf("missing pager export: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewShellInactivityWatch(t *testing.T) {
|
||||||
|
w := NewShellInactivityWatch(1)
|
||||||
|
if w == nil {
|
||||||
|
t.Fatal("expected watch")
|
||||||
|
}
|
||||||
|
w.Bump()
|
||||||
|
select {
|
||||||
|
case <-w.Expired:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("expected inactivity fire within 3s")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveShellNoOutputTimeoutSeconds(t *testing.T) {
|
||||||
|
if ResolveShellNoOutputTimeoutSeconds(0) != 300 {
|
||||||
|
t.Fatal("zero should default to 300")
|
||||||
|
}
|
||||||
|
if ResolveShellNoOutputTimeoutSeconds(-1) != 0 {
|
||||||
|
t.Fatal("-1 should disable")
|
||||||
|
}
|
||||||
|
if ResolveShellNoOutputTimeoutSeconds(30) != 30 {
|
||||||
|
t.Fatal("explicit value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNonInteractiveStdinReadExitsQuickly 验证 exec </dev/null + attachNonInteractiveStdin 时 read 立即 EOF,不挂起。
|
||||||
|
func TestNonInteractiveStdinReadExitsQuickly(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping shell integration in -short")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "sh", "-c", PrepareNonInteractiveShellCommand(`read x; echo "x=<$x>"`))
|
||||||
|
attachNonInteractiveStdin(cmd)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
if elapsed > 2*time.Second {
|
||||||
|
t.Fatalf("read with closed stdin took %v, want <2s", elapsed)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v output=%q", err, out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(out), "x=<>") {
|
||||||
|
t.Fatalf("unexpected output: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNonInteractiveStdinReadBlocksWithoutRedirect 对照:stdin 为永不写入的管道时 read 会挂起。
|
||||||
|
func TestNonInteractiveStdinReadBlocksWithoutRedirect(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping shell integration in -short")
|
||||||
|
}
|
||||||
|
r, w, err := os.Pipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
// 保持 w 打开且不写数据,模拟「等待用户输入」
|
||||||
|
|
||||||
|
cmd := exec.Command("sh", "-c", `read x; echo done`)
|
||||||
|
cmd.Stdin = r
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() { done <- cmd.Run() }()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
t.Fatalf("expected hang, but command finished: %v", err)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
if cmd.Process != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
_ = w.Close()
|
||||||
|
<-done // 等待 goroutine 退出
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2580,6 +2580,8 @@
|
|||||||
"agentModeSingle": "Single-agent (Eino ADK)",
|
"agentModeSingle": "Single-agent (Eino ADK)",
|
||||||
"agentModeMulti": "Multi-agent (Eino)",
|
"agentModeMulti": "Multi-agent (Eino)",
|
||||||
"agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).",
|
"agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).",
|
||||||
|
"concurrency": "Concurrency",
|
||||||
|
"concurrencyHint": "Number of subtasks to run in parallel (1-8). Default 1 is serial; use 1-2 for scan-heavy tasks.",
|
||||||
"scheduleMode": "Schedule mode",
|
"scheduleMode": "Schedule mode",
|
||||||
"scheduleModeManual": "Manual",
|
"scheduleModeManual": "Manual",
|
||||||
"scheduleModeCron": "Cron expression",
|
"scheduleModeCron": "Cron expression",
|
||||||
@@ -2594,8 +2596,8 @@
|
|||||||
"tasksList": "Task list (one task per line)",
|
"tasksList": "Task list (one task per line)",
|
||||||
"tasksListPlaceholder": "Enter task list, one per line",
|
"tasksListPlaceholder": "Enter task list, one per line",
|
||||||
"tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com",
|
"tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com",
|
||||||
"tasksListHint": "Enter one task command per line; the system will execute them in order. Empty lines are ignored.",
|
"tasksListHint": "Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
|
||||||
"tasksListHintFull": "Hint: Enter one task command per line; the system will execute these tasks in order. Empty lines are ignored.",
|
"tasksListHintFull": "Hint: Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
|
||||||
"createQueue": "Create queue"
|
"createQueue": "Create queue"
|
||||||
},
|
},
|
||||||
"batchQueueDetailModal": {
|
"batchQueueDetailModal": {
|
||||||
@@ -2629,6 +2631,8 @@
|
|||||||
"scheduleToggleFailed": "Failed to update schedule toggle",
|
"scheduleToggleFailed": "Failed to update schedule toggle",
|
||||||
"completedAt": "Completed at",
|
"completedAt": "Completed at",
|
||||||
"taskTotal": "Total tasks",
|
"taskTotal": "Total tasks",
|
||||||
|
"concurrency": "Concurrency",
|
||||||
|
"concurrencyEditHint": "Click to edit. Cannot change while the queue is running.",
|
||||||
"taskList": "Task list",
|
"taskList": "Task list",
|
||||||
"startLabel": "Start",
|
"startLabel": "Start",
|
||||||
"completeLabel": "Complete",
|
"completeLabel": "Complete",
|
||||||
|
|||||||
@@ -2568,6 +2568,8 @@
|
|||||||
"agentModeSingle": "单代理(Eino ADK)",
|
"agentModeSingle": "单代理(Eino ADK)",
|
||||||
"agentModeMulti": "多代理(Eino)",
|
"agentModeMulti": "多代理(Eino)",
|
||||||
"agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。",
|
"agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。",
|
||||||
|
"concurrency": "并发数",
|
||||||
|
"concurrencyHint": "同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。",
|
||||||
"scheduleMode": "调度方式",
|
"scheduleMode": "调度方式",
|
||||||
"scheduleModeManual": "手工执行",
|
"scheduleModeManual": "手工执行",
|
||||||
"scheduleModeCron": "调度表达式(Cron)",
|
"scheduleModeCron": "调度表达式(Cron)",
|
||||||
@@ -2582,8 +2584,8 @@
|
|||||||
"tasksList": "任务列表(每行一个任务)",
|
"tasksList": "任务列表(每行一个任务)",
|
||||||
"tasksListPlaceholder": "请输入任务列表,每行一个任务",
|
"tasksListPlaceholder": "请输入任务列表,每行一个任务",
|
||||||
"tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名",
|
"tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名",
|
||||||
"tasksListHint": "每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。",
|
"tasksListHint": "每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
|
||||||
"tasksListHintFull": "提示:每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。",
|
"tasksListHintFull": "提示:每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
|
||||||
"createQueue": "创建队列"
|
"createQueue": "创建队列"
|
||||||
},
|
},
|
||||||
"batchQueueDetailModal": {
|
"batchQueueDetailModal": {
|
||||||
@@ -2617,6 +2619,8 @@
|
|||||||
"scheduleToggleFailed": "更新调度开关失败",
|
"scheduleToggleFailed": "更新调度开关失败",
|
||||||
"completedAt": "完成时间",
|
"completedAt": "完成时间",
|
||||||
"taskTotal": "任务总数",
|
"taskTotal": "任务总数",
|
||||||
|
"concurrency": "并发数",
|
||||||
|
"concurrencyEditHint": "点击可修改;队列运行中不可改。",
|
||||||
"taskList": "任务列表",
|
"taskList": "任务列表",
|
||||||
"startLabel": "开始",
|
"startLabel": "开始",
|
||||||
"completeLabel": "完成",
|
"completeLabel": "完成",
|
||||||
|
|||||||
+24
-1
@@ -3110,7 +3110,17 @@ async function cancelMCPToolExecutionSubmit(executionId, userNote, options = {})
|
|||||||
if (!executionId) {
|
if (!executionId) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
let conversationId = '';
|
||||||
|
if (typeof monitorState !== 'undefined' && Array.isArray(monitorState.executions)) {
|
||||||
|
const exec = monitorState.executions.find(e => e && e.id === executionId);
|
||||||
|
if (exec) {
|
||||||
|
conversationId = (exec.conversationId || '').trim();
|
||||||
|
}
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
|
if (conversationId && typeof requestCancelWithContinue === 'function') {
|
||||||
|
await requestCancelWithContinue(conversationId, userNote || '');
|
||||||
|
} else {
|
||||||
const res = await apiFetch(`/api/monitor/execution/${encodeURIComponent(executionId)}/cancel`, {
|
const res = await apiFetch(`/api/monitor/execution/${encodeURIComponent(executionId)}/cancel`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
@@ -3120,6 +3130,7 @@ async function cancelMCPToolExecutionSubmit(executionId, userNote, options = {})
|
|||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
throw new Error(body.error || body.message || res.statusText);
|
throw new Error(body.error || body.message || res.statusText);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
const okMsg = typeof window.t === 'function' ? window.t('mcpDetailModal.abortSuccess') : '已发送终止请求';
|
const okMsg = typeof window.t === 'function' ? window.t('mcpDetailModal.abortSuccess') : '已发送终止请求';
|
||||||
alert(okMsg);
|
alert(okMsg);
|
||||||
if (options.refreshDetail && typeof showMCPDetail === 'function') {
|
if (options.refreshDetail && typeof showMCPDetail === 'function') {
|
||||||
@@ -3136,7 +3147,7 @@ async function cancelMCPToolExecutionSubmit(executionId, userNote, options = {})
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 取消单次 MCP 工具执行(监控页「终止」)。弹出说明框后提交;仅取消该次 tools/call,不停止整条对话/迭代任务。
|
* 取消单次 MCP 工具执行(监控页「终止」)。有 conversationId 时复用对话页「中断并继续」弹窗与 API。
|
||||||
* @param {string} executionId
|
* @param {string} executionId
|
||||||
* @param {{ refreshDetail?: boolean }} [options]
|
* @param {{ refreshDetail?: boolean }} [options]
|
||||||
*/
|
*/
|
||||||
@@ -3144,6 +3155,18 @@ async function cancelMCPToolExecution(executionId, options = {}) {
|
|||||||
if (!executionId) {
|
if (!executionId) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
let conversationId = '';
|
||||||
|
if (typeof monitorState !== 'undefined' && Array.isArray(monitorState.executions)) {
|
||||||
|
const exec = monitorState.executions.find(e => e && e.id === executionId);
|
||||||
|
if (exec) {
|
||||||
|
conversationId = (exec.conversationId || '').trim();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (conversationId && typeof openUserInterruptModal === 'function') {
|
||||||
|
openUserInterruptModal(null, conversationId);
|
||||||
|
window.__monitorInterruptContext = { executionId: executionId, options: options || {} };
|
||||||
|
return;
|
||||||
|
}
|
||||||
openMcpToolAbortModal(executionId, options);
|
openMcpToolAbortModal(executionId, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1003,6 +1003,7 @@ function openUserInterruptModal(progressId, conversationId) {
|
|||||||
|
|
||||||
function closeUserInterruptModal() {
|
function closeUserInterruptModal() {
|
||||||
userInterruptModalPending = null;
|
userInterruptModalPending = null;
|
||||||
|
window.__monitorInterruptContext = null;
|
||||||
closeAppModal('user-interrupt-modal');
|
closeAppModal('user-interrupt-modal');
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1012,6 +1013,7 @@ async function submitUserInterruptContinue() {
|
|||||||
}
|
}
|
||||||
const reason = (document.getElementById('user-interrupt-reason') && document.getElementById('user-interrupt-reason').value || '').trim();
|
const reason = (document.getElementById('user-interrupt-reason') && document.getElementById('user-interrupt-reason').value || '').trim();
|
||||||
const { progressId, conversationId } = userInterruptModalPending;
|
const { progressId, conversationId } = userInterruptModalPending;
|
||||||
|
const monitorCtx = window.__monitorInterruptContext;
|
||||||
closeUserInterruptModal();
|
closeUserInterruptModal();
|
||||||
const stopBtn = progressId ? document.getElementById(`${progressId}-stop-btn`) : null;
|
const stopBtn = progressId ? document.getElementById(`${progressId}-stop-btn`) : null;
|
||||||
try {
|
try {
|
||||||
@@ -1020,6 +1022,13 @@ async function submitUserInterruptContinue() {
|
|||||||
stopBtn.textContent = typeof window.t === 'function' ? window.t('tasks.interruptSubmitting') : '提交中...';
|
stopBtn.textContent = typeof window.t === 'function' ? window.t('tasks.interruptSubmitting') : '提交中...';
|
||||||
}
|
}
|
||||||
await requestCancelWithContinue(conversationId, reason);
|
await requestCancelWithContinue(conversationId, reason);
|
||||||
|
if (monitorCtx && monitorCtx.executionId && typeof refreshMonitorPanel === 'function') {
|
||||||
|
const page = (typeof monitorState !== 'undefined' && monitorState.pagination && monitorState.pagination.page)
|
||||||
|
? monitorState.pagination.page
|
||||||
|
: 1;
|
||||||
|
await refreshMonitorPanel(page);
|
||||||
|
window.__monitorInterruptContext = null;
|
||||||
|
}
|
||||||
loadActiveTasks();
|
loadActiveTasks();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('中断并继续失败:', error);
|
console.error('中断并继续失败:', error);
|
||||||
@@ -3536,6 +3545,33 @@ const monitorState = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let monitorPollTimer = null;
|
||||||
|
const MONITOR_POLL_INTERVAL_MS = 3000;
|
||||||
|
|
||||||
|
function startMonitorPoll() {
|
||||||
|
stopMonitorPoll();
|
||||||
|
monitorPollTimer = setInterval(function () {
|
||||||
|
const page = document.getElementById('page-mcp-monitor');
|
||||||
|
if (!page || !page.classList.contains('active')) {
|
||||||
|
stopMonitorPoll();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (document.hidden) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (typeof refreshMonitorPanel === 'function') {
|
||||||
|
refreshMonitorPanel().catch(function () { /* ignore */ });
|
||||||
|
}
|
||||||
|
}, MONITOR_POLL_INTERVAL_MS);
|
||||||
|
}
|
||||||
|
|
||||||
|
function stopMonitorPoll() {
|
||||||
|
if (monitorPollTimer) {
|
||||||
|
clearInterval(monitorPollTimer);
|
||||||
|
monitorPollTimer = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function openMonitorPanel() {
|
function openMonitorPanel() {
|
||||||
// 切换到MCP监控页面
|
// 切换到MCP监控页面
|
||||||
if (typeof switchPage === 'function') {
|
if (typeof switchPage === 'function') {
|
||||||
|
|||||||
@@ -356,6 +356,9 @@ async function initPage(pageId) {
|
|||||||
if (typeof refreshMonitorPanel === 'function') {
|
if (typeof refreshMonitorPanel === 'function') {
|
||||||
refreshMonitorPanel();
|
refreshMonitorPanel();
|
||||||
}
|
}
|
||||||
|
if (typeof startMonitorPoll === 'function') {
|
||||||
|
startMonitorPoll();
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case 'mcp-management':
|
case 'mcp-management':
|
||||||
// 初始化MCP管理
|
// 初始化MCP管理
|
||||||
|
|||||||
@@ -990,6 +990,7 @@ async function createBatchQueue() {
|
|||||||
const roleSelect = document.getElementById('batch-queue-role');
|
const roleSelect = document.getElementById('batch-queue-role');
|
||||||
const projectSelect = document.getElementById('batch-queue-project-id');
|
const projectSelect = document.getElementById('batch-queue-project-id');
|
||||||
const agentModeSelect = document.getElementById('batch-queue-agent-mode');
|
const agentModeSelect = document.getElementById('batch-queue-agent-mode');
|
||||||
|
const concurrencyInput = document.getElementById('batch-queue-concurrency');
|
||||||
const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode');
|
const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode');
|
||||||
const cronExprInput = document.getElementById('batch-queue-cron-expr');
|
const cronExprInput = document.getElementById('batch-queue-cron-expr');
|
||||||
const executeNowCheckbox = document.getElementById('batch-queue-execute-now');
|
const executeNowCheckbox = document.getElementById('batch-queue-execute-now');
|
||||||
@@ -1019,6 +1020,9 @@ async function createBatchQueue() {
|
|||||||
const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual';
|
const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual';
|
||||||
const cronExpr = cronExprInput ? cronExprInput.value.trim() : '';
|
const cronExpr = cronExprInput ? cronExprInput.value.trim() : '';
|
||||||
const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false;
|
const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false;
|
||||||
|
let concurrency = concurrencyInput ? parseInt(concurrencyInput.value, 10) : 1;
|
||||||
|
if (!Number.isFinite(concurrency) || concurrency < 1) concurrency = 1;
|
||||||
|
if (concurrency > 8) concurrency = 8;
|
||||||
if (scheduleMode === 'cron' && !cronExpr) {
|
if (scheduleMode === 'cron' && !cronExpr) {
|
||||||
alert(_t('batchImportModal.cronExprRequired'));
|
alert(_t('batchImportModal.cronExprRequired'));
|
||||||
return;
|
return;
|
||||||
@@ -1043,6 +1047,7 @@ async function createBatchQueue() {
|
|||||||
cronExpr,
|
cronExpr,
|
||||||
executeNow,
|
executeNow,
|
||||||
projectId,
|
projectId,
|
||||||
|
concurrency,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1489,6 +1494,7 @@ async function showBatchQueueDetail(queueId) {
|
|||||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div>
|
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div>
|
||||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div>
|
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div>
|
||||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div>
|
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div>
|
||||||
|
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.concurrency'))}</span><span class="bq-kv__v" id="bq-concurrency-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditConcurrency()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span>` : escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span></div>
|
||||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div>
|
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div>
|
||||||
${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''}
|
${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''}
|
||||||
</section>
|
</section>
|
||||||
@@ -2287,6 +2293,75 @@ async function saveInlineAgentMode() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function normalizeBatchQueueConcurrencyInput(raw) {
|
||||||
|
let n = parseInt(raw, 10);
|
||||||
|
if (!Number.isFinite(n) || n < 1) n = 1;
|
||||||
|
if (n > 8) n = 8;
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 内联编辑:并发数 ---
|
||||||
|
function startInlineEditConcurrency() {
|
||||||
|
const container = document.getElementById('bq-concurrency-val');
|
||||||
|
if (!container) return;
|
||||||
|
const queueId = batchQueuesState.currentQueueId;
|
||||||
|
if (!queueId) return;
|
||||||
|
apiFetch(`/api/batch-tasks/${queueId}`).then(r => r.json()).then(detail => {
|
||||||
|
const queue = detail.queue || {};
|
||||||
|
const current = normalizeBatchQueueConcurrencyInput(queue.concurrency || 1);
|
||||||
|
container.innerHTML = `<span class="bq-inline-edit-controls">
|
||||||
|
<input type="number" id="bq-edit-concurrency" min="1" max="8" value="${current}" style="width:72px;" />
|
||||||
|
</span>`;
|
||||||
|
const inp = document.getElementById('bq-edit-concurrency');
|
||||||
|
if (!inp) return;
|
||||||
|
inp.focus();
|
||||||
|
inp.select();
|
||||||
|
let cancelled = false;
|
||||||
|
inp.addEventListener('keydown', (e) => {
|
||||||
|
if (e.key === 'Enter') { e.preventDefault(); inp.blur(); }
|
||||||
|
if (e.key === 'Escape') { cancelled = true; cancelAllInlineEdits(); }
|
||||||
|
});
|
||||||
|
inp.addEventListener('blur', () => {
|
||||||
|
if (!cancelled) saveInlineConcurrency();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async function saveInlineConcurrency() {
|
||||||
|
if (_bqInlineSaving) return;
|
||||||
|
_bqInlineSaving = true;
|
||||||
|
const queueId = batchQueuesState.currentQueueId;
|
||||||
|
if (!queueId) { _bqInlineSaving = false; return; }
|
||||||
|
const inp = document.getElementById('bq-edit-concurrency');
|
||||||
|
const concurrency = normalizeBatchQueueConcurrencyInput(inp ? inp.value : 1);
|
||||||
|
try {
|
||||||
|
const detailResp = await apiFetch(`/api/batch-tasks/${queueId}`);
|
||||||
|
const detail = await detailResp.json();
|
||||||
|
const q = detail.queue || {};
|
||||||
|
const response = await apiFetch(`/api/batch-tasks/${queueId}/metadata`, {
|
||||||
|
method: 'PUT',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
title: q.title || '',
|
||||||
|
role: q.role || '',
|
||||||
|
agentMode: q.agentMode || 'eino_single',
|
||||||
|
concurrency,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
const result = await response.json().catch(() => ({}));
|
||||||
|
throw new Error(result.error || _t('tasks.updateTaskFailed'));
|
||||||
|
}
|
||||||
|
_bqInlineSaving = false;
|
||||||
|
showBatchQueueDetail(queueId);
|
||||||
|
refreshBatchQueues();
|
||||||
|
} catch (e) {
|
||||||
|
_bqInlineSaving = false;
|
||||||
|
console.error(e);
|
||||||
|
alert(e.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- 单条执行 ---
|
// --- 单条执行 ---
|
||||||
async function runSingleBatchTask(queueId, taskId) {
|
async function runSingleBatchTask(queueId, taskId) {
|
||||||
if (!queueId || !taskId) return;
|
if (!queueId || !taskId) return;
|
||||||
@@ -2441,6 +2516,8 @@ window.startInlineEditRole = startInlineEditRole;
|
|||||||
window.saveInlineRole = saveInlineRole;
|
window.saveInlineRole = saveInlineRole;
|
||||||
window.startInlineEditAgentMode = startInlineEditAgentMode;
|
window.startInlineEditAgentMode = startInlineEditAgentMode;
|
||||||
window.saveInlineAgentMode = saveInlineAgentMode;
|
window.saveInlineAgentMode = saveInlineAgentMode;
|
||||||
|
window.startInlineEditConcurrency = startInlineEditConcurrency;
|
||||||
|
window.saveInlineConcurrency = saveInlineConcurrency;
|
||||||
window.runSingleBatchTask = runSingleBatchTask;
|
window.runSingleBatchTask = runSingleBatchTask;
|
||||||
window.startInlineEditSchedule = startInlineEditSchedule;
|
window.startInlineEditSchedule = startInlineEditSchedule;
|
||||||
window.toggleInlineScheduleCron = toggleInlineScheduleCron;
|
window.toggleInlineScheduleCron = toggleInlineScheduleCron;
|
||||||
|
|||||||
@@ -4010,6 +4010,11 @@
|
|||||||
</select>
|
</select>
|
||||||
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div>
|
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="batch-queue-concurrency" data-i18n="batchImportModal.concurrency">并发数</label>
|
||||||
|
<input type="number" id="batch-queue-concurrency" min="1" max="8" value="1" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;" />
|
||||||
|
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.concurrencyHint">同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。</div>
|
||||||
|
</div>
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label>
|
<label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label>
|
||||||
<select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;">
|
<select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;">
|
||||||
|
|||||||
Reference in New Issue
Block a user