diff --git a/internal/app/app.go b/internal/app/app.go index f1671ffa..1a4c6221 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -404,6 +404,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { } configHandler.SetSkillsToolRegistrar(skillsRegistrar) + handler.RegisterBatchTaskMCPTools(mcpServer, agentHandler, log.Logger) + batchTaskToolRegistrar := func() error { + handler.RegisterBatchTaskMCPTools(mcpServer, agentHandler, log.Logger) + return nil + } + configHandler.SetBatchTaskToolRegistrar(batchTaskToolRegistrar) + // 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置) configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) { knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger) @@ -652,6 +659,7 @@ func setupRoutes( protected.GET("/batch-tasks/:queueId", agentHandler.GetBatchQueue) protected.POST("/batch-tasks/:queueId/start", agentHandler.StartBatchQueue) protected.POST("/batch-tasks/:queueId/pause", agentHandler.PauseBatchQueue) + protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled) protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue) protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask) protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask) @@ -1333,8 +1341,8 @@ func registerWebshellManagementTools(mcpServer *mcp.Server, db *database.DB, web // manage_webshell_add - 添加新的 webshell 连接 addTool := mcp.Tool{ - Name: builtin.ToolManageWebshellAdd, - Description: "添加新的 WebShell 连接到管理系统。支持 PHP、ASP、ASPX、JSP 等类型的一句话木马。", + Name: builtin.ToolManageWebshellAdd, + Description: "添加新的 WebShell 连接到管理系统。支持 PHP、ASP、ASPX、JSP 等类型的一句话木马。", ShortDescription: "添加 WebShell 连接", InputSchema: map[string]interface{}{ "type": "object", @@ -1425,8 +1433,8 @@ func registerWebshellManagementTools(mcpServer *mcp.Server, db *database.DB, web // manage_webshell_update - 更新 webshell 连接 updateTool := mcp.Tool{ - Name: builtin.ToolManageWebshellUpdate, - Description: "更新已存在的 WebShell 连接信息。", + Name: builtin.ToolManageWebshellUpdate, + Description: "更新已存在的 WebShell 连接信息。", ShortDescription: "更新 WebShell 连接", InputSchema: map[string]interface{}{ "type": "object", @@ -1522,8 +1530,8 @@ func registerWebshellManagementTools(mcpServer *mcp.Server, db *database.DB, web // manage_webshell_delete - 删除 webshell 连接 deleteTool := mcp.Tool{ - Name: builtin.ToolManageWebshellDelete, - Description: "删除指定的 WebShell 连接。", + Name: builtin.ToolManageWebshellDelete, + Description: "删除指定的 WebShell 连接。", ShortDescription: "删除 WebShell 连接", InputSchema: map[string]interface{}{ "type": "object", @@ -1564,8 +1572,8 @@ func registerWebshellManagementTools(mcpServer *mcp.Server, db *database.DB, web // manage_webshell_test - 测试 webshell 连接 testTool := mcp.Tool{ - Name: builtin.ToolManageWebshellTest, - Description: "测试指定的 WebShell 连接是否可用,会尝试执行一个简单的命令(如 whoami 或 dir)。", + Name: builtin.ToolManageWebshellTest, + Description: "测试指定的 WebShell 连接是否可用,会尝试执行一个简单的命令(如 whoami 或 dir)。", ShortDescription: "测试 WebShell 连接", InputSchema: map[string]interface{}{ "type": "object", diff --git a/internal/database/batch_task.go b/internal/database/batch_task.go index 32fd8134..98c0ed29 100644 --- a/internal/database/batch_task.go +++ b/internal/database/batch_task.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "strings" "time" "go.uber.org/zap" @@ -10,14 +11,22 @@ import ( // BatchTaskQueueRow 批量任务队列数据库行 type BatchTaskQueueRow struct { - ID string - Title sql.NullString - Role sql.NullString - Status string - CreatedAt time.Time - StartedAt sql.NullTime - CompletedAt sql.NullTime - CurrentIndex int + ID string + Title sql.NullString + Role sql.NullString + AgentMode sql.NullString + ScheduleMode sql.NullString + CronExpr sql.NullString + NextRunAt sql.NullTime + ScheduleEnabled sql.NullInt64 + LastScheduleTriggerAt sql.NullTime + LastScheduleError sql.NullString + LastRunError sql.NullString + Status string + CreatedAt time.Time + StartedAt sql.NullTime + CompletedAt sql.NullTime + CurrentIndex int } // BatchTaskRow 批量任务数据库行 @@ -34,7 +43,16 @@ type BatchTaskRow struct { } // CreateBatchQueue 创建批量任务队列 -func (db *DB) CreateBatchQueue(queueID string, title string, role string, tasks []map[string]interface{}) error { +func (db *DB) CreateBatchQueue( + queueID string, + title string, + role string, + agentMode string, + scheduleMode string, + cronExpr string, + nextRunAt *time.Time, + tasks []map[string]interface{}, +) error { tx, err := db.Begin() if err != nil { return fmt.Errorf("开始事务失败: %w", err) @@ -42,9 +60,14 @@ func (db *DB) CreateBatchQueue(queueID string, title string, role string, tasks defer tx.Rollback() now := time.Now() + var nextRunAtValue interface{} + if nextRunAt != nil { + nextRunAtValue = *nextRunAt + } + _, err = tx.Exec( - "INSERT INTO batch_task_queues (id, title, role, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?)", - queueID, title, role, "pending", now, 0, + "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0, ) if err != nil { return fmt.Errorf("创建批量任务队列失败: %w", err) @@ -60,7 +83,7 @@ func (db *DB) CreateBatchQueue(queueID string, title string, role string, tasks if !ok { continue } - + _, err = tx.Exec( "INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)", taskID, queueID, message, "pending", @@ -78,9 +101,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { var row BatchTaskQueueRow var createdAt string err := db.QueryRow( - "SELECT id, title, role, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", + "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", queueID, - ).Scan(&row.ID, &row.Title, &row.Role, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) + ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) if err == sql.ErrNoRows { return nil, nil } @@ -104,7 +127,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { // GetAllBatchQueues 获取所有批量任务队列 func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { rows, err := db.Query( - "SELECT id, title, role, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", + "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", ) if err != nil { return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) @@ -115,7 +138,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { for rows.Next() { var row BatchTaskQueueRow var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) } parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) @@ -135,7 +158,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { // ListBatchQueues 列出批量任务队列(支持筛选和分页) func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { - query := "SELECT id, title, role, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" + query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" args := []interface{}{} // 状态筛选 @@ -163,7 +186,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat for rows.Next() { var row BatchTaskQueueRow var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) } parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) @@ -237,7 +260,7 @@ func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) { func (db *DB) UpdateBatchQueueStatus(queueID, status string) error { var err error now := time.Now() - + if status == "running" { _, err = db.Exec( "UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?", @@ -254,7 +277,7 @@ func (db *DB) UpdateBatchQueueStatus(queueID, status string) error { status, queueID, ) } - + if err != nil { return fmt.Errorf("更新批量任务队列状态失败: %w", err) } @@ -265,41 +288,41 @@ func (db *DB) UpdateBatchQueueStatus(queueID, status string) error { func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error { var err error now := time.Now() - + // 构建更新语句 var updates []string var args []interface{} - + updates = append(updates, "status = ?") args = append(args, status) - + if conversationID != "" { updates = append(updates, "conversation_id = ?") args = append(args, conversationID) } - + if result != "" { updates = append(updates, "result = ?") args = append(args, result) } - + if errorMsg != "" { updates = append(updates, "error = ?") args = append(args, errorMsg) } - + if status == "running" { updates = append(updates, "started_at = COALESCE(started_at, ?)") args = append(args, now) } - + if status == "completed" || status == "failed" || status == "cancelled" { updates = append(updates, "completed_at = COALESCE(completed_at, ?)") args = append(args, now) } - + args = append(args, queueID, taskID) - + // 构建SQL语句 sql := "UPDATE batch_tasks SET " for i, update := range updates { @@ -309,7 +332,7 @@ func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversation sql += update } sql += " WHERE queue_id = ? AND id = ?" - + _, err = db.Exec(sql, args...) if err != nil { return fmt.Errorf("更新批量任务状态失败: %w", err) @@ -329,6 +352,107 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err return nil } +// UpdateBatchQueueSchedule 更新批量任务队列调度相关信息 +func (db *DB) UpdateBatchQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) error { + var nextRunAtValue interface{} + if nextRunAt != nil { + nextRunAtValue = *nextRunAt + } + _, err := db.Exec( + "UPDATE batch_task_queues SET schedule_mode = ?, cron_expr = ?, next_run_at = ? WHERE id = ?", + scheduleMode, cronExpr, nextRunAtValue, queueID, + ) + if err != nil { + return fmt.Errorf("更新批量任务调度配置失败: %w", err) + } + return nil +} + +// UpdateBatchQueueScheduleEnabled 是否允许 Cron 自动触发(手工「开始执行」不受影响) +func (db *DB) UpdateBatchQueueScheduleEnabled(queueID string, enabled bool) error { + v := 0 + if enabled { + v = 1 + } + _, err := db.Exec( + "UPDATE batch_task_queues SET schedule_enabled = ? WHERE id = ?", + v, queueID, + ) + if err != nil { + return fmt.Errorf("更新批量任务调度开关失败: %w", err) + } + return nil +} + +// RecordBatchQueueScheduledTriggerStart 记录一次由调度触发的开始时间并清空调度层错误 +func (db *DB) RecordBatchQueueScheduledTriggerStart(queueID string, at time.Time) error { + _, err := db.Exec( + "UPDATE batch_task_queues SET last_schedule_trigger_at = ?, last_schedule_error = NULL WHERE id = ?", + at, queueID, + ) + if err != nil { + return fmt.Errorf("记录调度触发时间失败: %w", err) + } + return nil +} + +// SetBatchQueueLastScheduleError 调度启动失败等原因(如状态不允许、重置失败) +func (db *DB) SetBatchQueueLastScheduleError(queueID, msg string) error { + _, err := db.Exec( + "UPDATE batch_task_queues SET last_schedule_error = ? WHERE id = ?", + msg, queueID, + ) + if err != nil { + return fmt.Errorf("写入调度错误信息失败: %w", err) + } + return nil +} + +// SetBatchQueueLastRunError 最近一轮执行中出现的子任务失败摘要(空串表示清空) +func (db *DB) SetBatchQueueLastRunError(queueID, msg string) error { + var v interface{} + if strings.TrimSpace(msg) == "" { + v = nil + } else { + v = msg + } + _, err := db.Exec( + "UPDATE batch_task_queues SET last_run_error = ? WHERE id = ?", + v, queueID, + ) + if err != nil { + return fmt.Errorf("写入最近运行错误失败: %w", err) + } + return nil +} + +// ResetBatchQueueForRerun 重置队列和任务状态用于下一轮调度执行 +func (db *DB) ResetBatchQueueForRerun(queueID string) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("开始事务失败: %w", err) + } + defer tx.Rollback() + + _, err = tx.Exec( + "UPDATE batch_task_queues SET status = ?, current_index = 0, started_at = NULL, completed_at = NULL WHERE id = ?", + "pending", queueID, + ) + if err != nil { + return fmt.Errorf("重置批量任务队列状态失败: %w", err) + } + + _, err = tx.Exec( + "UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ?", + "pending", queueID, + ) + if err != nil { + return fmt.Errorf("重置批量任务状态失败: %w", err) + } + + return tx.Commit() +} + // UpdateBatchTaskMessage 更新批量任务消息 func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error { _, err := db.Exec( @@ -387,4 +511,3 @@ func (db *DB) DeleteBatchQueue(queueID string) error { return tx.Commit() } - diff --git a/internal/database/database.go b/internal/database/database.go index 14a3809a..39593ec4 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -205,6 +205,15 @@ func (db *DB) initTables() error { CREATE TABLE IF NOT EXISTS batch_task_queues ( id TEXT PRIMARY KEY, title TEXT, + role TEXT, + agent_mode TEXT NOT NULL DEFAULT 'single', + schedule_mode TEXT NOT NULL DEFAULT 'manual', + cron_expr TEXT, + next_run_at DATETIME, + schedule_enabled INTEGER NOT NULL DEFAULT 1, + last_schedule_trigger_at DATETIME, + last_schedule_error TEXT, + last_run_error TEXT, status TEXT NOT NULL, created_at DATETIME NOT NULL, started_at DATETIME, @@ -495,7 +504,7 @@ func (db *DB) migrateConversationGroupMappingsTable() error { return nil } -// migrateBatchTaskQueuesTable 迁移batch_task_queues表,添加title和role字段 +// migrateBatchTaskQueuesTable 迁移batch_task_queues表,补充新字段 func (db *DB) migrateBatchTaskQueuesTable() error { // 检查title字段是否存在 var count int @@ -535,6 +544,131 @@ func (db *DB) migrateBatchTaskQueuesTable() error { } } + // 检查agent_mode字段是否存在 + var agentModeCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr)) + } + } + } else if agentModeCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil { + db.logger.Warn("添加agent_mode字段失败", zap.Error(err)) + } + } + + // 检查schedule_mode字段是否存在 + var scheduleModeCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_mode'").Scan(&scheduleModeCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加schedule_mode字段失败", zap.Error(addErr)) + } + } + } else if scheduleModeCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); err != nil { + db.logger.Warn("添加schedule_mode字段失败", zap.Error(err)) + } + } + + // 检查cron_expr字段是否存在 + var cronExprCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='cron_expr'").Scan(&cronExprCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加cron_expr字段失败", zap.Error(addErr)) + } + } + } else if cronExprCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); err != nil { + db.logger.Warn("添加cron_expr字段失败", zap.Error(err)) + } + } + + // 检查next_run_at字段是否存在 + var nextRunAtCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='next_run_at'").Scan(&nextRunAtCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加next_run_at字段失败", zap.Error(addErr)) + } + } + } else if nextRunAtCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); err != nil { + db.logger.Warn("添加next_run_at字段失败", zap.Error(err)) + } + } + + // schedule_enabled:0=暂停 Cron 自动调度,1=允许(手工执行不受影响) + var scheduleEnCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_enabled'").Scan(&scheduleEnCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加schedule_enabled字段失败", zap.Error(addErr)) + } + } + } else if scheduleEnCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); err != nil { + db.logger.Warn("添加schedule_enabled字段失败", zap.Error(err)) + } + } + + var lastTrigCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_trigger_at'").Scan(&lastTrigCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(addErr)) + } + } + } else if lastTrigCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); err != nil { + db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(err)) + } + } + + var lastSchedErrCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_error'").Scan(&lastSchedErrCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_schedule_error字段失败", zap.Error(addErr)) + } + } + } else if lastSchedErrCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); err != nil { + db.logger.Warn("添加last_schedule_error字段失败", zap.Error(err)) + } + } + + var lastRunErrCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_run_error'").Scan(&lastRunErrCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_run_error字段失败", zap.Error(addErr)) + } + } + } else if lastRunErrCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); err != nil { + db.logger.Warn("添加last_run_error字段失败", zap.Error(err)) + } + } + return nil } diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 256bb6cd..e63e0c86 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -24,6 +24,7 @@ import ( "cyberstrike-ai/internal/skills" "github.com/gin-gonic/gin" + "github.com/robfig/cron/v3" "go.uber.org/zap" ) @@ -81,6 +82,9 @@ type AgentHandler struct { } skillsManager *skills.Manager // Skills管理器 agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) + batchCronParser cron.Parser + batchRunnerMu sync.Mutex + batchRunning map[string]struct{} } // NewAgentHandler 创建新的Agent处理器 @@ -93,14 +97,18 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo logger.Warn("从数据库加载批量任务队列失败", zap.Error(err)) } - return &AgentHandler{ + handler := &AgentHandler{ agent: agent, db: db, logger: logger, tasks: NewAgentTaskManager(), batchTaskManager: batchTaskManager, config: cfg, + batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), + batchRunning: make(map[string]struct{}), } + go handler.batchQueueSchedulerLoop() + return handler } // SetKnowledgeManager 设置知识库管理器(用于记录检索日志) @@ -1575,9 +1583,26 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) { // BatchTaskRequest 批量任务请求 type BatchTaskRequest struct { - Title string `json:"title"` // 任务标题(可选) - Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务 - Role string `json:"role,omitempty"` // 角色名称(可选,空字符串表示默认角色) + Title string `json:"title"` // 任务标题(可选) + Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务 + Role string `json:"role,omitempty"` // 角色名称(可选,空字符串表示默认角色) + AgentMode string `json:"agentMode,omitempty"` // single | multi + ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron + CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 +} + +func normalizeBatchQueueAgentMode(mode string) string { + if strings.TrimSpace(mode) == "multi" { + return "multi" + } + return "single" +} + +func normalizeBatchQueueScheduleMode(mode string) string { + if strings.TrimSpace(mode) == "cron" { + return "cron" + } + return "manual" } // CreateBatchQueue 创建批量任务队列 @@ -1606,7 +1631,25 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) { return } - queue := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, validTasks) + agentMode := normalizeBatchQueueAgentMode(req.AgentMode) + scheduleMode := normalizeBatchQueueScheduleMode(req.ScheduleMode) + cronExpr := strings.TrimSpace(req.CronExpr) + var nextRunAt *time.Time + if scheduleMode == "cron" { + if cronExpr == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "启用 Cron 调度时,调度表达式不能为空"}) + return + } + schedule, err := h.batchCronParser.Parse(cronExpr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 Cron 表达式: " + err.Error()}) + return + } + next := schedule.Next(time.Now()) + nextRunAt = &next + } + + queue := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks) c.JSON(http.StatusOK, gin.H{ "queueId": queue.ID, "queue": queue, @@ -1699,21 +1742,15 @@ func (h *AgentHandler) ListBatchQueues(c *gin.Context) { // StartBatchQueue 开始执行批量任务队列 func (h *AgentHandler) StartBatchQueue(c *gin.Context) { queueID := c.Param("queueId") - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists { + ok, err := h.startBatchQueueExecution(queueID, false) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if !ok { c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) return } - - if queue.Status != "pending" && queue.Status != "paused" { - c.JSON(http.StatusBadRequest, gin.H{"error": "队列状态不允许启动"}) - return - } - - // 在后台执行批量任务 - go h.executeBatchQueue(queueID) - - h.batchTaskManager.UpdateQueueStatus(queueID, "running") c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID}) } @@ -1728,6 +1765,28 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"}) } +// SetBatchQueueScheduleEnabled 开启/关闭 Cron 自动调度(手工执行不受影响) +func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) { + queueID := c.Param("queueId") + if _, exists := h.batchTaskManager.GetBatchQueue(queueID); !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + var req struct { + ScheduleEnabled bool `json:"scheduleEnabled"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if !h.batchTaskManager.SetScheduleEnabled(queueID, req.ScheduleEnabled) { + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + return + } + queue, _ := h.batchTaskManager.GetBatchQueue(queueID) + c.JSON(http.StatusOK, gin.H{"queue": queue}) +} + // DeleteBatchQueue 删除批量任务队列 func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { queueID := c.Param("queueId") @@ -1824,8 +1883,125 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) } +func (h *AgentHandler) markBatchQueueRunning(queueID string) bool { + h.batchRunnerMu.Lock() + defer h.batchRunnerMu.Unlock() + if _, exists := h.batchRunning[queueID]; exists { + return false + } + h.batchRunning[queueID] = struct{}{} + return true +} + +func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) { + h.batchRunnerMu.Lock() + defer h.batchRunnerMu.Unlock() + delete(h.batchRunning, queueID) +} + +func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { + expr := strings.TrimSpace(cronExpr) + if expr == "" { + return nil, nil + } + schedule, err := h.batchCronParser.Parse(expr) + if err != nil { + return nil, err + } + next := schedule.Next(from) + return &next, nil +} + +func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + return false, nil + } + if !h.markBatchQueueRunning(queueID) { + return true, nil + } + + if scheduled { + if queue.ScheduleMode != "cron" { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("队列未启用 cron 调度") + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + return true, err + } + if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("当前队列状态不允许被调度执行") + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + return true, err + } + if !h.batchTaskManager.ResetQueueForRerun(queueID) { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("重置队列失败") + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + return true, err + } + queue, _ = h.batchTaskManager.GetBatchQueue(queueID) + } else if queue.Status != "pending" && queue.Status != "paused" { + h.unmarkBatchQueueRunning(queueID) + return true, fmt.Errorf("队列状态不允许启动") + } + + if queue != nil && queue.AgentMode == "multi" && (h.config == nil || !h.config.MultiAgent.Enabled) { + h.unmarkBatchQueueRunning(queueID) + err := fmt.Errorf("当前队列配置为多代理,但系统未启用多代理") + if scheduled { + h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) + } + return true, err + } + + if scheduled { + h.batchTaskManager.RecordScheduledRunStart(queueID) + } + h.batchTaskManager.UpdateQueueStatus(queueID, "running") + if queue != nil && queue.ScheduleMode == "cron" { + nextRunAt, err := h.nextBatchQueueRunAt(queue.CronExpr, time.Now()) + if err == nil { + h.batchTaskManager.UpdateQueueSchedule(queueID, "cron", queue.CronExpr, nextRunAt) + } + } + + go h.executeBatchQueue(queueID) + return true, nil +} + +func (h *AgentHandler) batchQueueSchedulerLoop() { + ticker := time.NewTicker(20 * time.Second) + defer ticker.Stop() + for range ticker.C { + queues := h.batchTaskManager.GetAllQueues() + now := time.Now() + for _, queue := range queues { + if queue == nil || queue.ScheduleMode != "cron" || !queue.ScheduleEnabled || queue.Status == "cancelled" || queue.Status == "running" || queue.Status == "paused" { + continue + } + nextRunAt := queue.NextRunAt + if nextRunAt == nil { + next, err := h.nextBatchQueueRunAt(queue.CronExpr, now) + if err != nil { + h.logger.Warn("批量任务 cron 表达式无效,跳过调度", zap.String("queueId", queue.ID), zap.String("cronExpr", queue.CronExpr), zap.Error(err)) + continue + } + h.batchTaskManager.UpdateQueueSchedule(queue.ID, "cron", queue.CronExpr, next) + nextRunAt = next + } + if nextRunAt != nil && (nextRunAt.Before(now) || nextRunAt.Equal(now)) { + if _, err := h.startBatchQueueExecution(queue.ID, true); err != nil { + h.logger.Warn("自动调度批量任务失败", zap.String("queueId", queue.ID), zap.Error(err)) + } + } + } + } +} + // executeBatchQueue 执行批量任务队列 func (h *AgentHandler) executeBatchQueue(queueID string) { + defer h.unmarkBatchQueueRunning(queueID) h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID)) for { @@ -1838,7 +2014,17 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { // 获取下一个任务 task, hasNext := h.batchTaskManager.GetNextTask(queueID) if !hasNext { - // 所有任务完成 + // 所有任务完成:汇总子任务失败信息便于排障 + q, ok := h.batchTaskManager.GetBatchQueue(queueID) + lastRunErr := "" + if ok { + for _, t := range q.Tasks { + if t.Status == "failed" && t.Error != "" { + lastRunErr = t.Error + } + } + } + h.batchTaskManager.SetLastRunError(queueID, lastRunErr) h.batchTaskManager.UpdateQueueStatus(queueID, "completed") h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID)) break @@ -1918,7 +2104,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { h.batchTaskManager.SetTaskCancel(queueID, cancel) // 使用队列配置的角色工具列表(如果为空,表示使用所有工具) // 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills - useBatchMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent + useBatchMulti := false + if queue.AgentMode == "multi" { + useBatchMulti = h.config != nil && h.config.MultiAgent.Enabled + } else if queue.AgentMode == "" { + // 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关 + useBatchMulti = h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent + } var result *agent.AgentLoopResult var resultMA *multiagent.RunResult var runErr error diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go index e0e6cbb7..8701476d 100644 --- a/internal/handler/batch_task_manager.go +++ b/internal/handler/batch_task_manager.go @@ -27,24 +27,32 @@ type BatchTask struct { // BatchTaskQueue 批量任务队列 type BatchTaskQueue struct { - ID string `json:"id"` - Title string `json:"title,omitempty"` - Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色) - Tasks []*BatchTask `json:"tasks"` - Status string `json:"status"` // pending, running, paused, completed, cancelled - CreatedAt time.Time `json:"createdAt"` - StartedAt *time.Time `json:"startedAt,omitempty"` - CompletedAt *time.Time `json:"completedAt,omitempty"` - CurrentIndex int `json:"currentIndex"` - mu sync.RWMutex + ID string `json:"id"` + Title string `json:"title,omitempty"` + Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色) + AgentMode string `json:"agentMode"` // single | multi + ScheduleMode string `json:"scheduleMode"` // manual | cron + CronExpr string `json:"cronExpr,omitempty"` + NextRunAt *time.Time `json:"nextRunAt,omitempty"` + ScheduleEnabled bool `json:"scheduleEnabled"` + LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` + LastScheduleError string `json:"lastScheduleError,omitempty"` + LastRunError string `json:"lastRunError,omitempty"` + Tasks []*BatchTask `json:"tasks"` + Status string `json:"status"` // pending, running, paused, completed, cancelled + CreatedAt time.Time `json:"createdAt"` + StartedAt *time.Time `json:"startedAt,omitempty"` + CompletedAt *time.Time `json:"completedAt,omitempty"` + CurrentIndex int `json:"currentIndex"` + mu sync.RWMutex } // BatchTaskManager 批量任务管理器 type BatchTaskManager struct { - db *database.DB - queues map[string]*BatchTaskQueue - taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 - mu sync.RWMutex + db *database.DB + queues map[string]*BatchTaskQueue + taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 + mu sync.RWMutex } // NewBatchTaskManager 创建批量任务管理器 @@ -63,19 +71,32 @@ func (m *BatchTaskManager) SetDB(db *database.DB) { } // CreateBatchQueue 创建批量任务队列 -func (m *BatchTaskManager) CreateBatchQueue(title, role string, tasks []string) *BatchTaskQueue { +func (m *BatchTaskManager) CreateBatchQueue( + title, role, agentMode, scheduleMode, cronExpr string, + nextRunAt *time.Time, + tasks []string, +) *BatchTaskQueue { m.mu.Lock() defer m.mu.Unlock() queueID := time.Now().Format("20060102150405") + "-" + generateShortID() queue := &BatchTaskQueue{ - ID: queueID, - Title: title, - Role: role, - Tasks: make([]*BatchTask, 0, len(tasks)), - Status: "pending", - CreatedAt: time.Now(), - CurrentIndex: 0, + ID: queueID, + Title: title, + Role: role, + AgentMode: normalizeBatchQueueAgentMode(agentMode), + ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode), + CronExpr: strings.TrimSpace(cronExpr), + NextRunAt: nextRunAt, + ScheduleEnabled: true, + Tasks: make([]*BatchTask, 0, len(tasks)), + Status: "pending", + CreatedAt: time.Now(), + CurrentIndex: 0, + } + if queue.ScheduleMode != "cron" { + queue.CronExpr = "" + queue.NextRunAt = nil } // 准备数据库保存的任务数据 @@ -100,7 +121,16 @@ func (m *BatchTaskManager) CreateBatchQueue(title, role string, tasks []string) // 保存到数据库 if m.db != nil { - if err := m.db.CreateBatchQueue(queueID, title, role, dbTasks); err != nil { + if err := m.db.CreateBatchQueue( + queueID, + title, + role, + queue.AgentMode, + queue.ScheduleMode, + queue.CronExpr, + queue.NextRunAt, + dbTasks, + ); err != nil { // 如果数据库保存失败,记录错误但继续(使用内存缓存) // 这里可以添加日志记录 } @@ -151,6 +181,8 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { queue := &BatchTaskQueue{ ID: queueRow.ID, + AgentMode: "single", + ScheduleMode: "manual", Status: queueRow.Status, CreatedAt: queueRow.CreatedAt, CurrentIndex: queueRow.CurrentIndex, @@ -163,6 +195,33 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { if queueRow.Role.Valid { queue.Role = queueRow.Role.String } + if queueRow.AgentMode.Valid { + queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String) + } + if queueRow.ScheduleMode.Valid { + queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String) + } + if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" { + queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String) + } + if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" { + t := queueRow.NextRunAt.Time + queue.NextRunAt = &t + } + queue.ScheduleEnabled = true + if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 { + queue.ScheduleEnabled = false + } + if queueRow.LastScheduleTriggerAt.Valid { + t := queueRow.LastScheduleTriggerAt.Time + queue.LastScheduleTriggerAt = &t + } + if queueRow.LastScheduleError.Valid { + queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String) + } + if queueRow.LastRunError.Valid { + queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) + } if queueRow.StartedAt.Valid { queue.StartedAt = &queueRow.StartedAt.Time } @@ -347,6 +406,8 @@ func (m *BatchTaskManager) LoadFromDB() error { queue := &BatchTaskQueue{ ID: queueRow.ID, + AgentMode: "single", + ScheduleMode: "manual", Status: queueRow.Status, CreatedAt: queueRow.CreatedAt, CurrentIndex: queueRow.CurrentIndex, @@ -359,6 +420,33 @@ func (m *BatchTaskManager) LoadFromDB() error { if queueRow.Role.Valid { queue.Role = queueRow.Role.String } + if queueRow.AgentMode.Valid { + queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String) + } + if queueRow.ScheduleMode.Valid { + queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String) + } + if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" { + queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String) + } + if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" { + t := queueRow.NextRunAt.Time + queue.NextRunAt = &t + } + queue.ScheduleEnabled = true + if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 { + queue.ScheduleEnabled = false + } + if queueRow.LastScheduleTriggerAt.Valid { + t := queueRow.LastScheduleTriggerAt.Time + queue.LastScheduleTriggerAt = &t + } + if queueRow.LastScheduleError.Valid { + queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String) + } + if queueRow.LastRunError.Valid { + queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) + } if queueRow.StartedAt.Valid { queue.StartedAt = &queueRow.StartedAt.Time } @@ -469,6 +557,127 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { } } +// UpdateQueueSchedule 更新队列调度配置 +func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + + queue.ScheduleMode = normalizeBatchQueueScheduleMode(scheduleMode) + if queue.ScheduleMode == "cron" { + queue.CronExpr = strings.TrimSpace(cronExpr) + queue.NextRunAt = nextRunAt + } else { + queue.CronExpr = "" + queue.NextRunAt = nil + } + + if m.db != nil { + if err := m.db.UpdateBatchQueueSchedule(queueID, queue.ScheduleMode, queue.CronExpr, queue.NextRunAt); err != nil { + // 记录错误但继续(使用内存缓存) + } + } +} + +// SetScheduleEnabled 暂停/恢复 Cron 自动调度(不影响手工执行) +func (m *BatchTaskManager) SetScheduleEnabled(queueID string, enabled bool) bool { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return false + } + queue.ScheduleEnabled = enabled + if m.db != nil { + _ = m.db.UpdateBatchQueueScheduleEnabled(queueID, enabled) + } + return true +} + +// RecordScheduledRunStart Cron 触发成功、即将执行子任务时调用 +func (m *BatchTaskManager) RecordScheduledRunStart(queueID string) { + now := time.Now() + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + queue.LastScheduleTriggerAt = &now + queue.LastScheduleError = "" + if m.db != nil { + _ = m.db.RecordBatchQueueScheduledTriggerStart(queueID, now) + } +} + +// SetLastScheduleError 调度层失败(未成功开始执行) +func (m *BatchTaskManager) SetLastScheduleError(queueID, msg string) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + queue.LastScheduleError = strings.TrimSpace(msg) + if m.db != nil { + _ = m.db.SetBatchQueueLastScheduleError(queueID, queue.LastScheduleError) + } +} + +// SetLastRunError 最近一轮批量执行中的失败摘要 +func (m *BatchTaskManager) SetLastRunError(queueID, msg string) { + msg = strings.TrimSpace(msg) + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return + } + queue.LastRunError = msg + if m.db != nil { + _ = m.db.SetBatchQueueLastRunError(queueID, msg) + } +} + +// ResetQueueForRerun 重置队列与子任务状态,供 cron 下一轮执行 +func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists { + return false + } + queue.Status = "pending" + queue.CurrentIndex = 0 + queue.StartedAt = nil + queue.CompletedAt = nil + queue.NextRunAt = nil + for _, task := range queue.Tasks { + task.Status = "pending" + task.ConversationID = "" + task.StartedAt = nil + task.CompletedAt = nil + task.Error = "" + task.Result = "" + } + + if m.db != nil { + if err := m.db.ResetBatchQueueForRerun(queueID); err != nil { + return false + } + } + return true +} + // UpdateTaskMessage 更新任务消息(仅限待执行状态) func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error { m.mu.Lock() diff --git a/internal/handler/batch_task_mcp.go b/internal/handler/batch_task_mcp.go new file mode 100644 index 00000000..2463750a --- /dev/null +++ b/internal/handler/batch_task_mcp.go @@ -0,0 +1,533 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler) +func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) { + if mcpServer == nil || h == nil || logger == nil { + return + } + + reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) { + mcpServer.RegisterTool(tool, fn) + } + + // --- list --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskList, + Description: "列出批量任务队列,支持按状态筛选与关键字搜索。用于查看队列 id、状态、进度及 Cron 配置等。", + ShortDescription: "列出批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "status": map[string]interface{}{ + "type": "string", + "description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled", + "enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"}, + }, + "keyword": map[string]interface{}{ + "type": "string", + "description": "按队列 ID 或标题模糊搜索", + }, + "page": map[string]interface{}{ + "type": "integer", + "description": "页码,从 1 开始,默认 1", + }, + "page_size": map[string]interface{}{ + "type": "integer", + "description": "每页条数,默认 20,最大 100", + }, + }, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + status := mcpArgString(args, "status") + if status == "" { + status = "all" + } + keyword := mcpArgString(args, "keyword") + page := int(mcpArgFloat(args, "page")) + if page <= 0 { + page = 1 + } + pageSize := int(mcpArgFloat(args, "page_size")) + if pageSize <= 0 { + pageSize = 20 + } + if pageSize > 100 { + pageSize = 100 + } + offset := (page - 1) * pageSize + queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword) + if err != nil { + return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil + } + totalPages := (total + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + payload := map[string]interface{}{ + "queues": queues, + "total": total, + "page": page, + "page_size": pageSize, + "total_pages": totalPages, + } + logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total)) + return batchMCPJSONResult(payload) + }) + + // --- get --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskGet, + Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。", + ShortDescription: "获取批量任务队列详情", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + queue, ok := h.batchTaskManager.GetBatchQueue(qid) + if !ok { + return batchMCPTextResult("队列不存在: "+qid, true), nil + } + return batchMCPJSONResult(queue) + }) + + // --- create --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskCreate, + Description: `创建新的批量任务队列。任务列表使用 tasks(字符串数组)或 tasks_text(多行,每行一条)。 +agent_mode: single(默认)或 multi(需系统启用多代理)。schedule_mode: manual(默认)或 cron;为 cron 时必须提供 cron_expr(如 "0 */6 * * *")。 +重要:创建成功后队列处于 pending,不会自动开始跑子任务。若要立即执行或手工开跑,必须再调用工具 batch_task_start(传入返回的 queue_id)。Cron 队列若需按表达式自动触发下一轮,还需保持调度开关开启(可用 batch_task_schedule_enabled)。`, + ShortDescription: "创建批量任务队列(创建后需 batch_task_start 才会执行)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "title": map[string]interface{}{ + "type": "string", + "description": "可选标题", + }, + "role": map[string]interface{}{ + "type": "string", + "description": "角色名称,空表示默认", + }, + "tasks": map[string]interface{}{ + "type": "array", + "description": "任务指令列表,每项一条", + "items": map[string]interface{}{"type": "string"}, + }, + "tasks_text": map[string]interface{}{ + "type": "string", + "description": "多行文本,每行一条任务(与 tasks 二选一)", + }, + "agent_mode": map[string]interface{}{ + "type": "string", + "description": "single 或 multi", + "enum": []string{"single", "multi"}, + }, + "schedule_mode": map[string]interface{}{ + "type": "string", + "description": "manual 或 cron", + "enum": []string{"manual", "cron"}, + }, + "cron_expr": map[string]interface{}{ + "type": "string", + "description": "schedule_mode 为 cron 时必填", + }, + }, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + tasks, errMsg := batchMCPTasksFromArgs(args) + if errMsg != "" { + return batchMCPTextResult(errMsg, true), nil + } + title := mcpArgString(args, "title") + role := mcpArgString(args, "role") + agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode")) + scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode")) + cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr")) + var nextRunAt *time.Time + if scheduleMode == "cron" { + if cronExpr == "" { + return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil + } + sch, err := h.batchCronParser.Parse(cronExpr) + if err != nil { + return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil + } + n := sch.Next(time.Now()) + nextRunAt = &n + } + queue := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks) + logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks))) + return batchMCPJSONResult(map[string]interface{}{ + "queue_id": queue.ID, + "queue": queue, + "reminder": "队列已创建,当前为 pending。需要开始执行时请调用 MCP工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。", + }) + }) + + // --- start --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskStart, + Description: `启动或继续执行批量任务队列(pending / paused)。 +与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。`, + ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + ok, err := h.startBatchQueueExecution(qid, false) + if !ok { + return batchMCPTextResult("队列不存在: "+qid, true), nil + } + if err != nil { + return batchMCPTextResult("启动失败: "+err.Error(), true), nil + } + logger.Info("MCP batch_task_start", zap.String("queueId", qid)) + return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil + }) + + // --- pause --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskPause, + Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。", + ShortDescription: "暂停批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + if !h.batchTaskManager.PauseQueue(qid) { + return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil + } + logger.Info("MCP batch_task_pause", zap.String("queueId", qid)) + return batchMCPTextResult("队列已暂停。", false), nil + }) + + // --- delete queue --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskDelete, + Description: "删除批量任务队列及其子任务记录。", + ShortDescription: "删除批量任务队列", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + }, + "required": []string{"queue_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + if !h.batchTaskManager.DeleteQueue(qid) { + return batchMCPTextResult("删除失败:队列不存在", true), nil + } + logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) + return batchMCPTextResult("队列已删除。", false), nil + }) + + // --- schedule enabled --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskScheduleEnabled, + Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。 +仅对 schedule_mode 为 cron 的队列有意义。`, + ShortDescription: "开关批量任务 Cron 自动调度", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "schedule_enabled": map[string]interface{}{ + "type": "boolean", + "description": "true 允许定时触发,false 仅手工执行", + }, + }, + "required": []string{"queue_id", "schedule_enabled"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + if qid == "" { + return batchMCPTextResult("queue_id 不能为空", true), nil + } + en, ok := mcpArgBool(args, "schedule_enabled") + if !ok { + return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil + } + if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists { + return batchMCPTextResult("队列不存在", true), nil + } + if !h.batchTaskManager.SetScheduleEnabled(qid, en) { + return batchMCPTextResult("更新失败", true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en)) + return batchMCPJSONResult(queue) + }) + + // --- add task --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskAdd, + Description: "向处于 pending 状态的队列追加一条子任务。", + ShortDescription: "批量队列添加子任务", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "任务指令内容", + }, + }, + "required": []string{"queue_id", "message"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + msg := strings.TrimSpace(mcpArgString(args, "message")) + if qid == "" || msg == "" { + return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil + } + task, err := h.batchTaskManager.AddTaskToQueue(qid, msg) + if err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID)) + return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue}) + }) + + // --- update task --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskUpdate, + Description: "修改 pending 队列中仍为 pending 的子任务文案。", + ShortDescription: "更新批量子任务内容", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "task_id": map[string]interface{}{ + "type": "string", + "description": "子任务 ID", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "新的任务指令", + }, + }, + "required": []string{"queue_id", "task_id", "message"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + tid := mcpArgString(args, "task_id") + msg := strings.TrimSpace(mcpArgString(args, "message")) + if qid == "" || tid == "" || msg == "" { + return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil + } + if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid)) + return batchMCPJSONResult(queue) + }) + + // --- remove task --- + reg(mcp.Tool{ + Name: builtin.ToolBatchTaskRemove, + Description: "从 pending 队列中删除仍为 pending 的子任务。", + ShortDescription: "删除批量子任务", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "queue_id": map[string]interface{}{ + "type": "string", + "description": "队列 ID", + }, + "task_id": map[string]interface{}{ + "type": "string", + "description": "子任务 ID", + }, + }, + "required": []string{"queue_id", "task_id"}, + }, + }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + qid := mcpArgString(args, "queue_id") + tid := mcpArgString(args, "task_id") + if qid == "" || tid == "" { + return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil + } + if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil { + return batchMCPTextResult(err.Error(), true), nil + } + queue, _ := h.batchTaskManager.GetBatchQueue(qid) + logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid)) + return batchMCPJSONResult(queue) + }) + + logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 10)) +} + +func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult { + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: text}}, + IsError: isErr, + } +} + +func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) { + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil + } + return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil +} + +func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) { + if raw, ok := args["tasks"]; ok && raw != nil { + switch t := raw.(type) { + case []interface{}: + out := make([]string, 0, len(t)) + for _, x := range t { + if s, ok := x.(string); ok { + if tr := strings.TrimSpace(s); tr != "" { + out = append(out, tr) + } + } + } + if len(out) > 0 { + return out, "" + } + } + } + if txt := mcpArgString(args, "tasks_text"); txt != "" { + lines := strings.Split(txt, "\n") + out := make([]string, 0, len(lines)) + for _, line := range lines { + if tr := strings.TrimSpace(line); tr != "" { + out = append(out, tr) + } + } + if len(out) > 0 { + return out, "" + } + } + return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)" +} + +func mcpArgString(args map[string]interface{}, key string) string { + v, ok := args[key] + if !ok || v == nil { + return "" + } + switch t := v.(type) { + case string: + return strings.TrimSpace(t) + case float64: + return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64)) + case json.Number: + return strings.TrimSpace(t.String()) + default: + return strings.TrimSpace(fmt.Sprint(t)) + } +} + +func mcpArgFloat(args map[string]interface{}, key string) float64 { + v, ok := args[key] + if !ok || v == nil { + return 0 + } + switch t := v.(type) { + case float64: + return t + case int: + return float64(t) + case int64: + return float64(t) + case json.Number: + f, _ := t.Float64() + return f + case string: + f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64) + return f + default: + return 0 + } +} + +func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) { + v, exists := args[key] + if !exists { + return false, false + } + switch t := v.(type) { + case bool: + return t, true + case string: + s := strings.ToLower(strings.TrimSpace(t)) + if s == "true" || s == "1" || s == "yes" { + return true, true + } + if s == "false" || s == "0" || s == "no" { + return false, true + } + case float64: + return t != 0, true + } + return false, false +} diff --git a/internal/handler/config.go b/internal/handler/config.go index 414d6ee8..22ee983a 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -37,6 +37,9 @@ type WebshellToolRegistrar func() error // SkillsToolRegistrar Skills工具注册器接口 type SkillsToolRegistrar func() error +// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册) +type BatchTaskToolRegistrar func() error + // RetrieverUpdater 检索器更新接口 type RetrieverUpdater interface { UpdateConfig(config *knowledge.RetrievalConfig) @@ -68,6 +71,7 @@ type ConfigHandler struct { vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选) webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选) skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选) + batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选) retrieverUpdater RetrieverUpdater // 检索器更新器(可选) knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) appUpdater AppUpdater // App更新器(可选) @@ -141,6 +145,13 @@ func (h *ConfigHandler) SetSkillsToolRegistrar(registrar SkillsToolRegistrar) { h.skillsToolRegistrar = registrar } +// SetBatchTaskToolRegistrar 设置批量任务 MCP 工具注册器 +func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.batchTaskToolRegistrar = registrar +} + // SetRetrieverUpdater 设置检索器更新器 func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { h.mu.Lock() @@ -999,6 +1010,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { } } + // 重新注册批量任务 MCP 工具 + if h.batchTaskToolRegistrar != nil { + h.logger.Info("重新注册批量任务 MCP 工具") + if err := h.batchTaskToolRegistrar(); err != nil { + h.logger.Error("重新注册批量任务 MCP 工具失败", zap.Error(err)) + } else { + h.logger.Info("批量任务 MCP 工具已重新注册") + } + } + // 如果知识库启用,重新注册知识库工具 if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { h.logger.Info("重新注册知识库工具") diff --git a/internal/mcp/builtin/constants.go b/internal/mcp/builtin/constants.go index 1ef554ae..3b4d37b9 100644 --- a/internal/mcp/builtin/constants.go +++ b/internal/mcp/builtin/constants.go @@ -11,14 +11,14 @@ const ( ToolSearchKnowledgeBase = "search_knowledge_base" // Skills工具 - ToolListSkills = "list_skills" - ToolReadSkill = "read_skill" + ToolListSkills = "list_skills" + ToolReadSkill = "read_skill" // WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用) - ToolWebshellExec = "webshell_exec" - ToolWebshellFileList = "webshell_file_list" - ToolWebshellFileRead = "webshell_file_read" - ToolWebshellFileWrite = "webshell_file_write" + ToolWebshellExec = "webshell_exec" + ToolWebshellFileList = "webshell_file_list" + ToolWebshellFileRead = "webshell_file_read" + ToolWebshellFileWrite = "webshell_file_write" // WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接) ToolManageWebshellList = "manage_webshell_list" @@ -26,6 +26,18 @@ const ( ToolManageWebshellUpdate = "manage_webshell_update" ToolManageWebshellDelete = "manage_webshell_delete" ToolManageWebshellTest = "manage_webshell_test" + + // 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列) + ToolBatchTaskList = "batch_task_list" + ToolBatchTaskGet = "batch_task_get" + ToolBatchTaskCreate = "batch_task_create" + ToolBatchTaskStart = "batch_task_start" + ToolBatchTaskPause = "batch_task_pause" + ToolBatchTaskDelete = "batch_task_delete" + ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled" + ToolBatchTaskAdd = "batch_task_add_task" + ToolBatchTaskUpdate = "batch_task_update_task" + ToolBatchTaskRemove = "batch_task_remove_task" ) // IsBuiltinTool 检查工具名称是否是内置工具 @@ -44,7 +56,17 @@ func IsBuiltinTool(toolName string) bool { ToolManageWebshellAdd, ToolManageWebshellUpdate, ToolManageWebshellDelete, - ToolManageWebshellTest: + ToolManageWebshellTest, + ToolBatchTaskList, + ToolBatchTaskGet, + ToolBatchTaskCreate, + ToolBatchTaskStart, + ToolBatchTaskPause, + ToolBatchTaskDelete, + ToolBatchTaskScheduleEnabled, + ToolBatchTaskAdd, + ToolBatchTaskUpdate, + ToolBatchTaskRemove: return true default: return false @@ -68,5 +90,15 @@ func GetAllBuiltinTools() []string { ToolManageWebshellUpdate, ToolManageWebshellDelete, ToolManageWebshellTest, + ToolBatchTaskList, + ToolBatchTaskGet, + ToolBatchTaskCreate, + ToolBatchTaskStart, + ToolBatchTaskPause, + ToolBatchTaskDelete, + ToolBatchTaskScheduleEnabled, + ToolBatchTaskAdd, + ToolBatchTaskUpdate, + ToolBatchTaskRemove, } }