Add files via upload

This commit is contained in:
公明
2026-04-17 11:26:32 +08:00
committed by GitHub
parent 6fb96dcc0c
commit 15c7692988
6 changed files with 303 additions and 38 deletions
+99 -27
View File
@@ -11,6 +11,32 @@ import (
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// 批量任务状态常量
const (
BatchQueueStatusPending = "pending"
BatchQueueStatusRunning = "running"
BatchQueueStatusPaused = "paused"
BatchQueueStatusCompleted = "completed"
BatchQueueStatusCancelled = "cancelled"
BatchTaskStatusPending = "pending"
BatchTaskStatusRunning = "running"
BatchTaskStatusCompleted = "completed"
BatchTaskStatusFailed = "failed"
BatchTaskStatusCancelled = "cancelled"
// MaxBatchTasksPerQueue 单个队列最大任务数
MaxBatchTasksPerQueue = 10000
// MaxBatchQueueTitleLen 队列标题最大长度
MaxBatchQueueTitleLen = 200
// MaxBatchQueueRoleLen 角色名最大长度
MaxBatchQueueRoleLen = 100
)
// BatchTask 批量任务项
@@ -50,14 +76,19 @@ type BatchTaskQueue struct {
// BatchTaskManager 批量任务管理器
type BatchTaskManager struct {
db *database.DB
logger *zap.Logger
queues map[string]*BatchTaskQueue
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
mu sync.RWMutex
}
// NewBatchTaskManager 创建批量任务管理器
func NewBatchTaskManager() *BatchTaskManager {
func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
if logger == nil {
logger = zap.NewNop()
}
return &BatchTaskManager{
logger: logger,
queues: make(map[string]*BatchTaskQueue),
taskCancels: make(map[string]context.CancelFunc),
}
@@ -75,7 +106,18 @@ func (m *BatchTaskManager) CreateBatchQueue(
title, role, agentMode, scheduleMode, cronExpr string,
nextRunAt *time.Time,
tasks []string,
) *BatchTaskQueue {
) (*BatchTaskQueue, error) {
// 输入校验
if len(title) > MaxBatchQueueTitleLen {
return nil, fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
}
if len(role) > MaxBatchQueueRoleLen {
return nil, fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen)
}
if len(tasks) > MaxBatchTasksPerQueue {
return nil, fmt.Errorf("单个队列最多 %d 条任务", MaxBatchTasksPerQueue)
}
m.mu.Lock()
defer m.mu.Unlock()
@@ -131,13 +173,12 @@ func (m *BatchTaskManager) CreateBatchQueue(
queue.NextRunAt,
dbTasks,
); err != nil {
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
// 这里可以添加日志记录
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
}
}
m.queues[queueID] = queue
return queue
return queue, nil
}
// GetBatchQueue 获取批量任务队列
@@ -525,7 +566,7 @@ func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, s
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch task DB status update failed", zap.String("queueId", queueID), zap.String("taskId", taskID), zap.Error(err))
}
}
}
@@ -552,7 +593,7 @@ func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch queue DB status update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
}
@@ -578,11 +619,42 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
if m.db != nil {
if err := m.db.UpdateBatchQueueSchedule(queueID, queue.ScheduleMode, queue.CronExpr, queue.NextRunAt); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch queue DB schedule update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
}
// UpdateQueueMetadata 更新队列标题和角色(非 running 时可用)
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role string) error {
if len(title) > MaxBatchQueueTitleLen {
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
}
if len(role) > MaxBatchQueueRoleLen {
return fmt.Errorf("角色名不能超过 %d 个字符", MaxBatchQueueRoleLen)
}
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return fmt.Errorf("队列不存在")
}
if queue.Status == "running" {
return fmt.Errorf("队列正在运行中,无法修改")
}
queue.Title = title
queue.Role = role
if m.db != nil {
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role); err != nil {
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
return nil
}
// SetScheduleEnabled 暂停/恢复 Cron 自动调度(不影响手工执行)
func (m *BatchTaskManager) SetScheduleEnabled(queueID string, enabled bool) bool {
m.mu.Lock()
@@ -661,6 +733,8 @@ func (m *BatchTaskManager) ResetQueueForRerun(queueID string) bool {
queue.StartedAt = nil
queue.CompletedAt = nil
queue.NextRunAt = nil
queue.LastRunError = ""
queue.LastScheduleError = ""
for _, task := range queue.Tasks {
task.Status = "pending"
task.ConversationID = ""
@@ -832,8 +906,8 @@ func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool {
// GetNextTask 获取下一个待执行的任务
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
@@ -866,7 +940,7 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch queue DB index update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
}
@@ -885,15 +959,14 @@ func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFu
// PauseQueue 暂停队列
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return false
}
if queue.Status != "running" {
m.mu.Unlock()
return false
}
@@ -905,12 +978,10 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
delete(m.taskCancels, queueID)
}
m.mu.Unlock()
// 同步队列状态到数据库
// 同步队列状态到数据库(在锁内完成,避免竞态)
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, "paused"); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch queue DB pause update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
@@ -920,15 +991,14 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
m.mu.Unlock()
return false
}
if queue.Status == "completed" || queue.Status == "cancelled" {
m.mu.Unlock()
return false
}
@@ -941,7 +1011,6 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
if task.Status == "pending" {
task.Status = "cancelled"
task.CompletedAt = &now
// 同步到数据库
if m.db != nil {
m.db.UpdateBatchTaskStatus(queueID, task.ID, "cancelled", "", "", "")
}
@@ -954,35 +1023,38 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
delete(m.taskCancels, queueID)
}
m.mu.Unlock()
// 同步队列状态到数据库
// 同步队列状态到数据库(在锁内完成)
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch queue DB cancel update failed", zap.String("queueId", queueID), zap.Error(err))
}
}
return true
}
// DeleteQueue 删除队列
// DeleteQueue 删除队列(运行中的队列不允许删除)
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
_, exists := m.queues[queueID]
queue, exists := m.queues[queueID]
if !exists {
return false
}
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
if queue.Status == "running" {
return false
}
// 清理取消函数
delete(m.taskCancels, queueID)
// 从数据库删除
if m.db != nil {
if err := m.db.DeleteBatchQueue(queueID); err != nil {
// 记录错误但继续(使用内存缓存)
m.logger.Warn("batch queue DB delete failed", zap.String("queueId", queueID), zap.Error(err))
}
}