From fb0724a862deb23e67186a1885709c2f1c7fb845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 17 Apr 2026 11:53:20 +0800 Subject: [PATCH] Add files via upload --- internal/handler/batch_task_manager.go | 63 ++++++++++++++------------ internal/handler/batch_task_mcp.go | 3 ++ 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go index dc9139e8..a8855f4b 100644 --- a/internal/handler/batch_task_manager.go +++ b/internal/handler/batch_task_manager.go @@ -9,6 +9,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "cyberstrike-ai/internal/database" @@ -108,10 +109,10 @@ func (m *BatchTaskManager) CreateBatchQueue( tasks []string, ) (*BatchTaskQueue, error) { // 输入校验 - if len(title) > MaxBatchQueueTitleLen { + if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { return nil, fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) } - if len(role) > MaxBatchQueueRoleLen { + if utf8.RuneCountInString(role) > MaxBatchQueueRoleLen { return nil, fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen) } if len(tasks) > MaxBatchTasksPerQueue { @@ -132,7 +133,7 @@ func (m *BatchTaskManager) CreateBatchQueue( NextRunAt: nextRunAt, ScheduleEnabled: true, Tasks: make([]*BatchTask, 0, len(tasks)), - Status: "pending", + Status: BatchQueueStatusPending, CreatedAt: time.Now(), CurrentIndex: 0, } @@ -152,7 +153,7 @@ func (m *BatchTaskManager) CreateBatchQueue( task := &BatchTask{ ID: taskID, Message: message, - Status: "pending", + Status: BatchTaskStatusPending, } queue.Tasks = append(queue.Tasks, task) dbTasks = append(dbTasks, map[string]interface{}{ @@ -553,10 +554,10 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s task.ConversationID = conversationID } now := time.Now() - if status == "running" && task.StartedAt == nil { + if status == BatchTaskStatusRunning && task.StartedAt == nil { task.StartedAt = &now } - if status == "completed" || status == "failed" || status == "cancelled" { + if status == BatchTaskStatusCompleted || status == BatchTaskStatusFailed || status == BatchTaskStatusCancelled { task.CompletedAt = &now } break @@ -583,10 +584,10 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) { queue.Status = status now := time.Now() - if status == "running" && queue.StartedAt == nil { + if status == BatchQueueStatusRunning && queue.StartedAt == nil { queue.StartedAt = &now } - if status == "completed" || status == "cancelled" { + if status == BatchQueueStatusCompleted || status == BatchQueueStatusCancelled { queue.CompletedAt = &now } @@ -626,10 +627,10 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s // UpdateQueueMetadata 更新队列标题和角色(非 running 时可用) func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role string) error { - if len(title) > MaxBatchQueueTitleLen { + if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) } - if len(role) > MaxBatchQueueRoleLen { + if utf8.RuneCountInString(role) > MaxBatchQueueRoleLen { return fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen) } @@ -640,7 +641,7 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role string) erro if !exists { return fmt.Errorf("队列不存在") } - if queue.Status == "running" { + if queue.Status == BatchQueueStatusRunning { return fmt.Errorf("队列正在运行中,无法修改") } @@ -728,7 +729,7 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { if !exists { return false } - queue.Status = "pending" + queue.Status = BatchQueueStatusPending queue.CurrentIndex = 0 queue.StartedAt = nil queue.CompletedAt = nil @@ -736,7 +737,7 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool { queue.LastRunError = "" queue.LastScheduleError = "" for _, task := range queue.Tasks { - task.Status = "pending" + task.Status = BatchTaskStatusPending task.ConversationID = "" task.StartedAt = nil task.CompletedAt = nil @@ -769,7 +770,7 @@ func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) er // 查找并更新任务 for _, task := range queue.Tasks { if task.ID == taskID { - if task.Status == "running" { + if task.Status == BatchTaskStatusRunning { return fmt.Errorf("执行中的任务不能编辑") } task.Message = message @@ -810,7 +811,7 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, task := &BatchTask{ ID: taskID, Message: message, - Status: "pending", + Status: BatchTaskStatusPending, } // 添加到内存队列 @@ -846,7 +847,7 @@ func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error { taskIndex := -1 for i, task := range queue.Tasks { if task.ID == taskID { - if task.Status == "running" { + if task.Status == BatchTaskStatusRunning { return fmt.Errorf("执行中的任务不能删除") } taskIndex = i @@ -878,7 +879,7 @@ func queueHasRunningTaskLocked(queue *BatchTaskQueue) bool { return false } for _, t := range queue.Tasks { - if t != nil && t.Status == "running" { + if t != nil && t.Status == BatchTaskStatusRunning { return true } } @@ -890,14 +891,14 @@ func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool { if queue == nil { return false } - if queue.Status == "running" { + if queue.Status == BatchQueueStatusRunning { return false } if queueHasRunningTaskLocked(queue) { return false } switch queue.Status { - case "pending", "paused", "completed", "cancelled": + case BatchQueueStatusPending, BatchQueueStatusPaused, BatchQueueStatusCompleted, BatchQueueStatusCancelled: return true default: return false @@ -916,7 +917,7 @@ func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { for i := queue.CurrentIndex; i < len(queue.Tasks); i++ { task := queue.Tasks[i] - if task.Status == "pending" { + if task.Status == BatchTaskStatusPending { queue.CurrentIndex = i return task, true } @@ -966,11 +967,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { return false } - if queue.Status != "running" { + if queue.Status != BatchQueueStatusRunning { return false } - queue.Status = "paused" + queue.Status = BatchQueueStatusPaused // 取消当前正在执行的任务(通过取消context) if cancel, exists := m.taskCancels[queueID]; exists { @@ -980,7 +981,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool { // 同步队列状态到数据库(在锁内完成,避免竞态) if m.db != nil { - if err := m.db.UpdateBatchQueueStatus(queueID, "paused"); err != nil { + if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusPaused); err != nil { m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err)) } } @@ -998,21 +999,23 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool { return false } - if queue.Status == "completed" || queue.Status == "cancelled" { + if queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled { return false } - queue.Status = "cancelled" + queue.Status = BatchQueueStatusCancelled now := time.Now() queue.CompletedAt = &now // 取消所有待执行的任务 for _, task := range queue.Tasks { - if task.Status == "pending" { - task.Status = "cancelled" + if task.Status == BatchTaskStatusPending { + task.Status = BatchTaskStatusCancelled task.CompletedAt = &now if m.db != nil { - m.db.UpdateBatchTaskStatus(queueID, task.ID, "cancelled", "", "", "") + if err := m.db.UpdateBatchTaskStatus(queueID, task.ID, BatchTaskStatusCancelled, "", "", ""); err != nil { + m.logger.Warn("batch task DB cancel update failed", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) + } } } } @@ -1025,7 +1028,7 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool { // 同步队列状态到数据库(在锁内完成) if m.db != nil { - if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil { + if err := m.db.UpdateBatchQueueStatus(queueID, BatchQueueStatusCancelled); err != nil { m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err)) } } @@ -1044,7 +1047,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool { } // 运行中的队列不允许删除,防止孤儿协程和数据丢失 - if queue.Status == "running" { + if queue.Status == BatchQueueStatusRunning { return false } diff --git a/internal/handler/batch_task_mcp.go b/internal/handler/batch_task_mcp.go index e2b23595..6d093fb1 100644 --- a/internal/handler/batch_task_mcp.go +++ b/internal/handler/batch_task_mcp.go @@ -69,6 +69,9 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z pageSize = 100 } offset := (page - 1) * pageSize + if offset > 100000 { + offset = 100000 + } queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword) if err != nil { return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil