package handler import ( "context" "encoding/json" "fmt" "strconv" "strings" "time" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" "go.uber.org/zap" ) // RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler) func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) { if mcpServer == nil || h == nil || logger == nil { return } reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) { mcpServer.RegisterTool(tool, fn) } // --- list --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskList, Description: "列出批量任务队列,支持按状态筛选与关键字搜索。用于查看队列 id、状态、进度及 Cron 配置等。", ShortDescription: "列出批量任务队列", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "status": map[string]interface{}{ "type": "string", "description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled", "enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"}, }, "keyword": map[string]interface{}{ "type": "string", "description": "按队列 ID 或标题模糊搜索", }, "page": map[string]interface{}{ "type": "integer", "description": "页码,从 1 开始,默认 1", }, "page_size": map[string]interface{}{ "type": "integer", "description": "每页条数,默认 20,最大 100", }, }, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { status := mcpArgString(args, "status") if status == "" { status = "all" } keyword := mcpArgString(args, "keyword") page := int(mcpArgFloat(args, "page")) if page <= 0 { page = 1 } pageSize := int(mcpArgFloat(args, "page_size")) if pageSize <= 0 { pageSize = 20 } if pageSize > 100 { pageSize = 100 } offset := (page - 1) * pageSize queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword) if err != nil { return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil } totalPages := (total + pageSize - 1) / pageSize if totalPages == 0 { totalPages = 1 } payload := map[string]interface{}{ "queues": queues, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages, } logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total)) return batchMCPJSONResult(payload) }) // --- get --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskGet, Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。", ShortDescription: "获取批量任务队列详情", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, }, "required": []string{"queue_id"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") if qid == "" { return batchMCPTextResult("queue_id 不能为空", true), nil } queue, ok := h.batchTaskManager.GetBatchQueue(qid) if !ok { return batchMCPTextResult("队列不存在: "+qid, true), nil } return batchMCPJSONResult(queue) }) // --- create --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskCreate, Description: `创建新的批量任务队列。任务列表使用 tasks(字符串数组)或 tasks_text(多行,每行一条)。 agent_mode: single(默认)或 multi(需系统启用多代理)。schedule_mode: manual(默认)或 cron;为 cron 时必须提供 cron_expr(如 "0 */6 * * *")。 重要:创建成功后队列处于 pending,不会自动开始跑子任务。若要立即执行或手工开跑,必须再调用工具 batch_task_start(传入返回的 queue_id)。Cron 队列若需按表达式自动触发下一轮,还需保持调度开关开启(可用 batch_task_schedule_enabled)。`, ShortDescription: "创建批量任务队列(创建后需 batch_task_start 才会执行)", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "title": map[string]interface{}{ "type": "string", "description": "可选标题", }, "role": map[string]interface{}{ "type": "string", "description": "角色名称,空表示默认", }, "tasks": map[string]interface{}{ "type": "array", "description": "任务指令列表,每项一条", "items": map[string]interface{}{"type": "string"}, }, "tasks_text": map[string]interface{}{ "type": "string", "description": "多行文本,每行一条任务(与 tasks 二选一)", }, "agent_mode": map[string]interface{}{ "type": "string", "description": "single 或 multi", "enum": []string{"single", "multi"}, }, "schedule_mode": map[string]interface{}{ "type": "string", "description": "manual 或 cron", "enum": []string{"manual", "cron"}, }, "cron_expr": map[string]interface{}{ "type": "string", "description": "schedule_mode 为 cron 时必填", }, }, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { tasks, errMsg := batchMCPTasksFromArgs(args) if errMsg != "" { return batchMCPTextResult(errMsg, true), nil } title := mcpArgString(args, "title") role := mcpArgString(args, "role") agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode")) scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode")) cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr")) var nextRunAt *time.Time if scheduleMode == "cron" { if cronExpr == "" { return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil } sch, err := h.batchCronParser.Parse(cronExpr) if err != nil { return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil } n := sch.Next(time.Now()) nextRunAt = &n } queue := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks) logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks))) return batchMCPJSONResult(map[string]interface{}{ "queue_id": queue.ID, "queue": queue, "reminder": "队列已创建,当前为 pending。需要开始执行时请调用 MCP工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。", }) }) // --- start --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskStart, Description: `启动或继续执行批量任务队列(pending / paused)。 与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。`, ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, }, "required": []string{"queue_id"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") if qid == "" { return batchMCPTextResult("queue_id 不能为空", true), nil } ok, err := h.startBatchQueueExecution(qid, false) if !ok { return batchMCPTextResult("队列不存在: "+qid, true), nil } if err != nil { return batchMCPTextResult("启动失败: "+err.Error(), true), nil } logger.Info("MCP batch_task_start", zap.String("queueId", qid)) return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil }) // --- pause --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskPause, Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。", ShortDescription: "暂停批量任务队列", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, }, "required": []string{"queue_id"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") if qid == "" { return batchMCPTextResult("queue_id 不能为空", true), nil } if !h.batchTaskManager.PauseQueue(qid) { return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil } logger.Info("MCP batch_task_pause", zap.String("queueId", qid)) return batchMCPTextResult("队列已暂停。", false), nil }) // --- delete queue --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskDelete, Description: "删除批量任务队列及其子任务记录。", ShortDescription: "删除批量任务队列", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, }, "required": []string{"queue_id"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") if qid == "" { return batchMCPTextResult("queue_id 不能为空", true), nil } if !h.batchTaskManager.DeleteQueue(qid) { return batchMCPTextResult("删除失败:队列不存在", true), nil } logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) return batchMCPTextResult("队列已删除。", false), nil }) // --- schedule enabled --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskScheduleEnabled, Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。 仅对 schedule_mode 为 cron 的队列有意义。`, ShortDescription: "开关批量任务 Cron 自动调度", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, "schedule_enabled": map[string]interface{}{ "type": "boolean", "description": "true 允许定时触发,false 仅手工执行", }, }, "required": []string{"queue_id", "schedule_enabled"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") if qid == "" { return batchMCPTextResult("queue_id 不能为空", true), nil } en, ok := mcpArgBool(args, "schedule_enabled") if !ok { return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil } if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists { return batchMCPTextResult("队列不存在", true), nil } if !h.batchTaskManager.SetScheduleEnabled(qid, en) { return batchMCPTextResult("更新失败", true), nil } queue, _ := h.batchTaskManager.GetBatchQueue(qid) logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en)) return batchMCPJSONResult(queue) }) // --- add task --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskAdd, Description: "向处于 pending 状态的队列追加一条子任务。", ShortDescription: "批量队列添加子任务", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, "message": map[string]interface{}{ "type": "string", "description": "任务指令内容", }, }, "required": []string{"queue_id", "message"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") msg := strings.TrimSpace(mcpArgString(args, "message")) if qid == "" || msg == "" { return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil } task, err := h.batchTaskManager.AddTaskToQueue(qid, msg) if err != nil { return batchMCPTextResult(err.Error(), true), nil } queue, _ := h.batchTaskManager.GetBatchQueue(qid) logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID)) return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue}) }) // --- update task --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskUpdate, Description: "修改 pending 队列中仍为 pending 的子任务文案。", ShortDescription: "更新批量子任务内容", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, "task_id": map[string]interface{}{ "type": "string", "description": "子任务 ID", }, "message": map[string]interface{}{ "type": "string", "description": "新的任务指令", }, }, "required": []string{"queue_id", "task_id", "message"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") tid := mcpArgString(args, "task_id") msg := strings.TrimSpace(mcpArgString(args, "message")) if qid == "" || tid == "" || msg == "" { return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil } if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil { return batchMCPTextResult(err.Error(), true), nil } queue, _ := h.batchTaskManager.GetBatchQueue(qid) logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid)) return batchMCPJSONResult(queue) }) // --- remove task --- reg(mcp.Tool{ Name: builtin.ToolBatchTaskRemove, Description: "从 pending 队列中删除仍为 pending 的子任务。", ShortDescription: "删除批量子任务", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "queue_id": map[string]interface{}{ "type": "string", "description": "队列 ID", }, "task_id": map[string]interface{}{ "type": "string", "description": "子任务 ID", }, }, "required": []string{"queue_id", "task_id"}, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { qid := mcpArgString(args, "queue_id") tid := mcpArgString(args, "task_id") if qid == "" || tid == "" { return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil } if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil { return batchMCPTextResult(err.Error(), true), nil } queue, _ := h.batchTaskManager.GetBatchQueue(qid) logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid)) return batchMCPJSONResult(queue) }) logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 10)) } func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult { return &mcp.ToolResult{ Content: []mcp.Content{{Type: "text", Text: text}}, IsError: isErr, } } func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) { b, err := json.MarshalIndent(v, "", " ") if err != nil { return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil } return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil } func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) { if raw, ok := args["tasks"]; ok && raw != nil { switch t := raw.(type) { case []interface{}: out := make([]string, 0, len(t)) for _, x := range t { if s, ok := x.(string); ok { if tr := strings.TrimSpace(s); tr != "" { out = append(out, tr) } } } if len(out) > 0 { return out, "" } } } if txt := mcpArgString(args, "tasks_text"); txt != "" { lines := strings.Split(txt, "\n") out := make([]string, 0, len(lines)) for _, line := range lines { if tr := strings.TrimSpace(line); tr != "" { out = append(out, tr) } } if len(out) > 0 { return out, "" } } return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)" } func mcpArgString(args map[string]interface{}, key string) string { v, ok := args[key] if !ok || v == nil { return "" } switch t := v.(type) { case string: return strings.TrimSpace(t) case float64: return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64)) case json.Number: return strings.TrimSpace(t.String()) default: return strings.TrimSpace(fmt.Sprint(t)) } } func mcpArgFloat(args map[string]interface{}, key string) float64 { v, ok := args[key] if !ok || v == nil { return 0 } switch t := v.(type) { case float64: return t case int: return float64(t) case int64: return float64(t) case json.Number: f, _ := t.Float64() return f case string: f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64) return f default: return 0 } } func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) { v, exists := args[key] if !exists { return false, false } switch t := v.(type) { case bool: return t, true case string: s := strings.ToLower(strings.TrimSpace(t)) if s == "true" || s == "1" || s == "yes" { return true, true } if s == "false" || s == "0" || s == "no" { return false, true } case float64: return t != 0, true } return false, false }