mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
563 lines
19 KiB
Go
563 lines
19 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/mcp"
|
||
"cyberstrike-ai/internal/mcp/builtin"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler)
|
||
func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) {
|
||
if mcpServer == nil || h == nil || logger == nil {
|
||
return
|
||
}
|
||
|
||
reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) {
|
||
mcpServer.RegisterTool(tool, fn)
|
||
}
|
||
|
||
// --- list ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskList,
|
||
Description: "列出批量任务队列,支持按状态筛选与关键字搜索。用于查看队列 id、状态、进度及 Cron 配置等。",
|
||
ShortDescription: "列出批量任务队列",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"status": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled",
|
||
"enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"},
|
||
},
|
||
"keyword": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "按队列 ID 或标题模糊搜索",
|
||
},
|
||
"page": map[string]interface{}{
|
||
"type": "integer",
|
||
"description": "页码,从 1 开始,默认 1",
|
||
},
|
||
"page_size": map[string]interface{}{
|
||
"type": "integer",
|
||
"description": "每页条数,默认 20,最大 100",
|
||
},
|
||
},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
status := mcpArgString(args, "status")
|
||
if status == "" {
|
||
status = "all"
|
||
}
|
||
keyword := mcpArgString(args, "keyword")
|
||
page := int(mcpArgFloat(args, "page"))
|
||
if page <= 0 {
|
||
page = 1
|
||
}
|
||
pageSize := int(mcpArgFloat(args, "page_size"))
|
||
if pageSize <= 0 {
|
||
pageSize = 20
|
||
}
|
||
if pageSize > 100 {
|
||
pageSize = 100
|
||
}
|
||
offset := (page - 1) * pageSize
|
||
queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword)
|
||
if err != nil {
|
||
return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil
|
||
}
|
||
totalPages := (total + pageSize - 1) / pageSize
|
||
if totalPages == 0 {
|
||
totalPages = 1
|
||
}
|
||
payload := map[string]interface{}{
|
||
"queues": queues,
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": pageSize,
|
||
"total_pages": totalPages,
|
||
}
|
||
logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total))
|
||
return batchMCPJSONResult(payload)
|
||
})
|
||
|
||
// --- get ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskGet,
|
||
Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。",
|
||
ShortDescription: "获取批量任务队列详情",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
},
|
||
"required": []string{"queue_id"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
if qid == "" {
|
||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||
}
|
||
queue, ok := h.batchTaskManager.GetBatchQueue(qid)
|
||
if !ok {
|
||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||
}
|
||
return batchMCPJSONResult(queue)
|
||
})
|
||
|
||
// --- create ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskCreate,
|
||
Description: `创建新的批量任务队列。任务列表使用 tasks(字符串数组)或 tasks_text(多行,每行一条)。
|
||
agent_mode: single(默认)或 multi(需系统启用多代理)。schedule_mode: manual(默认)或 cron;为 cron 时必须提供 cron_expr(如 "0 */6 * * *")。
|
||
默认创建后不会立即执行。可通过 execute_now=true 在创建后立即启动;也可后续调用 batch_task_start 手工启动。Cron 队列若需按表达式自动触发下一轮,还需保持调度开关开启(可用 batch_task_schedule_enabled)。`,
|
||
ShortDescription: "创建批量任务队列(可选立即执行)",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"title": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "可选标题",
|
||
},
|
||
"role": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "角色名称,空表示默认",
|
||
},
|
||
"tasks": map[string]interface{}{
|
||
"type": "array",
|
||
"description": "任务指令列表,每项一条",
|
||
"items": map[string]interface{}{"type": "string"},
|
||
},
|
||
"tasks_text": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "多行文本,每行一条任务(与 tasks 二选一)",
|
||
},
|
||
"agent_mode": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "single 或 multi",
|
||
"enum": []string{"single", "multi"},
|
||
},
|
||
"schedule_mode": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "manual 或 cron",
|
||
"enum": []string{"manual", "cron"},
|
||
},
|
||
"cron_expr": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "schedule_mode 为 cron 时必填",
|
||
},
|
||
"execute_now": map[string]interface{}{
|
||
"type": "boolean",
|
||
"description": "是否创建后立即执行,默认 false",
|
||
},
|
||
},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
tasks, errMsg := batchMCPTasksFromArgs(args)
|
||
if errMsg != "" {
|
||
return batchMCPTextResult(errMsg, true), nil
|
||
}
|
||
title := mcpArgString(args, "title")
|
||
role := mcpArgString(args, "role")
|
||
agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode"))
|
||
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
|
||
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
|
||
var nextRunAt *time.Time
|
||
if scheduleMode == "cron" {
|
||
if cronExpr == "" {
|
||
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
|
||
}
|
||
sch, err := h.batchCronParser.Parse(cronExpr)
|
||
if err != nil {
|
||
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
|
||
}
|
||
n := sch.Next(time.Now())
|
||
nextRunAt = &n
|
||
}
|
||
executeNow, ok := mcpArgBool(args, "execute_now")
|
||
if !ok {
|
||
executeNow = false
|
||
}
|
||
queue := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
|
||
started := false
|
||
if executeNow {
|
||
ok, err := h.startBatchQueueExecution(queue.ID, false)
|
||
if !ok {
|
||
return batchMCPTextResult("队列不存在: "+queue.ID, true), nil
|
||
}
|
||
if err != nil {
|
||
return batchMCPTextResult("创建成功但启动失败: "+err.Error(), true), nil
|
||
}
|
||
started = true
|
||
if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists {
|
||
queue = refreshed
|
||
}
|
||
}
|
||
logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks)))
|
||
return batchMCPJSONResult(map[string]interface{}{
|
||
"queue_id": queue.ID,
|
||
"queue": queue,
|
||
"started": started,
|
||
"execute_now": executeNow,
|
||
"reminder": func() string {
|
||
if started {
|
||
return "队列已创建并立即启动。"
|
||
}
|
||
return "队列已创建,当前为 pending。需要开始执行时请调用 MCP 工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。"
|
||
}(),
|
||
})
|
||
})
|
||
|
||
// --- start ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskStart,
|
||
Description: `启动或继续执行批量任务队列(pending / paused)。
|
||
与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。`,
|
||
ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
},
|
||
"required": []string{"queue_id"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
if qid == "" {
|
||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||
}
|
||
ok, err := h.startBatchQueueExecution(qid, false)
|
||
if !ok {
|
||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||
}
|
||
if err != nil {
|
||
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
|
||
}
|
||
logger.Info("MCP batch_task_start", zap.String("queueId", qid))
|
||
return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil
|
||
})
|
||
|
||
// --- pause ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskPause,
|
||
Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。",
|
||
ShortDescription: "暂停批量任务队列",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
},
|
||
"required": []string{"queue_id"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
if qid == "" {
|
||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||
}
|
||
if !h.batchTaskManager.PauseQueue(qid) {
|
||
return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil
|
||
}
|
||
logger.Info("MCP batch_task_pause", zap.String("queueId", qid))
|
||
return batchMCPTextResult("队列已暂停。", false), nil
|
||
})
|
||
|
||
// --- delete queue ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskDelete,
|
||
Description: "删除批量任务队列及其子任务记录。",
|
||
ShortDescription: "删除批量任务队列",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
},
|
||
"required": []string{"queue_id"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
if qid == "" {
|
||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||
}
|
||
if !h.batchTaskManager.DeleteQueue(qid) {
|
||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||
}
|
||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||
return batchMCPTextResult("队列已删除。", false), nil
|
||
})
|
||
|
||
// --- schedule enabled ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskScheduleEnabled,
|
||
Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。
|
||
仅对 schedule_mode 为 cron 的队列有意义。`,
|
||
ShortDescription: "开关批量任务 Cron 自动调度",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
"schedule_enabled": map[string]interface{}{
|
||
"type": "boolean",
|
||
"description": "true 允许定时触发,false 仅手工执行",
|
||
},
|
||
},
|
||
"required": []string{"queue_id", "schedule_enabled"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
if qid == "" {
|
||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||
}
|
||
en, ok := mcpArgBool(args, "schedule_enabled")
|
||
if !ok {
|
||
return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil
|
||
}
|
||
if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists {
|
||
return batchMCPTextResult("队列不存在", true), nil
|
||
}
|
||
if !h.batchTaskManager.SetScheduleEnabled(qid, en) {
|
||
return batchMCPTextResult("更新失败", true), nil
|
||
}
|
||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||
logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en))
|
||
return batchMCPJSONResult(queue)
|
||
})
|
||
|
||
// --- add task ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskAdd,
|
||
Description: "向处于 pending 状态的队列追加一条子任务。",
|
||
ShortDescription: "批量队列添加子任务",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
"message": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "任务指令内容",
|
||
},
|
||
},
|
||
"required": []string{"queue_id", "message"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
msg := strings.TrimSpace(mcpArgString(args, "message"))
|
||
if qid == "" || msg == "" {
|
||
return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil
|
||
}
|
||
task, err := h.batchTaskManager.AddTaskToQueue(qid, msg)
|
||
if err != nil {
|
||
return batchMCPTextResult(err.Error(), true), nil
|
||
}
|
||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||
logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID))
|
||
return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue})
|
||
})
|
||
|
||
// --- update task ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskUpdate,
|
||
Description: "修改 pending 队列中仍为 pending 的子任务文案。",
|
||
ShortDescription: "更新批量子任务内容",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
"task_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "子任务 ID",
|
||
},
|
||
"message": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "新的任务指令",
|
||
},
|
||
},
|
||
"required": []string{"queue_id", "task_id", "message"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
tid := mcpArgString(args, "task_id")
|
||
msg := strings.TrimSpace(mcpArgString(args, "message"))
|
||
if qid == "" || tid == "" || msg == "" {
|
||
return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil
|
||
}
|
||
if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil {
|
||
return batchMCPTextResult(err.Error(), true), nil
|
||
}
|
||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||
logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid))
|
||
return batchMCPJSONResult(queue)
|
||
})
|
||
|
||
// --- remove task ---
|
||
reg(mcp.Tool{
|
||
Name: builtin.ToolBatchTaskRemove,
|
||
Description: "从 pending 队列中删除仍为 pending 的子任务。",
|
||
ShortDescription: "删除批量子任务",
|
||
InputSchema: map[string]interface{}{
|
||
"type": "object",
|
||
"properties": map[string]interface{}{
|
||
"queue_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "队列 ID",
|
||
},
|
||
"task_id": map[string]interface{}{
|
||
"type": "string",
|
||
"description": "子任务 ID",
|
||
},
|
||
},
|
||
"required": []string{"queue_id", "task_id"},
|
||
},
|
||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||
qid := mcpArgString(args, "queue_id")
|
||
tid := mcpArgString(args, "task_id")
|
||
if qid == "" || tid == "" {
|
||
return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil
|
||
}
|
||
if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil {
|
||
return batchMCPTextResult(err.Error(), true), nil
|
||
}
|
||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||
logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid))
|
||
return batchMCPJSONResult(queue)
|
||
})
|
||
|
||
logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 10))
|
||
}
|
||
|
||
func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult {
|
||
return &mcp.ToolResult{
|
||
Content: []mcp.Content{{Type: "text", Text: text}},
|
||
IsError: isErr,
|
||
}
|
||
}
|
||
|
||
func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) {
|
||
b, err := json.MarshalIndent(v, "", " ")
|
||
if err != nil {
|
||
return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil
|
||
}
|
||
return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil
|
||
}
|
||
|
||
func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) {
|
||
if raw, ok := args["tasks"]; ok && raw != nil {
|
||
switch t := raw.(type) {
|
||
case []interface{}:
|
||
out := make([]string, 0, len(t))
|
||
for _, x := range t {
|
||
if s, ok := x.(string); ok {
|
||
if tr := strings.TrimSpace(s); tr != "" {
|
||
out = append(out, tr)
|
||
}
|
||
}
|
||
}
|
||
if len(out) > 0 {
|
||
return out, ""
|
||
}
|
||
}
|
||
}
|
||
if txt := mcpArgString(args, "tasks_text"); txt != "" {
|
||
lines := strings.Split(txt, "\n")
|
||
out := make([]string, 0, len(lines))
|
||
for _, line := range lines {
|
||
if tr := strings.TrimSpace(line); tr != "" {
|
||
out = append(out, tr)
|
||
}
|
||
}
|
||
if len(out) > 0 {
|
||
return out, ""
|
||
}
|
||
}
|
||
return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)"
|
||
}
|
||
|
||
func mcpArgString(args map[string]interface{}, key string) string {
|
||
v, ok := args[key]
|
||
if !ok || v == nil {
|
||
return ""
|
||
}
|
||
switch t := v.(type) {
|
||
case string:
|
||
return strings.TrimSpace(t)
|
||
case float64:
|
||
return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64))
|
||
case json.Number:
|
||
return strings.TrimSpace(t.String())
|
||
default:
|
||
return strings.TrimSpace(fmt.Sprint(t))
|
||
}
|
||
}
|
||
|
||
func mcpArgFloat(args map[string]interface{}, key string) float64 {
|
||
v, ok := args[key]
|
||
if !ok || v == nil {
|
||
return 0
|
||
}
|
||
switch t := v.(type) {
|
||
case float64:
|
||
return t
|
||
case int:
|
||
return float64(t)
|
||
case int64:
|
||
return float64(t)
|
||
case json.Number:
|
||
f, _ := t.Float64()
|
||
return f
|
||
case string:
|
||
f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64)
|
||
return f
|
||
default:
|
||
return 0
|
||
}
|
||
}
|
||
|
||
func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) {
|
||
v, exists := args[key]
|
||
if !exists {
|
||
return false, false
|
||
}
|
||
switch t := v.(type) {
|
||
case bool:
|
||
return t, true
|
||
case string:
|
||
s := strings.ToLower(strings.TrimSpace(t))
|
||
if s == "true" || s == "1" || s == "yes" {
|
||
return true, true
|
||
}
|
||
if s == "false" || s == "0" || s == "no" {
|
||
return false, true
|
||
}
|
||
case float64:
|
||
return t != 0, true
|
||
}
|
||
return false, false
|
||
}
|