mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-05 22:06:41 +02:00
Add files via upload
This commit is contained in:
+211
-19
@@ -24,6 +24,7 @@ import (
|
||||
"cyberstrike-ai/internal/skills"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/robfig/cron/v3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -81,6 +82,9 @@ type AgentHandler struct {
|
||||
}
|
||||
skillsManager *skills.Manager // Skills管理器
|
||||
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
||||
batchCronParser cron.Parser
|
||||
batchRunnerMu sync.Mutex
|
||||
batchRunning map[string]struct{}
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
@@ -93,14 +97,18 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
logger.Warn("从数据库加载批量任务队列失败", zap.Error(err))
|
||||
}
|
||||
|
||||
return &AgentHandler{
|
||||
handler := &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
logger: logger,
|
||||
tasks: NewAgentTaskManager(),
|
||||
batchTaskManager: batchTaskManager,
|
||||
config: cfg,
|
||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||
batchRunning: make(map[string]struct{}),
|
||||
}
|
||||
go handler.batchQueueSchedulerLoop()
|
||||
return handler
|
||||
}
|
||||
|
||||
// SetKnowledgeManager 设置知识库管理器(用于记录检索日志)
|
||||
@@ -1575,9 +1583,26 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
||||
|
||||
// BatchTaskRequest 批量任务请求
|
||||
type BatchTaskRequest struct {
|
||||
Title string `json:"title"` // 任务标题(可选)
|
||||
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
|
||||
Role string `json:"role,omitempty"` // 角色名称(可选,空字符串表示默认角色)
|
||||
Title string `json:"title"` // 任务标题(可选)
|
||||
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
|
||||
Role string `json:"role,omitempty"` // 角色名称(可选,空字符串表示默认角色)
|
||||
AgentMode string `json:"agentMode,omitempty"` // single | multi
|
||||
ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron
|
||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||
}
|
||||
|
||||
func normalizeBatchQueueAgentMode(mode string) string {
|
||||
if strings.TrimSpace(mode) == "multi" {
|
||||
return "multi"
|
||||
}
|
||||
return "single"
|
||||
}
|
||||
|
||||
func normalizeBatchQueueScheduleMode(mode string) string {
|
||||
if strings.TrimSpace(mode) == "cron" {
|
||||
return "cron"
|
||||
}
|
||||
return "manual"
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
@@ -1606,7 +1631,25 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
queue := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, validTasks)
|
||||
agentMode := normalizeBatchQueueAgentMode(req.AgentMode)
|
||||
scheduleMode := normalizeBatchQueueScheduleMode(req.ScheduleMode)
|
||||
cronExpr := strings.TrimSpace(req.CronExpr)
|
||||
var nextRunAt *time.Time
|
||||
if scheduleMode == "cron" {
|
||||
if cronExpr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "启用 Cron 调度时,调度表达式不能为空"})
|
||||
return
|
||||
}
|
||||
schedule, err := h.batchCronParser.Parse(cronExpr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的 Cron 表达式: " + err.Error()})
|
||||
return
|
||||
}
|
||||
next := schedule.Next(time.Now())
|
||||
nextRunAt = &next
|
||||
}
|
||||
|
||||
queue := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"queueId": queue.ID,
|
||||
"queue": queue,
|
||||
@@ -1699,21 +1742,15 @@ func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
|
||||
// StartBatchQueue 开始执行批量任务队列
|
||||
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
ok, err := h.startBatchQueueExecution(queueID, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if queue.Status != "pending" && queue.Status != "paused" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "队列状态不允许启动"})
|
||||
return
|
||||
}
|
||||
|
||||
// 在后台执行批量任务
|
||||
go h.executeBatchQueue(queueID)
|
||||
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "running")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -1728,6 +1765,28 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
||||
}
|
||||
|
||||
// SetBatchQueueScheduleEnabled 开启/关闭 Cron 自动调度(手工执行不受影响)
|
||||
func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
if _, exists := h.batchTaskManager.GetBatchQueue(queueID); !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
ScheduleEnabled bool `json:"scheduleEnabled"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !h.batchTaskManager.SetScheduleEnabled(queueID, req.ScheduleEnabled) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
c.JSON(http.StatusOK, gin.H{"queue": queue})
|
||||
}
|
||||
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
@@ -1824,8 +1883,125 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
|
||||
h.batchRunnerMu.Lock()
|
||||
defer h.batchRunnerMu.Unlock()
|
||||
if _, exists := h.batchRunning[queueID]; exists {
|
||||
return false
|
||||
}
|
||||
h.batchRunning[queueID] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
||||
h.batchRunnerMu.Lock()
|
||||
defer h.batchRunnerMu.Unlock()
|
||||
delete(h.batchRunning, queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||
expr := strings.TrimSpace(cronExpr)
|
||||
if expr == "" {
|
||||
return nil, nil
|
||||
}
|
||||
schedule, err := h.batchCronParser.Parse(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
next := schedule.Next(from)
|
||||
return &next, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
if !h.markBatchQueueRunning(queueID) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if scheduled {
|
||||
if queue.ScheduleMode != "cron" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
err := fmt.Errorf("队列未启用 cron 调度")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
err := fmt.Errorf("重置队列失败")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
} else if queue.Status != "pending" && queue.Status != "paused" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
return true, fmt.Errorf("队列状态不允许启动")
|
||||
}
|
||||
|
||||
if queue != nil && queue.AgentMode == "multi" && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
err := fmt.Errorf("当前队列配置为多代理,但系统未启用多代理")
|
||||
if scheduled {
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
}
|
||||
return true, err
|
||||
}
|
||||
|
||||
if scheduled {
|
||||
h.batchTaskManager.RecordScheduledRunStart(queueID)
|
||||
}
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "running")
|
||||
if queue != nil && queue.ScheduleMode == "cron" {
|
||||
nextRunAt, err := h.nextBatchQueueRunAt(queue.CronExpr, time.Now())
|
||||
if err == nil {
|
||||
h.batchTaskManager.UpdateQueueSchedule(queueID, "cron", queue.CronExpr, nextRunAt)
|
||||
}
|
||||
}
|
||||
|
||||
go h.executeBatchQueue(queueID)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) batchQueueSchedulerLoop() {
|
||||
ticker := time.NewTicker(20 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
queues := h.batchTaskManager.GetAllQueues()
|
||||
now := time.Now()
|
||||
for _, queue := range queues {
|
||||
if queue == nil || queue.ScheduleMode != "cron" || !queue.ScheduleEnabled || queue.Status == "cancelled" || queue.Status == "running" || queue.Status == "paused" {
|
||||
continue
|
||||
}
|
||||
nextRunAt := queue.NextRunAt
|
||||
if nextRunAt == nil {
|
||||
next, err := h.nextBatchQueueRunAt(queue.CronExpr, now)
|
||||
if err != nil {
|
||||
h.logger.Warn("批量任务 cron 表达式无效,跳过调度", zap.String("queueId", queue.ID), zap.String("cronExpr", queue.CronExpr), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
h.batchTaskManager.UpdateQueueSchedule(queue.ID, "cron", queue.CronExpr, next)
|
||||
nextRunAt = next
|
||||
}
|
||||
if nextRunAt != nil && (nextRunAt.Before(now) || nextRunAt.Equal(now)) {
|
||||
if _, err := h.startBatchQueueExecution(queue.ID, true); err != nil {
|
||||
h.logger.Warn("自动调度批量任务失败", zap.String("queueId", queue.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeBatchQueue 执行批量任务队列
|
||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
defer h.unmarkBatchQueueRunning(queueID)
|
||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
|
||||
|
||||
for {
|
||||
@@ -1838,7 +2014,17 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
// 获取下一个任务
|
||||
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
|
||||
if !hasNext {
|
||||
// 所有任务完成
|
||||
// 所有任务完成:汇总子任务失败信息便于排障
|
||||
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
lastRunErr := ""
|
||||
if ok {
|
||||
for _, t := range q.Tasks {
|
||||
if t.Status == "failed" && t.Error != "" {
|
||||
lastRunErr = t.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
|
||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||
break
|
||||
@@ -1918,7 +2104,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
|
||||
useBatchMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent
|
||||
useBatchMulti := false
|
||||
if queue.AgentMode == "multi" {
|
||||
useBatchMulti = h.config != nil && h.config.MultiAgent.Enabled
|
||||
} else if queue.AgentMode == "" {
|
||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
||||
useBatchMulti = h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent
|
||||
}
|
||||
var result *agent.AgentLoopResult
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
|
||||
@@ -27,24 +27,32 @@ type BatchTask struct {
|
||||
|
||||
// BatchTaskQueue 批量任务队列
|
||||
type BatchTaskQueue struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色)
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
mu sync.RWMutex
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Role string `json:"role,omitempty"` // 角色名称(空字符串表示默认角色)
|
||||
AgentMode string `json:"agentMode"` // single | multi
|
||||
ScheduleMode string `json:"scheduleMode"` // manual | cron
|
||||
CronExpr string `json:"cronExpr,omitempty"`
|
||||
NextRunAt *time.Time `json:"nextRunAt,omitempty"`
|
||||
ScheduleEnabled bool `json:"scheduleEnabled"`
|
||||
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||
LastRunError string `json:"lastRunError,omitempty"`
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BatchTaskManager 批量任务管理器
|
||||
type BatchTaskManager struct {
|
||||
db *database.DB
|
||||
queues map[string]*BatchTaskQueue
|
||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||
mu sync.RWMutex
|
||||
db *database.DB
|
||||
queues map[string]*BatchTaskQueue
|
||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBatchTaskManager 创建批量任务管理器
|
||||
@@ -63,19 +71,32 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(title, role string, tasks []string) *BatchTaskQueue {
|
||||
func (m *BatchTaskManager) CreateBatchQueue(
|
||||
title, role, agentMode, scheduleMode, cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
tasks []string,
|
||||
) *BatchTaskQueue {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queueID := time.Now().Format("20060102150405") + "-" + generateShortID()
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueID,
|
||||
Title: title,
|
||||
Role: role,
|
||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
CurrentIndex: 0,
|
||||
ID: queueID,
|
||||
Title: title,
|
||||
Role: role,
|
||||
AgentMode: normalizeBatchQueueAgentMode(agentMode),
|
||||
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
|
||||
CronExpr: strings.TrimSpace(cronExpr),
|
||||
NextRunAt: nextRunAt,
|
||||
ScheduleEnabled: true,
|
||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
CurrentIndex: 0,
|
||||
}
|
||||
if queue.ScheduleMode != "cron" {
|
||||
queue.CronExpr = ""
|
||||
queue.NextRunAt = nil
|
||||
}
|
||||
|
||||
// 准备数据库保存的任务数据
|
||||
@@ -100,7 +121,16 @@ func (m *BatchTaskManager) CreateBatchQueue(title, role string, tasks []string)
|
||||
|
||||
// 保存到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.CreateBatchQueue(queueID, title, role, dbTasks); err != nil {
|
||||
if err := m.db.CreateBatchQueue(
|
||||
queueID,
|
||||
title,
|
||||
role,
|
||||
queue.AgentMode,
|
||||
queue.ScheduleMode,
|
||||
queue.CronExpr,
|
||||
queue.NextRunAt,
|
||||
dbTasks,
|
||||
); err != nil {
|
||||
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
|
||||
// 这里可以添加日志记录
|
||||
}
|
||||
@@ -151,6 +181,8 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueRow.ID,
|
||||
AgentMode: "single",
|
||||
ScheduleMode: "manual",
|
||||
Status: queueRow.Status,
|
||||
CreatedAt: queueRow.CreatedAt,
|
||||
CurrentIndex: queueRow.CurrentIndex,
|
||||
@@ -163,6 +195,33 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if queueRow.Role.Valid {
|
||||
queue.Role = queueRow.Role.String
|
||||
}
|
||||
if queueRow.AgentMode.Valid {
|
||||
queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String)
|
||||
}
|
||||
if queueRow.ScheduleMode.Valid {
|
||||
queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String)
|
||||
}
|
||||
if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" {
|
||||
queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String)
|
||||
}
|
||||
if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" {
|
||||
t := queueRow.NextRunAt.Time
|
||||
queue.NextRunAt = &t
|
||||
}
|
||||
queue.ScheduleEnabled = true
|
||||
if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 {
|
||||
queue.ScheduleEnabled = false
|
||||
}
|
||||
if queueRow.LastScheduleTriggerAt.Valid {
|
||||
t := queueRow.LastScheduleTriggerAt.Time
|
||||
queue.LastScheduleTriggerAt = &t
|
||||
}
|
||||
if queueRow.LastScheduleError.Valid {
|
||||
queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String)
|
||||
}
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -347,6 +406,8 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueRow.ID,
|
||||
AgentMode: "single",
|
||||
ScheduleMode: "manual",
|
||||
Status: queueRow.Status,
|
||||
CreatedAt: queueRow.CreatedAt,
|
||||
CurrentIndex: queueRow.CurrentIndex,
|
||||
@@ -359,6 +420,33 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if queueRow.Role.Valid {
|
||||
queue.Role = queueRow.Role.String
|
||||
}
|
||||
if queueRow.AgentMode.Valid {
|
||||
queue.AgentMode = normalizeBatchQueueAgentMode(queueRow.AgentMode.String)
|
||||
}
|
||||
if queueRow.ScheduleMode.Valid {
|
||||
queue.ScheduleMode = normalizeBatchQueueScheduleMode(queueRow.ScheduleMode.String)
|
||||
}
|
||||
if queueRow.CronExpr.Valid && queue.ScheduleMode == "cron" {
|
||||
queue.CronExpr = strings.TrimSpace(queueRow.CronExpr.String)
|
||||
}
|
||||
if queueRow.NextRunAt.Valid && queue.ScheduleMode == "cron" {
|
||||
t := queueRow.NextRunAt.Time
|
||||
queue.NextRunAt = &t
|
||||
}
|
||||
queue.ScheduleEnabled = true
|
||||
if queueRow.ScheduleEnabled.Valid && queueRow.ScheduleEnabled.Int64 == 0 {
|
||||
queue.ScheduleEnabled = false
|
||||
}
|
||||
if queueRow.LastScheduleTriggerAt.Valid {
|
||||
t := queueRow.LastScheduleTriggerAt.Time
|
||||
queue.LastScheduleTriggerAt = &t
|
||||
}
|
||||
if queueRow.LastScheduleError.Valid {
|
||||
queue.LastScheduleError = strings.TrimSpace(queueRow.LastScheduleError.String)
|
||||
}
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -469,6 +557,127 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueSchedule 更新队列调度配置
|
||||
func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
queue.ScheduleMode = normalizeBatchQueueScheduleMode(scheduleMode)
|
||||
if queue.ScheduleMode == "cron" {
|
||||
queue.CronExpr = strings.TrimSpace(cronExpr)
|
||||
queue.NextRunAt = nextRunAt
|
||||
} else {
|
||||
queue.CronExpr = ""
|
||||
queue.NextRunAt = nil
|
||||
}
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueSchedule(queueID, queue.ScheduleMode, queue.CronExpr, queue.NextRunAt); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetScheduleEnabled 暂停/恢复 Cron 自动调度(不影响手工执行)
|
||||
func (m *BatchTaskManager) SetScheduleEnabled(queueID string, enabled bool) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
queue.ScheduleEnabled = enabled
|
||||
if m.db != nil {
|
||||
_ = m.db.UpdateBatchQueueScheduleEnabled(queueID, enabled)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// RecordScheduledRunStart Cron 触发成功、即将执行子任务时调用
|
||||
func (m *BatchTaskManager) RecordScheduledRunStart(queueID string) {
|
||||
now := time.Now()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
queue.LastScheduleTriggerAt = &now
|
||||
queue.LastScheduleError = ""
|
||||
if m.db != nil {
|
||||
_ = m.db.RecordBatchQueueScheduledTriggerStart(queueID, now)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLastScheduleError 调度层失败(未成功开始执行)
|
||||
func (m *BatchTaskManager) SetLastScheduleError(queueID, msg string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
queue.LastScheduleError = strings.TrimSpace(msg)
|
||||
if m.db != nil {
|
||||
_ = m.db.SetBatchQueueLastScheduleError(queueID, queue.LastScheduleError)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLastRunError 最近一轮批量执行中的失败摘要
|
||||
func (m *BatchTaskManager) SetLastRunError(queueID, msg string) {
|
||||
msg = strings.TrimSpace(msg)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
queue.LastRunError = msg
|
||||
if m.db != nil {
|
||||
_ = m.db.SetBatchQueueLastRunError(queueID, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetQueueForRerun 重置队列与子任务状态,供 cron 下一轮执行
|
||||
func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
queue.Status = "pending"
|
||||
queue.CurrentIndex = 0
|
||||
queue.StartedAt = nil
|
||||
queue.CompletedAt = nil
|
||||
queue.NextRunAt = nil
|
||||
for _, task := range queue.Tasks {
|
||||
task.Status = "pending"
|
||||
task.ConversationID = ""
|
||||
task.StartedAt = nil
|
||||
task.CompletedAt = nil
|
||||
task.Error = ""
|
||||
task.Result = ""
|
||||
}
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.ResetBatchQueueForRerun(queueID); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// UpdateTaskMessage 更新任务消息(仅限待执行状态)
|
||||
func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
@@ -0,0 +1,533 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler)
|
||||
func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) {
|
||||
if mcpServer == nil || h == nil || logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) {
|
||||
mcpServer.RegisterTool(tool, fn)
|
||||
}
|
||||
|
||||
// --- list ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskList,
|
||||
Description: "列出批量任务队列,支持按状态筛选与关键字搜索。用于查看队列 id、状态、进度及 Cron 配置等。",
|
||||
ShortDescription: "列出批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"status": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled",
|
||||
"enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"},
|
||||
},
|
||||
"keyword": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "按队列 ID 或标题模糊搜索",
|
||||
},
|
||||
"page": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "页码,从 1 开始,默认 1",
|
||||
},
|
||||
"page_size": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "每页条数,默认 20,最大 100",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
status := mcpArgString(args, "status")
|
||||
if status == "" {
|
||||
status = "all"
|
||||
}
|
||||
keyword := mcpArgString(args, "keyword")
|
||||
page := int(mcpArgFloat(args, "page"))
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := int(mcpArgFloat(args, "page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword)
|
||||
if err != nil {
|
||||
return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil
|
||||
}
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
payload := map[string]interface{}{
|
||||
"queues": queues,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": totalPages,
|
||||
}
|
||||
logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total))
|
||||
return batchMCPJSONResult(payload)
|
||||
})
|
||||
|
||||
// --- get ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskGet,
|
||||
Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。",
|
||||
ShortDescription: "获取批量任务队列详情",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
queue, ok := h.batchTaskManager.GetBatchQueue(qid)
|
||||
if !ok {
|
||||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||||
}
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
// --- create ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskCreate,
|
||||
Description: `创建新的批量任务队列。任务列表使用 tasks(字符串数组)或 tasks_text(多行,每行一条)。
|
||||
agent_mode: single(默认)或 multi(需系统启用多代理)。schedule_mode: manual(默认)或 cron;为 cron 时必须提供 cron_expr(如 "0 */6 * * *")。
|
||||
重要:创建成功后队列处于 pending,不会自动开始跑子任务。若要立即执行或手工开跑,必须再调用工具 batch_task_start(传入返回的 queue_id)。Cron 队列若需按表达式自动触发下一轮,还需保持调度开关开启(可用 batch_task_schedule_enabled)。`,
|
||||
ShortDescription: "创建批量任务队列(创建后需 batch_task_start 才会执行)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选标题",
|
||||
},
|
||||
"role": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "角色名称,空表示默认",
|
||||
},
|
||||
"tasks": map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "任务指令列表,每项一条",
|
||||
"items": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"tasks_text": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "多行文本,每行一条任务(与 tasks 二选一)",
|
||||
},
|
||||
"agent_mode": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "single 或 multi",
|
||||
"enum": []string{"single", "multi"},
|
||||
},
|
||||
"schedule_mode": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "manual 或 cron",
|
||||
"enum": []string{"manual", "cron"},
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "schedule_mode 为 cron 时必填",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
tasks, errMsg := batchMCPTasksFromArgs(args)
|
||||
if errMsg != "" {
|
||||
return batchMCPTextResult(errMsg, true), nil
|
||||
}
|
||||
title := mcpArgString(args, "title")
|
||||
role := mcpArgString(args, "role")
|
||||
agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode"))
|
||||
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
|
||||
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
|
||||
var nextRunAt *time.Time
|
||||
if scheduleMode == "cron" {
|
||||
if cronExpr == "" {
|
||||
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
|
||||
}
|
||||
sch, err := h.batchCronParser.Parse(cronExpr)
|
||||
if err != nil {
|
||||
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
|
||||
}
|
||||
n := sch.Next(time.Now())
|
||||
nextRunAt = &n
|
||||
}
|
||||
queue := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
|
||||
logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks)))
|
||||
return batchMCPJSONResult(map[string]interface{}{
|
||||
"queue_id": queue.ID,
|
||||
"queue": queue,
|
||||
"reminder": "队列已创建,当前为 pending。需要开始执行时请调用 MCP工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。",
|
||||
})
|
||||
})
|
||||
|
||||
// --- start ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskStart,
|
||||
Description: `启动或继续执行批量任务队列(pending / paused)。
|
||||
与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。`,
|
||||
ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
ok, err := h.startBatchQueueExecution(qid, false)
|
||||
if !ok {
|
||||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||||
}
|
||||
if err != nil {
|
||||
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_start", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil
|
||||
})
|
||||
|
||||
// --- pause ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskPause,
|
||||
Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。",
|
||||
ShortDescription: "暂停批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.PauseQueue(qid) {
|
||||
return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_pause", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("队列已暂停。", false), nil
|
||||
})
|
||||
|
||||
// --- delete queue ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskDelete,
|
||||
Description: "删除批量任务队列及其子任务记录。",
|
||||
ShortDescription: "删除批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.DeleteQueue(qid) {
|
||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("队列已删除。", false), nil
|
||||
})
|
||||
|
||||
// --- schedule enabled ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskScheduleEnabled,
|
||||
Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。
|
||||
仅对 schedule_mode 为 cron 的队列有意义。`,
|
||||
ShortDescription: "开关批量任务 Cron 自动调度",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"schedule_enabled": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "true 允许定时触发,false 仅手工执行",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "schedule_enabled"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
en, ok := mcpArgBool(args, "schedule_enabled")
|
||||
if !ok {
|
||||
return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil
|
||||
}
|
||||
if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists {
|
||||
return batchMCPTextResult("队列不存在", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.SetScheduleEnabled(qid, en) {
|
||||
return batchMCPTextResult("更新失败", true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en))
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
// --- add task ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskAdd,
|
||||
Description: "向处于 pending 状态的队列追加一条子任务。",
|
||||
ShortDescription: "批量队列添加子任务",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "任务指令内容",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "message"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
msg := strings.TrimSpace(mcpArgString(args, "message"))
|
||||
if qid == "" || msg == "" {
|
||||
return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil
|
||||
}
|
||||
task, err := h.batchTaskManager.AddTaskToQueue(qid, msg)
|
||||
if err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID))
|
||||
return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue})
|
||||
})
|
||||
|
||||
// --- update task ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskUpdate,
|
||||
Description: "修改 pending 队列中仍为 pending 的子任务文案。",
|
||||
ShortDescription: "更新批量子任务内容",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"task_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "子任务 ID",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "新的任务指令",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "task_id", "message"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
tid := mcpArgString(args, "task_id")
|
||||
msg := strings.TrimSpace(mcpArgString(args, "message"))
|
||||
if qid == "" || tid == "" || msg == "" {
|
||||
return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil
|
||||
}
|
||||
if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid))
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
// --- remove task ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskRemove,
|
||||
Description: "从 pending 队列中删除仍为 pending 的子任务。",
|
||||
ShortDescription: "删除批量子任务",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"task_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "子任务 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "task_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
tid := mcpArgString(args, "task_id")
|
||||
if qid == "" || tid == "" {
|
||||
return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil
|
||||
}
|
||||
if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid))
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 10))
|
||||
}
|
||||
|
||||
func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{{Type: "text", Text: text}},
|
||||
IsError: isErr,
|
||||
}
|
||||
}
|
||||
|
||||
func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) {
|
||||
b, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil
|
||||
}
|
||||
return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil
|
||||
}
|
||||
|
||||
func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) {
|
||||
if raw, ok := args["tasks"]; ok && raw != nil {
|
||||
switch t := raw.(type) {
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, x := range t {
|
||||
if s, ok := x.(string); ok {
|
||||
if tr := strings.TrimSpace(s); tr != "" {
|
||||
out = append(out, tr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
if txt := mcpArgString(args, "tasks_text"); txt != "" {
|
||||
lines := strings.Split(txt, "\n")
|
||||
out := make([]string, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
if tr := strings.TrimSpace(line); tr != "" {
|
||||
out = append(out, tr)
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out, ""
|
||||
}
|
||||
}
|
||||
return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)"
|
||||
}
|
||||
|
||||
func mcpArgString(args map[string]interface{}, key string) string {
|
||||
v, ok := args[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(t)
|
||||
case float64:
|
||||
return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64))
|
||||
case json.Number:
|
||||
return strings.TrimSpace(t.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(t))
|
||||
}
|
||||
}
|
||||
|
||||
func mcpArgFloat(args map[string]interface{}, key string) float64 {
|
||||
v, ok := args[key]
|
||||
if !ok || v == nil {
|
||||
return 0
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
return t
|
||||
case int:
|
||||
return float64(t)
|
||||
case int64:
|
||||
return float64(t)
|
||||
case json.Number:
|
||||
f, _ := t.Float64()
|
||||
return f
|
||||
case string:
|
||||
f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64)
|
||||
return f
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) {
|
||||
v, exists := args[key]
|
||||
if !exists {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
s := strings.ToLower(strings.TrimSpace(t))
|
||||
if s == "true" || s == "1" || s == "yes" {
|
||||
return true, true
|
||||
}
|
||||
if s == "false" || s == "0" || s == "no" {
|
||||
return false, true
|
||||
}
|
||||
case float64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
@@ -37,6 +37,9 @@ type WebshellToolRegistrar func() error
|
||||
// SkillsToolRegistrar Skills工具注册器接口
|
||||
type SkillsToolRegistrar func() error
|
||||
|
||||
// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册)
|
||||
type BatchTaskToolRegistrar func() error
|
||||
|
||||
// RetrieverUpdater 检索器更新接口
|
||||
type RetrieverUpdater interface {
|
||||
UpdateConfig(config *knowledge.RetrievalConfig)
|
||||
@@ -68,6 +71,7 @@ type ConfigHandler struct {
|
||||
vulnerabilityToolRegistrar VulnerabilityToolRegistrar // 漏洞工具注册器(可选)
|
||||
webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选)
|
||||
skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选)
|
||||
batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
@@ -141,6 +145,13 @@ func (h *ConfigHandler) SetSkillsToolRegistrar(registrar SkillsToolRegistrar) {
|
||||
h.skillsToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetBatchTaskToolRegistrar 设置批量任务 MCP 工具注册器
|
||||
func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistrar) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.batchTaskToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetRetrieverUpdater 设置检索器更新器
|
||||
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
|
||||
h.mu.Lock()
|
||||
@@ -999,6 +1010,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 重新注册批量任务 MCP 工具
|
||||
if h.batchTaskToolRegistrar != nil {
|
||||
h.logger.Info("重新注册批量任务 MCP 工具")
|
||||
if err := h.batchTaskToolRegistrar(); err != nil {
|
||||
h.logger.Error("重新注册批量任务 MCP 工具失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("批量任务 MCP 工具已重新注册")
|
||||
}
|
||||
}
|
||||
|
||||
// 如果知识库启用,重新注册知识库工具
|
||||
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
|
||||
h.logger.Info("重新注册知识库工具")
|
||||
|
||||
Reference in New Issue
Block a user