From 3127781102829c04a6cfe02128a7ab8deaf7888c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:47:43 +0800 Subject: [PATCH] Add files via upload --- internal/database/batch_task.go | 12 ++++ internal/handler/agent.go | 2 +- internal/handler/batch_task_manager.go | 98 ++++++++++++++++++-------- 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/internal/database/batch_task.go b/internal/database/batch_task.go index 93d6ef97..2a331617 100644 --- a/internal/database/batch_task.go +++ b/internal/database/batch_task.go @@ -489,6 +489,18 @@ func (db *DB) AddBatchTask(queueID, taskID, message string) error { return nil } +// CancelPendingBatchTasks 批量取消队列中所有 pending 状态的任务(单条 SQL) +func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) error { + _, err := db.Exec( + "UPDATE batch_tasks SET status = ?, completed_at = ? WHERE queue_id = ? AND status = ?", + "cancelled", completedAt, queueID, "pending", + ) + if err != nil { + return fmt.Errorf("批量取消 pending 任务失败: %w", err) + } + return nil +} + // DeleteBatchTask 删除批量任务 func (db *DB) DeleteBatchTask(queueID, taskID string) error { _, err := db.Exec( diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 9b7c4b0e..9885b3cd 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -2062,7 +2062,7 @@ func (h *AgentHandler) batchQueueSchedulerLoop() { ticker := time.NewTicker(20 * time.Second) defer ticker.Stop() for range ticker.C { - queues := h.batchTaskManager.GetAllQueues() + queues := h.batchTaskManager.GetLoadedQueues() now := time.Now() for _, queue := range queues { if queue == nil || queue.ScheduleMode != "cron" || !queue.ScheduleEnabled || queue.Status == "cancelled" || queue.Status == "running" || queue.Status == "paused" { diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go index a8855f4b..f6353dc7 100644 --- a/internal/handler/batch_task_manager.go +++ b/internal/handler/batch_task_manager.go @@ -71,7 +71,6 @@ type BatchTaskQueue struct { StartedAt *time.Time `json:"startedAt,omitempty"` CompletedAt *time.Time `json:"completedAt,omitempty"` CurrentIndex int `json:"currentIndex"` - mu sync.RWMutex } // BatchTaskManager 批量任务管理器 @@ -298,6 +297,17 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { return queue } +// GetLoadedQueues 获取内存中已加载的队列(不触发 DB 加载,仅用 RLock) +func (m *BatchTaskManager) GetLoadedQueues() []*BatchTaskQueue { + m.mu.RLock() + result := make([]*BatchTaskQueue, 0, len(m.queues)) + for _, queue := range m.queues { + result = append(result, queue) + } + m.mu.RUnlock() + return result +} + // GetAllQueues 获取所有队列 func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue { m.mu.RLock() @@ -533,11 +543,13 @@ func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, resu // UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId) func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) { - m.mu.Lock() - defer m.mu.Unlock() + var needDBUpdate bool + // 在锁内只更新内存状态 + m.mu.Lock() queue, exists := m.queues[queueID] if !exists { + m.mu.Unlock() return } @@ -564,8 +576,11 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s } } - // 同步到数据库 - if m.db != nil { + needDBUpdate = m.db != nil + m.mu.Unlock() + + // 释放锁后写 DB + if needDBUpdate { if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil { m.logger.Warn("batch task DB status update failed", zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err)) } @@ -574,11 +589,13 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s // UpdateQueueStatus 更新队列状态 func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { - m.mu.Lock() - defer m.mu.Unlock() + var needDBUpdate bool + // 在锁内只更新内存状态 + m.mu.Lock() queue, exists := m.queues[queueID] if !exists { + m.mu.Unlock() return } @@ -591,8 +608,11 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { queue.CompletedAt = &now } - // 同步到数据库 - if m.db != nil { + needDBUpdate = m.db != nil + m.mu.Unlock() + + // 释放锁后写 DB + if needDBUpdate { if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil { m.logger.Warn("batch queue DB status update failed", zap.String("queueId", queueID), zap.Error(err)) } @@ -959,28 +979,40 @@ func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFu // PauseQueue 暂停队列 func (m *BatchTaskManager) PauseQueue(queueID string) bool { - m.mu.Lock() - defer m.mu.Unlock() + var cancelFunc context.CancelFunc + var needDBUpdate bool + // 在锁内只更新内存状态 + m.mu.Lock() queue, exists := m.queues[queueID] if !exists { + m.mu.Unlock() return false } if queue.Status != BatchQueueStatusRunning { + m.mu.Unlock() return false } queue.Status = BatchQueueStatusPaused // 取消当前正在执行的任务(通过取消context) - if cancel, exists := m.taskCancels[queueID]; exists { - cancel() + if cancel, ok := m.taskCancels[queueID]; ok { + cancelFunc = cancel delete(m.taskCancels, queueID) } - // 同步队列状态到数据库(在锁内完成,避免竞态) - if m.db != nil { + needDBUpdate = m.db != nil + m.mu.Unlock() + + // 释放锁后执行取消回调 + if cancelFunc != nil { + cancelFunc() + } + + // 释放锁后写 DB + if needDBUpdate { if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil { m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err)) } @@ -991,43 +1023,53 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { // CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) func (m *BatchTaskManager) CancelQueue(queueID string) bool { - m.mu.Lock() - defer m.mu.Unlock() + now := time.Now() + var cancelFunc context.CancelFunc + var needDBUpdate bool + // 在锁内只更新内存状态,不做 DB 操作 + m.mu.Lock() queue, exists := m.queues[queueID] if !exists { + m.mu.Unlock() return false } if queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled { + m.mu.Unlock() return false } queue.Status = BatchQueueStatusCancelled - now := time.Now() queue.CompletedAt = &now - // 取消所有待执行的任务 + // 内存中批量标记所有 pending 任务为 cancelled for _, task := range queue.Tasks { if task.Status == BatchTaskStatusPending { task.Status = BatchTaskStatusCancelled task.CompletedAt = &now - if m.db != nil { - if err := m.db.UpdateBatchTaskStatus(queueID, task.ID, BatchTaskStatusCancelled, "", "", ""); err != nil { - m.logger.Warn("batch task DB cancel update failed", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) - } - } } } // 取消当前正在执行的任务 - if cancel, exists := m.taskCancels[queueID]; exists { - cancel() + if cancel, ok := m.taskCancels[queueID]; ok { + cancelFunc = cancel delete(m.taskCancels, queueID) } - // 同步队列状态到数据库(在锁内完成) - if m.db != nil { + needDBUpdate = m.db != nil + m.mu.Unlock() + + // 释放锁后执行取消回调 + if cancelFunc != nil { + cancelFunc() + } + + // 释放锁后批量写 DB(单条 SQL 取消所有 pending 任务) + if needDBUpdate { + if err := m.db.CancelPendingBatchTasks(queueID, now); err != nil { + m.logger.Warn("batch task DB batch cancel failed", zap.String("queueId", queueID), zap.Error(err)) + } if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil { m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err)) }