From 9306303d9911a43d76c451147a19600ce374a184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Wed, 24 Jun 2026 01:46:30 +0800 Subject: [PATCH] Add files via upload --- internal/handler/agent.go | 385 ++------------------ internal/handler/batch_queue_executor.go | 352 ++++++++++++++++++ internal/handler/batch_task_manager.go | 259 ++++++++++--- internal/handler/batch_task_manager_test.go | 121 ++++++ internal/handler/batch_task_mcp.go | 34 +- 5 files changed, 745 insertions(+), 406 deletions(-) create mode 100644 internal/handler/batch_queue_executor.go create mode 100644 internal/handler/batch_task_manager_test.go diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 143b1f16..25b3895b 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -21,7 +21,6 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/reasoning" - "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/openai" @@ -178,8 +177,6 @@ type AgentHandler struct { } agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) batchCronParser cron.Parser - batchRunnerMu sync.Mutex - batchRunning map[string]struct{} // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) hitlWhitelistSaver HitlToolWhitelistSaver audit *audit.Service @@ -233,7 +230,6 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo config: cfg, hitlManager: NewHITLManager(db, logger), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), - batchRunning: make(map[string]struct{}), } if err := handler.hitlManager.EnsureSchema(); err != nil { logger.Warn("初始化 HITL 表失败", zap.Error(err)) @@ -1470,6 +1466,7 @@ type BatchTaskRequest struct { CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选) + Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8 } // batchQueueWantsEino 队列是否配置为走 Eino 多代理。 @@ -1529,7 +1526,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) { nextRunAt = &next } - queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks) + queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks) if createErr != nil { c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) return @@ -1719,15 +1716,16 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) { func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) { queueID := c.Param("queueId") var req struct { - Title string `json:"title"` - Role string `json:"role"` - AgentMode string `json:"agentMode"` + Title string `json:"title"` + Role string `json:"role"` + AgentMode string `json:"agentMode"` + Concurrency *int `json:"concurrency"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil { + if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -1802,9 +1800,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) { // DeleteBatchQueue 删除批量任务队列 func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { queueID := c.Param("queueId") - success := h.batchTaskManager.DeleteQueue(queueID) - if !success { - c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + if err := h.batchTaskManager.DeleteQueue(queueID); err != nil { + switch { + case errors.Is(err, ErrBatchQueueNotFound): + c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) + case errors.Is(err, ErrBatchQueueExecutorActive): + c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"}) + case errors.Is(err, ErrBatchQueueStillRunning): + c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } return } if h.audit != nil { @@ -1898,7 +1904,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) { // 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动 if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused { - h.forceUnmarkBatchQueueRunning(queueID) + h.batchTaskManager.ForceUnmarkQueueExecutor(queueID) } autoStarted := true @@ -1957,26 +1963,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) } -func (h *AgentHandler) markBatchQueueRunning(queueID string) bool { - h.batchRunnerMu.Lock() - defer h.batchRunnerMu.Unlock() - if _, exists := h.batchRunning[queueID]; exists { - return false - } - h.batchRunning[queueID] = struct{}{} - return true -} - -func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) { - h.batchRunnerMu.Lock() - defer h.batchRunnerMu.Unlock() - delete(h.batchRunning, queueID) -} - -func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) { - h.unmarkBatchQueueRunning(queueID) -} - func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { expr := strings.TrimSpace(cronExpr) if expr == "" { @@ -1992,43 +1978,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { // 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断 - if !h.markBatchQueueRunning(queueID) { + if !h.batchTaskManager.TryMarkQueueExecutor(queueID) { return true, nil } queue, exists := h.batchTaskManager.GetBatchQueue(queueID) if !exists { - h.unmarkBatchQueueRunning(queueID) + h.batchTaskManager.UnmarkQueueExecutor(queueID) return false, nil } if scheduled { if queue.ScheduleMode != "cron" { - h.unmarkBatchQueueRunning(queueID) + h.batchTaskManager.UnmarkQueueExecutor(queueID) err := fmt.Errorf("队列未启用 cron 调度") h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) return true, err } if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { - h.unmarkBatchQueueRunning(queueID) + h.batchTaskManager.UnmarkQueueExecutor(queueID) err := fmt.Errorf("当前队列状态不允许被调度执行") h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) return true, err } if !h.batchTaskManager.ResetQueueForRerun(queueID) { - h.unmarkBatchQueueRunning(queueID) + h.batchTaskManager.UnmarkQueueExecutor(queueID) err := fmt.Errorf("重置队列失败") h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) return true, err } queue, _ = h.batchTaskManager.GetBatchQueue(queueID) } else if queue.Status != "pending" && queue.Status != "paused" { - h.unmarkBatchQueueRunning(queueID) + h.batchTaskManager.UnmarkQueueExecutor(queueID) return true, fmt.Errorf("队列状态不允许启动") } if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) { - h.unmarkBatchQueueRunning(queueID) + h.batchTaskManager.UnmarkQueueExecutor(queueID) err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理") if scheduled { h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) @@ -2080,325 +2066,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() { } } -// executeBatchQueue 执行批量任务队列 -func (h *AgentHandler) executeBatchQueue(queueID string) { - defer h.unmarkBatchQueueRunning(queueID) - h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID)) - - for { - // 检查队列状态 - queue, exists := h.batchTaskManager.GetBatchQueue(queueID) - if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" { - break - } - - // 获取下一个任务 - task, hasNext := h.batchTaskManager.GetNextTask(queueID) - if !hasNext { - // 所有任务完成:汇总子任务失败信息便于排障 - q, ok := h.batchTaskManager.GetBatchQueue(queueID) - lastRunErr := "" - if ok { - for _, t := range q.Tasks { - if t.Status == "failed" && t.Error != "" { - lastRunErr = t.Error - } - } - } - h.batchTaskManager.SetLastRunError(queueID, lastRunErr) - h.batchTaskManager.UpdateQueueStatus(queueID, "completed") - h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID)) - break - } - - // 更新任务状态为运行中 - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "") - - // 创建新对话 - title := safeTruncateString(task.Message, 50) - batchMeta := audit.ConversationCreateMeta("batch_task") - batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID) - conv, err := h.db.CreateConversation(title, batchMeta) - var conversationID string - if err != nil { - h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error()) - h.batchTaskManager.MoveToNextTask(queueID) - if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) { - h.batchTaskManager.UpdateQueueStatus(queueID, "paused") - break - } - continue - } - conversationID = conv.ID - - // 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话) - h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID) - - // 应用角色用户提示词和工具配置 - finalMessage := task.Message - var roleTools []string // 角色配置的工具列表 - if queue.Role != "" && queue.Role != "默认" { - if h.config.Roles != nil { - if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled { - // 应用用户提示词 - if role.UserPrompt != "" { - finalMessage = role.UserPrompt + "\n\n" + task.Message - h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role)) - } - // 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段) - if len(role.Tools) > 0 { - roleTools = role.Tools - h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools))) - } - } - } - } - - // 保存用户消息(保存原始消息,不包含角色提示词) - _, err = h.db.AddMessage(conversationID, "user", task.Message, nil) - if err != nil { - h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - } - - // 预先创建助手消息,以便关联过程详情 - assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) - if err != nil { - h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - // 如果创建失败,继续执行但不保存过程详情 - assistantMsg = nil - } - - // 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil) - var assistantMessageID string - if assistantMsg != nil { - assistantMessageID = assistantMsg.ID - } - // 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」, - // 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。 - // progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。 - var progressCallback func(eventType, message string, data interface{}) - - // 执行任务(使用包含角色提示词的finalMessage和角色工具列表) - h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID)) - - func() { - // 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。 - baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) - // 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致) - taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour) - - registered := false - finishStatus := "completed" - - defer func() { - h.batchTaskManager.SetTaskCancel(queueID, nil) - timeoutCancel() - if registered { - // 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。 - if h.taskEventBus != nil { - ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}} - if b, err := json.Marshal(ev); err == nil { - h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n')) - } - } - h.tasks.FinishTask(conversationID, finishStatus) - } - cancelWithCause(nil) - }() - - // 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。 - sendEvent := func(eventType, message string, data interface{}) { - if h.taskEventBus == nil { - return - } - ev := StreamEvent{Type: eventType, Message: message, Data: data} - b, err := json.Marshal(ev) - if err != nil { - b = []byte(`{"type":"error","message":"marshal failed"}`) - } - line := make([]byte, 0, len(b)+8) - line = append(line, []byte("data: ")...) - line = append(line, b...) - line = append(line, '\n', '\n') - h.taskEventBus.Publish(conversationID, line) - } - - if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil { - h.logger.Warn("批量队列子任务注册会话运行状态失败", - zap.String("queueId", queueID), - zap.String("taskId", task.ID), - zap.String("conversationId", conversationID), - zap.Error(err)) - failMsg := err.Error() - if errors.Is(err, ErrTaskAlreadyRunning) { - failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务" - } - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg) - return - } - registered = true - // 存储取消函数:暂停队列时取消子任务 context(与原先语义一致) - h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel) - - // 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。 - progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) - taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) - taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) - taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks) - - // 使用队列配置的角色工具列表(如果为空,表示使用所有工具) - useBatchMulti := false - batchOrch := "deep" - am := strings.TrimSpace(strings.ToLower(queue.AgentMode)) - if am == "multi" { - am = "deep" - } - if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled { - useBatchMulti = true - batchOrch = config.NormalizeMultiAgentOrchestration(am) - } else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent { - // 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关 - useBatchMulti = true - batchOrch = "deep" - } - var resultMA *multiagent.RunResult - var runErr error - switch { - case useBatchMulti: - resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID)) - default: - if h.config == nil { - runErr = fmt.Errorf("服务器配置未加载") - } else { - resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID)) - } - } - - if runErr != nil { - if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { - h.persistEinoAgentTraceForResume(conversationID, resultMA) - } - errStr := runErr.Error() - partialResp := "" - if resultMA != nil { - partialResp = resultMA.Response - } - isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) || - errors.Is(runErr, context.Canceled) || - strings.Contains(strings.ToLower(errStr), "context canceled") || - strings.Contains(strings.ToLower(errStr), "context cancelled") || - (partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断"))) - isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) - - if isTimeout { - finishStatus = "timeout" - } else if isCancelled { - finishStatus = "cancelled" - } else { - finishStatus = "failed" - } - - if isCancelled { - h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) - cancelMsg := "任务已被用户取消,后续操作已停止。" - // 如果执行结果中有更具体的取消消息,使用它 - if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) { - cancelMsg = partialResp - } - // 更新助手消息内容 - if assistantMessageID != "" { - if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil { - h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) - } - // 保存取消详情到数据库 - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil { - h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } else { - // 如果没有预先创建的助手消息,创建一个新的 - _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil) - if errMsg != nil { - h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) - } - } - h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) - } else { - h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr)) - errorMsg := "执行失败: " + runErr.Error() - // 更新助手消息内容 - if assistantMessageID != "" { - if _, updateErr := h.db.Exec( - "UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", - errorMsg, - time.Now(), assistantMessageID, - ); updateErr != nil { - h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) - } - // 保存错误详情到数据库 - if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil { - h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } - h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error()) - } - } else { - h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) - - resText := resultMA.Response - mcpIDs := resultMA.MCPExecutionIDs - lastIn := resultMA.LastAgentTraceInput - lastOut := resultMA.LastAgentTraceOutput - - // 更新助手消息内容 - if assistantMessageID != "" { - if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil { - h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) - // 如果更新失败,尝试创建新消息 - _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs) - if err != nil { - h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - } - } - } else { - // 如果没有预先创建的助手消息,创建一个新的 - _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs) - if err != nil { - h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) - } - } - - // 保存代理轨迹 - if lastIn != "" || lastOut != "" { - if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil { - h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } else { - h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) - } - } - - // 保存结果 - h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID) - } - }() - - // 移动到下一个任务 - h.batchTaskManager.MoveToNextTask(queueID) - - if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) { - h.batchTaskManager.UpdateQueueStatus(queueID, "paused") - h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID)) - break - } - - // 检查是否被取消或暂停 - queue, _ = h.batchTaskManager.GetBatchQueue(queueID) - if queue.Status == "cancelled" || queue.Status == "paused" { - break - } - } -} - // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。 // 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。 func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) { diff --git a/internal/handler/batch_queue_executor.go b/internal/handler/batch_queue_executor.go new file mode 100644 index 00000000..06640be9 --- /dev/null +++ b/internal/handler/batch_queue_executor.go @@ -0,0 +1,352 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/audit" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/multiagent" + + "go.uber.org/zap" +) + +const batchQueueWorkerIdlePoll = 200 * time.Millisecond + +// executeBatchQueue 使用并发 worker 池执行批量任务队列。 +func (h *AgentHandler) executeBatchQueue(queueID string) { + defer h.batchTaskManager.UnmarkQueueExecutor(queueID) + + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists { + return + } + concurrency := normalizeBatchQueueConcurrency(queue.Concurrency) + h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency)) + + var wg sync.WaitGroup + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + h.runBatchQueueWorker(queueID) + }() + } + wg.Wait() + + h.tryFinalizeBatchQueue(queueID) +} + +func (h *AgentHandler) runBatchQueueWorker(queueID string) { + for { + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if batchQueueExecutionShouldStop(queue, exists) { + return + } + + task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID) + if !ok { + if !h.batchTaskManager.HasRunningTasks(queueID) { + return + } + time.Sleep(batchQueueWorkerIdlePoll) + continue + } + + queue, _ = h.batchTaskManager.GetBatchQueue(queueID) + if queue == nil { + return + } + + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "") + h.executeOneBatchSubTask(queueID, queue, task) + + if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) { + h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused) + h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID)) + return + } + + queue, exists = h.batchTaskManager.GetBatchQueue(queueID) + if batchQueueExecutionShouldStop(queue, exists) { + if !exists { + h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID)) + } + return + } + } +} + +func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) { + queue, exists := h.batchTaskManager.GetBatchQueue(queueID) + if !exists || queue == nil { + return + } + if queue.Status != BatchQueueStatusRunning { + return + } + if h.batchTaskManager.HasPendingOrRunningTasks(queueID) { + return + } + + lastRunErr := "" + for _, t := range queue.Tasks { + if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" { + lastRunErr = t.Error + } + } + h.batchTaskManager.SetLastRunError(queueID, lastRunErr) + h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted) + h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID)) +} + +// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。 +func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) { + title := safeTruncateString(task.Message, 50) + batchMeta := audit.ConversationCreateMeta("batch_task") + batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID) + conv, err := h.db.CreateConversation(title, batchMeta) + if err != nil { + h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error()) + return + } + conversationID := conv.ID + + h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID) + + finalMessage := task.Message + var roleTools []string + if queue.Role != "" && queue.Role != "默认" { + if h.config.Roles != nil { + if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled { + if role.UserPrompt != "" { + finalMessage = role.UserPrompt + "\n\n" + task.Message + h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role)) + } + if len(role.Tools) > 0 { + roleTools = role.Tools + h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools))) + } + } + } + } + + if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil { + h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + } + + assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) + if err != nil { + h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + assistantMsg = nil + } + + var assistantMessageID string + if assistantMsg != nil { + assistantMessageID = assistantMsg.ID + } + + h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID)) + + baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) + taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour) + + registered := false + finishStatus := "completed" + + defer func() { + h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil) + timeoutCancel() + if registered { + if h.taskEventBus != nil { + ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}} + if b, err := json.Marshal(ev); err == nil { + h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n')) + } + } + h.tasks.FinishTask(conversationID, finishStatus) + } + cancelWithCause(nil) + }() + + sendEvent := func(eventType, message string, data interface{}) { + if h.taskEventBus == nil { + return + } + ev := StreamEvent{Type: eventType, Message: message, Data: data} + b, err := json.Marshal(ev) + if err != nil { + b = []byte(`{"type":"error","message":"marshal failed"}`) + } + line := make([]byte, 0, len(b)+8) + line = append(line, []byte("data: ")...) + line = append(line, b...) + line = append(line, '\n', '\n') + h.taskEventBus.Publish(conversationID, line) + } + + if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil { + h.logger.Warn("批量队列子任务注册会话运行状态失败", + zap.String("queueId", queueID), + zap.String("taskId", task.ID), + zap.String("conversationId", conversationID), + zap.Error(err)) + failMsg := err.Error() + if errors.Is(err, ErrTaskAlreadyRunning) { + failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务" + } + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg) + return + } + registered = true + h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel) + + progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) + taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID) + taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks) + taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks) + + useBatchMulti := false + batchOrch := "deep" + am := strings.TrimSpace(strings.ToLower(queue.AgentMode)) + if am == "multi" { + am = "deep" + } + if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled { + useBatchMulti = true + batchOrch = config.NormalizeMultiAgentOrchestration(am) + } else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent { + useBatchMulti = true + batchOrch = "deep" + } + + var resultMA *multiagent.RunResult + var runErr error + switch { + case useBatchMulti: + resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID)) + default: + if h.config == nil { + runErr = fmt.Errorf("服务器配置未加载") + } else { + resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID)) + } + } + + if runErr != nil { + h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus) + return + } + + if resultMA == nil { + h.logger.Error("批量任务执行成功但无结果对象", + zap.String("queueId", queueID), + zap.String("taskId", task.ID), + zap.String("conversationId", conversationID)) + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果") + return + } + + h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) + + resText := resultMA.Response + mcpIDs := resultMA.MCPExecutionIDs + lastIn := resultMA.LastAgentTraceInput + lastOut := resultMA.LastAgentTraceOutput + + if assistantMessageID != "" { + if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil { + h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) + if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil { + h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + } + } + } else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil { + h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err)) + } + + if lastIn != "" || lastOut != "" { + if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil { + h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } + } + + h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID) +} + +func (h *AgentHandler) handleBatchSubTaskRunError( + queueID string, + task *BatchTask, + conversationID, assistantMessageID string, + baseCtx, taskCtx context.Context, + resultMA *multiagent.RunResult, + runErr error, + finishStatus *string, +) { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, resultMA) + } + errStr := runErr.Error() + partialResp := "" + if resultMA != nil { + partialResp = resultMA.Response + } + isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) || + errors.Is(runErr, context.Canceled) || + strings.Contains(strings.ToLower(errStr), "context canceled") || + strings.Contains(strings.ToLower(errStr), "context cancelled") || + (partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断"))) + isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) + + if isTimeout { + *finishStatus = "timeout" + } else if isCancelled { + *finishStatus = "cancelled" + } else { + *finishStatus = "failed" + } + + if isCancelled { + h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) + cancelMsg := "任务已被用户取消,后续操作已停止。" + if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) { + cancelMsg = partialResp + } + if assistantMessageID != "" { + if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil { + h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) + } + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil { + h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } + } else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil { + h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) + } + h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID) + return + } + + h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr)) + errorMsg := "执行失败: " + runErr.Error() + if assistantMessageID != "" { + if _, updateErr := h.db.Exec( + "UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", + errorMsg, + time.Now(), assistantMessageID, + ); updateErr != nil { + h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr)) + } + if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil { + h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } + } + h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error()) +} diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go index 9a53d20b..b99822d0 100644 --- a/internal/handler/batch_task_manager.go +++ b/internal/handler/batch_task_manager.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "sort" "strings" @@ -17,6 +18,15 @@ import ( "go.uber.org/zap" ) +var ( + // ErrBatchQueueNotFound 队列不存在或已从内存卸载。 + ErrBatchQueueNotFound = errors.New("batch queue not found") + // ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。 + ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active") + // ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。 + ErrBatchQueueStillRunning = errors.New("batch queue is still running") +) + // 批量任务状态常量 const ( BatchQueueStatusPending = "pending" @@ -39,6 +49,12 @@ const ( // MaxBatchQueueRoleLen 角色名最大长度 MaxBatchQueueRoleLen = 100 + + // DefaultBatchQueueConcurrency 批量队列默认并发数(串行) + DefaultBatchQueueConcurrency = 1 + + // MaxBatchQueueConcurrency 批量队列最大并发数 + MaxBatchQueueConcurrency = 8 ) // BatchTask 批量任务项 @@ -67,6 +83,7 @@ type BatchTaskQueue struct { LastScheduleError string `json:"lastScheduleError,omitempty"` LastRunError string `json:"lastRunError,omitempty"` ProjectID string `json:"projectId,omitempty"` + Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1 Tasks []*BatchTask `json:"tasks"` Status string `json:"status"` // pending, running, paused, completed, cancelled CreatedAt time.Time `json:"createdAt"` @@ -80,8 +97,9 @@ type BatchTaskManager struct { db *database.DB logger *zap.Logger queues map[string]*BatchTaskQueue - taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 + taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数 singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列 + queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦) mu sync.RWMutex } @@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager { return &BatchTaskManager{ logger: logger, queues: make(map[string]*BatchTaskQueue), - taskCancels: make(map[string]context.CancelFunc), + taskCancels: make(map[string]map[string]context.CancelFunc), singleRunTasks: make(map[string]string), + queueExecutors: make(map[string]struct{}), } } +// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。 +func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool { + if !exists || queue == nil { + return true + } + switch queue.Status { + case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused: + return true + default: + return false + } +} + +// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。 +func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool { + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.queueExecutors[queueID]; exists { + return false + } + m.queueExecutors[queueID] = struct{}{} + return true +} + +// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。 +func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.queueExecutors, queueID) +} + +// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。 +func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) { + m.UnmarkQueueExecutor(queueID) +} + +// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。 +func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, ok := m.queueExecutors[queueID] + return ok +} + // SetDB 设置数据库连接 func (m *BatchTaskManager) SetDB(db *database.DB) { m.mu.Lock() @@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) { m.db = db } +// normalizeBatchQueueConcurrency 规范化队列并发数。 +func normalizeBatchQueueConcurrency(n int) int { + if n < 1 { + return DefaultBatchQueueConcurrency + } + if n > MaxBatchQueueConcurrency { + return MaxBatchQueueConcurrency + } + return n +} + // CreateBatchQueue 创建批量任务队列 func (m *BatchTaskManager) CreateBatchQueue( title, role, agentMode, scheduleMode, cronExpr, projectID string, nextRunAt *time.Time, + concurrency int, tasks []string, ) (*BatchTaskQueue, error) { // 输入校验 @@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue( CronExpr: strings.TrimSpace(cronExpr), NextRunAt: nextRunAt, ScheduleEnabled: true, + Concurrency: normalizeBatchQueueConcurrency(concurrency), Tasks: make([]*BatchTask, 0, len(tasks)), Status: BatchQueueStatusPending, CreatedAt: time.Now(), @@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue( queue.CronExpr, queue.NextRunAt, queue.ProjectID, + queue.Concurrency, dbTasks, ); err != nil { m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) @@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { if queueRow.ProjectID.Valid { queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) } + queue.Concurrency = batchQueueConcurrencyFromRow(queueRow) if queueRow.StartedAt.Valid { queue.StartedAt = &queueRow.StartedAt.Time } @@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error { if queueRow.ProjectID.Valid { queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) } + queue.Concurrency = batchQueueConcurrencyFromRow(queueRow) if queueRow.StartedAt.Valid { queue.StartedAt = &queueRow.StartedAt.Time } @@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s } } -// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用) -func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error { +// batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。 +func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int { + if row == nil || !row.Concurrency.Valid { + return DefaultBatchQueueConcurrency + } + return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64)) +} + +// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用) +func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error { if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) } @@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s queue.Title = title queue.Role = role queue.AgentMode = agentMode + if concurrency != nil { + queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency) + } if m.db != nil { - if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil { + if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil { m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err)) } } @@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, // PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引 func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error { - var cancelFunc context.CancelFunc var siblingRunningIDs []string m.mu.Lock() @@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error { } // 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项 + var cancelFuncs []context.CancelFunc if queue.Status == BatchQueueStatusPaused { - if c, ok := m.taskCancels[queueID]; ok { - cancelFunc = c - delete(m.taskCancels, queueID) - } + cancelFuncs = m.drainTaskCancelsLocked(queueID) for _, t := range queue.Tasks { if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning { siblingRunningIDs = append(siblingRunningIDs, t.ID) @@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error { resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled m.mu.Unlock() - if cancelFunc != nil { - cancelFunc() + for _, c := range cancelFuncs { + if c != nil { + c() + } } const staleRunMsg = "为单条执行其它任务,已中止" for _, sid := range siblingRunningIDs { @@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool } } -// GetNextTask 获取下一个待执行的任务 +// ClaimNextPendingTask 原子领取下一个待执行子任务(并发 worker 安全)。 +func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + queue, exists := m.queues[queueID] + if !exists || queue == nil { + return nil, false + } + if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused { + return nil, false + } + + onlyTaskID := "" + if m.singleRunTasks != nil { + onlyTaskID = m.singleRunTasks[queueID] + } + + for i, task := range queue.Tasks { + if task == nil || task.Status != BatchTaskStatusPending { + continue + } + if onlyTaskID != "" && task.ID != onlyTaskID { + continue + } + task.Status = BatchTaskStatusRunning + queue.CurrentIndex = i + return task, true + } + return nil, false +} + +// HasRunningTasks 队列是否仍有 running 状态的子任务。 +func (m *BatchTaskManager) HasRunningTasks(queueID string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + queue, exists := m.queues[queueID] + if !exists || queue == nil { + return false + } + for _, task := range queue.Tasks { + if task != nil && task.Status == BatchTaskStatusRunning { + return true + } + } + return false +} + +// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。 +func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + queue, exists := m.queues[queueID] + if !exists || queue == nil { + return false + } + for _, task := range queue.Tasks { + if task == nil { + continue + } + if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning { + return true + } + } + return false +} + +// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。 +func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc { + taskMap, ok := m.taskCancels[queueID] + if !ok || len(taskMap) == 0 { + return nil + } + cancels := make([]context.CancelFunc, 0, len(taskMap)) + for _, c := range taskMap { + if c != nil { + cancels = append(cancels, c) + } + } + delete(m.taskCancels, queueID) + return cancels +} + +// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask) func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { m.mu.Lock() defer m.mu.Unlock() @@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) { } } -// SetTaskCancel 设置当前任务的取消函数 -func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) { +// SetTaskCancel 设置子任务的取消函数 +func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) { m.mu.Lock() defer m.mu.Unlock() - if cancel != nil { - m.taskCancels[queueID] = cancel - } else { - delete(m.taskCancels, queueID) + if cancel == nil { + if taskMap, ok := m.taskCancels[queueID]; ok { + delete(taskMap, taskID) + if len(taskMap) == 0 { + delete(m.taskCancels, queueID) + } + } + return } + if m.taskCancels[queueID] == nil { + m.taskCancels[queueID] = make(map[string]context.CancelFunc) + } + m.taskCancels[queueID][taskID] = cancel } // PauseQueue 暂停队列 func (m *BatchTaskManager) PauseQueue(queueID string) bool { - var cancelFunc context.CancelFunc + var cancelFuncs []context.CancelFunc m.mu.Lock() queue, exists := m.queues[queueID] @@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { } queue.Status = BatchQueueStatusPaused - - // 取消当前正在执行的任务(通过取消context) - if cancel, ok := m.taskCancels[queueID]; ok { - cancelFunc = cancel - delete(m.taskCancels, queueID) - } + cancelFuncs = m.drainTaskCancelsLocked(queueID) m.mu.Unlock() - // 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) - if cancelFunc != nil { - cancelFunc() + for _, c := range cancelFuncs { + c() } return true @@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { // CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) func (m *BatchTaskManager) CancelQueue(queueID string) bool { now := time.Now() - var cancelFunc context.CancelFunc + var cancelFuncs []context.CancelFunc m.mu.Lock() queue, exists := m.queues[queueID] @@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool { } } - // 取消当前正在执行的任务 - if cancel, ok := m.taskCancels[queueID]; ok { - cancelFunc = cancel - delete(m.taskCancels, queueID) - } + cancelFuncs = m.drainTaskCancelsLocked(queueID) m.mu.Unlock() - // 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) - if cancelFunc != nil { - cancelFunc() + for _, c := range cancelFuncs { + c() } return true } -// DeleteQueue 删除队列(运行中的队列不允许删除) -func (m *BatchTaskManager) DeleteQueue(queueID string) bool { +// DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。 +func (m *BatchTaskManager) DeleteQueue(queueID string) error { m.mu.Lock() defer m.mu.Unlock() queue, exists := m.queues[queueID] if !exists { - return false + return ErrBatchQueueNotFound + } + + if _, exec := m.queueExecutors[queueID]; exec { + return ErrBatchQueueExecutorActive } // 运行中的队列不允许删除,防止孤儿协程和数据丢失 if queue.Status == BatchQueueStatusRunning { - return false + return ErrBatchQueueStillRunning } // 清理取消函数 @@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool { } delete(m.queues, queueID) - return true + return nil } // generateShortID 生成短ID diff --git a/internal/handler/batch_task_manager_test.go b/internal/handler/batch_task_manager_test.go new file mode 100644 index 00000000..998d18f1 --- /dev/null +++ b/internal/handler/batch_task_manager_test.go @@ -0,0 +1,121 @@ +package handler + +import ( + "errors" + "testing" + + "go.uber.org/zap" +) + +func TestNormalizeBatchQueueConcurrency(t *testing.T) { + if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency { + t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got) + } + if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency { + t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got) + } +} + +func TestClaimNextPendingTaskParallel(t *testing.T) { + m := NewBatchTaskManager(zap.NewNop()) + queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"}) + if err != nil { + t.Fatalf("CreateBatchQueue: %v", err) + } + m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning) + + t1, ok1 := m.ClaimNextPendingTask(queue.ID) + t2, ok2 := m.ClaimNextPendingTask(queue.ID) + if !ok1 || !ok2 || t1.ID == t2.ID { + t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2) + } + if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning { + t.Fatalf("claimed tasks should be running") + } + t3, ok3 := m.ClaimNextPendingTask(queue.ID) + if !ok3 { + t.Fatal("expected third claim") + } + _, ok4 := m.ClaimNextPendingTask(queue.ID) + if ok4 { + t.Fatal("expected no fourth pending task") + } + _ = t3 +} + +func TestBatchQueueExecutionShouldStop(t *testing.T) { + t.Parallel() + if !batchQueueExecutionShouldStop(nil, false) { + t.Fatal("expected stop when queue missing") + } + if !batchQueueExecutionShouldStop(nil, true) { + t.Fatal("expected stop when queue is nil but exists=true") + } + q := &BatchTaskQueue{Status: BatchQueueStatusRunning} + if batchQueueExecutionShouldStop(q, true) { + t.Fatal("expected continue when running") + } + q.Status = BatchQueueStatusCancelled + if !batchQueueExecutionShouldStop(q, true) { + t.Fatal("expected stop when cancelled") + } +} + +func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) { + t.Parallel() + m := NewBatchTaskManager(zap.NewNop()) + queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"}) + if err != nil { + t.Fatalf("CreateBatchQueue: %v", err) + } + if !m.TryMarkQueueExecutor(queue.ID) { + t.Fatal("expected to mark executor") + } + m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled) + + err = m.DeleteQueue(queue.ID) + if !errors.Is(err, ErrBatchQueueExecutorActive) { + t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err) + } + if _, ok := m.GetBatchQueue(queue.ID); !ok { + t.Fatal("queue should still exist while executor active") + } + + m.UnmarkQueueExecutor(queue.ID) + if err := m.DeleteQueue(queue.ID); err != nil { + t.Fatalf("expected delete after executor unmarked, got %v", err) + } + if _, ok := m.GetBatchQueue(queue.ID); ok { + t.Fatal("queue should be deleted") + } +} + +func TestDeleteQueueBlockedWhileRunning(t *testing.T) { + t.Parallel() + m := NewBatchTaskManager(zap.NewNop()) + queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"}) + if err != nil { + t.Fatalf("CreateBatchQueue: %v", err) + } + m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning) + + err = m.DeleteQueue(queue.ID) + if !errors.Is(err, ErrBatchQueueStillRunning) { + t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err) + } +} + +func TestTryMarkQueueExecutorDedupes(t *testing.T) { + t.Parallel() + m := NewBatchTaskManager(zap.NewNop()) + if !m.TryMarkQueueExecutor("q-1") { + t.Fatal("first mark should succeed") + } + if m.TryMarkQueueExecutor("q-1") { + t.Fatal("second mark should fail") + } + m.UnmarkQueueExecutor("q-1") + if !m.TryMarkQueueExecutor("q-1") { + t.Fatal("mark after unmark should succeed") + } +} diff --git a/internal/handler/batch_task_mcp.go b/internal/handler/batch_task_mcp.go index bba9ece1..e024903b 100644 --- a/internal/handler/batch_task_mcp.go +++ b/internal/handler/batch_task_mcp.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "errors" "fmt" "strconv" "strings" @@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z "type": "string", "description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)", }, + "concurrency": map[string]interface{}{ + "type": "integer", + "description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。", + }, }, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { @@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z executeNow = false } projectID := strings.TrimSpace(mcpArgString(args, "project_id")) - queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks) + concurrency := int(mcpArgFloat(args, "concurrency")) + queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks) if createErr != nil { return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil } @@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z if qid == "" { return batchMCPTextResult("queue_id 不能为空", true), nil } - if !h.batchTaskManager.DeleteQueue(qid) { - return batchMCPTextResult("删除失败:队列不存在", true), nil + if err := h.batchTaskManager.DeleteQueue(qid); err != nil { + switch { + case errors.Is(err, ErrBatchQueueNotFound): + return batchMCPTextResult("删除失败:队列不存在", true), nil + case errors.Is(err, ErrBatchQueueExecutorActive): + return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil + case errors.Is(err, ErrBatchQueueStillRunning): + return batchMCPTextResult("删除失败:队列正在运行中", true), nil + default: + return batchMCPTextResult("删除失败:"+err.Error(), true), nil + } } logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) return batchMCPTextResult("队列已删除。", false), nil @@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z "description": "代理模式:eino_single、deep、plan_execute、supervisor", "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, }, + "concurrency": map[string]interface{}{ + "type": "integer", + "description": "同时执行的子任务数,默认 1,最大 8", + }, }, "required": []string{"queue_id"}, }, @@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z title := mcpArgString(args, "title") role := mcpArgString(args, "role") agentMode := mcpArgString(args, "agent_mode") - if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil { + var concurrency *int + if raw, ok := args["concurrency"]; ok && raw != nil { + v := int(mcpArgFloat(args, "concurrency")) + concurrency = &v + } + if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil { return batchMCPTextResult(err.Error(), true), nil } updated, _ := h.batchTaskManager.GetBatchQueue(qid) @@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct { StartedAt *time.Time `json:"startedAt,omitempty"` CompletedAt *time.Time `json:"completedAt,omitempty"` CurrentIndex int `json:"currentIndex"` + Concurrency int `json:"concurrency"` TaskTotal int `json:"task_total"` TaskCounts map[string]int `json:"task_counts"` Tasks []batchTaskMCPListSummary `json:"tasks"` @@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem { StartedAt: q.StartedAt, CompletedAt: q.CompletedAt, CurrentIndex: q.CurrentIndex, + Concurrency: q.Concurrency, TaskTotal: len(tasks), TaskCounts: counts, Tasks: tasks,