Add files via upload

This commit is contained in:
公明
2026-01-02 00:18:39 +08:00
committed by GitHub
parent b90a29fdd7
commit 7b9dee7268
13 changed files with 3761 additions and 28 deletions
+297 -10
View File
@@ -7,8 +7,8 @@ import (
"fmt"
"net/http"
"strings"
"unicode/utf8"
"time"
"unicode/utf8"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/database"
@@ -25,16 +25,16 @@ func safeTruncateString(s string, maxLen int) string {
if utf8.RuneCountInString(s) <= maxLen {
return s
}
// 将字符串转换为 rune 切片以正确计算字符数
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
// 截断到最大长度
truncated := string(runes[:maxLen])
// 尝试在标点符号或空格处截断,使截断更自然
// 在截断点往前查找合适的断点(不超过20%的长度)
searchRange := maxLen / 5
@@ -43,7 +43,7 @@ func safeTruncateString(s string, maxLen int) string {
}
breakChars := []rune(",。、 ,.;:!?!?/\\-_")
bestBreakPos := len(runes[:maxLen])
for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- {
for _, breakChar := range breakChars {
if runes[i] == breakChar {
@@ -52,7 +52,7 @@ func safeTruncateString(s string, maxLen int) string {
}
}
}
found:
truncated = string(runes[:bestBreakPos])
return truncated + "..."
@@ -64,6 +64,7 @@ type AgentHandler struct {
db *database.DB
logger *zap.Logger
tasks *AgentTaskManager
batchTaskManager *BatchTaskManager
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}
@@ -71,11 +72,20 @@ type AgentHandler struct {
// NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
batchTaskManager := NewBatchTaskManager()
batchTaskManager.SetDB(db)
// 从数据库加载所有批量任务队列
if err := batchTaskManager.LoadFromDB(); err != nil {
logger.Warn("从数据库加载批量任务队列失败", zap.Error(err))
}
return &AgentHandler{
agent: agent,
db: db,
logger: logger,
tasks: NewAgentTaskManager(),
agent: agent,
db: db,
logger: logger,
tasks: NewAgentTaskManager(),
batchTaskManager: batchTaskManager,
}
}
@@ -724,6 +734,283 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
})
}
// ListCompletedTasks 列出最近完成的任务历史
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"tasks": h.tasks.GetCompletedTasks(),
})
}
// BatchTaskRequest 批量任务请求
type BatchTaskRequest struct {
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
}
// CreateBatchQueue 创建批量任务队列
func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
var req BatchTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(req.Tasks) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "任务列表不能为空"})
return
}
// 过滤空任务
validTasks := make([]string, 0, len(req.Tasks))
for _, task := range req.Tasks {
if task != "" {
validTasks = append(validTasks, task)
}
}
if len(validTasks) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的任务"})
return
}
queue := h.batchTaskManager.CreateBatchQueue(validTasks)
c.JSON(http.StatusOK, gin.H{
"queueId": queue.ID,
"queue": queue,
})
}
// GetBatchQueue 获取批量任务队列
func (h *AgentHandler) GetBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"queue": queue})
}
// ListBatchQueues 列出所有批量任务队列
func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
queues := h.batchTaskManager.GetAllQueues()
c.JSON(http.StatusOK, gin.H{"queues": queues})
}
// StartBatchQueue 开始执行批量任务队列
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
if queue.Status != "pending" && queue.Status != "paused" {
c.JSON(http.StatusBadRequest, gin.H{"error": "队列状态不允许启动"})
return
}
// 在后台执行批量任务
go h.executeBatchQueue(queueID)
h.batchTaskManager.UpdateQueueStatus(queueID, "running")
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
}
// CancelBatchQueue 取消批量任务队列
func (h *AgentHandler) CancelBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
success := h.batchTaskManager.CancelQueue(queueID)
if !success {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法取消"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已取消"})
}
// DeleteBatchQueue 删除批量任务队列
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
queueID := c.Param("queueId")
success := h.batchTaskManager.DeleteQueue(queueID)
if !success {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
}
// UpdateBatchTask 更新批量任务消息
func (h *AgentHandler) UpdateBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
taskID := c.Param("taskId")
var req struct {
Message string `json:"message" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Message == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"})
return
}
err := h.batchTaskManager.UpdateTaskMessage(queueID, taskID, req.Message)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 返回更新后的队列信息
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已更新", "queue": queue})
}
// AddBatchTask 添加任务到批量任务队列
func (h *AgentHandler) AddBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
var req struct {
Message string `json:"message" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Message == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"})
return
}
task, err := h.batchTaskManager.AddTaskToQueue(queueID, req.Message)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 返回更新后的队列信息
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
}
// DeleteBatchTask 删除批量任务
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
queueID := c.Param("queueId")
taskID := c.Param("taskId")
err := h.batchTaskManager.DeleteTask(queueID, taskID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 返回更新后的队列信息
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
}
// executeBatchQueue 执行批量任务队列
func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
for {
// 检查队列状态
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue.Status == "cancelled" || queue.Status == "completed" {
break
}
// 获取下一个任务
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
if !hasNext {
// 所有任务完成
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
break
}
// 更新任务状态为运行中
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
// 创建新对话
title := safeTruncateString(task.Message, 50)
conv, err := h.db.CreateConversation(title)
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
h.batchTaskManager.MoveToNextTask(queueID)
continue
}
conversationID = conv.ID
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
// 执行任务
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
result, err := h.agent.AgentLoopWithConversationID(ctx, task.Message, []agent.ChatMessage{}, conversationID)
cancel()
if err != nil {
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", err.Error())
} else {
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
// 保存助手回复
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
// 保存ReAct数据
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
// 保存结果
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", result.Response, "", conversationID)
}
// 移动到下一个任务
h.batchTaskManager.MoveToNextTask(queueID)
// 检查是否被取消
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue.Status == "cancelled" {
break
}
}
}
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
+589
View File
@@ -0,0 +1,589 @@
package handler
import (
"crypto/rand"
"encoding/hex"
"fmt"
"sync"
"time"
"cyberstrike-ai/internal/database"
)
// BatchTask 批量任务项
type BatchTask struct {
ID string `json:"id"`
Message string `json:"message"`
ConversationID string `json:"conversationId,omitempty"`
Status string `json:"status"` // pending, running, completed, failed, cancelled
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
Error string `json:"error,omitempty"`
Result string `json:"result,omitempty"`
}
// BatchTaskQueue 批量任务队列
type BatchTaskQueue struct {
ID string `json:"id"`
Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"`
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"`
mu sync.RWMutex
}
// BatchTaskManager 批量任务管理器
type BatchTaskManager struct {
db *database.DB
queues map[string]*BatchTaskQueue
mu sync.RWMutex
}
// NewBatchTaskManager 创建批量任务管理器
func NewBatchTaskManager() *BatchTaskManager {
return &BatchTaskManager{
queues: make(map[string]*BatchTaskQueue),
}
}
// SetDB 设置数据库连接
func (m *BatchTaskManager) SetDB(db *database.DB) {
m.mu.Lock()
defer m.mu.Unlock()
m.db = db
}
// CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue(tasks []string) *BatchTaskQueue {
m.mu.Lock()
defer m.mu.Unlock()
queueID := time.Now().Format("20060102150405") + "-" + generateShortID()
queue := &BatchTaskQueue{
ID: queueID,
Tasks: make([]*BatchTask, 0, len(tasks)),
Status: "pending",
CreatedAt: time.Now(),
CurrentIndex: 0,
}
// 准备数据库保存的任务数据
dbTasks := make([]map[string]interface{}, 0, len(tasks))
for _, message := range tasks {
if message == "" {
continue // 跳过空行
}
taskID := generateShortID()
task := &BatchTask{
ID: taskID,
Message: message,
Status: "pending",
}
queue.Tasks = append(queue.Tasks, task)
dbTasks = append(dbTasks, map[string]interface{}{
"id": taskID,
"message": message,
})
}
// 保存到数据库
if m.db != nil {
if err := m.db.CreateBatchQueue(queueID, dbTasks); err != nil {
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
// 这里可以添加日志记录
}
}
m.queues[queueID] = queue
return queue
}
// GetBatchQueue 获取批量任务队列
func (m *BatchTaskManager) GetBatchQueue(queueID string) (*BatchTaskQueue, bool) {
m.mu.RLock()
queue, exists := m.queues[queueID]
m.mu.RUnlock()
if exists {
return queue, true
}
// 如果内存中不存在,尝试从数据库加载
if m.db != nil {
if queue := m.loadQueueFromDB(queueID); queue != nil {
m.mu.Lock()
m.queues[queueID] = queue
m.mu.Unlock()
return queue, true
}
}
return nil, false
}
// loadQueueFromDB 从数据库加载单个队列
func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if m.db == nil {
return nil
}
queueRow, err := m.db.GetBatchQueue(queueID)
if err != nil || queueRow == nil {
return nil
}
taskRows, err := m.db.GetBatchTasks(queueID)
if err != nil {
return nil
}
queue := &BatchTaskQueue{
ID: queueRow.ID,
Status: queueRow.Status,
CreatedAt: queueRow.CreatedAt,
CurrentIndex: queueRow.CurrentIndex,
Tasks: make([]*BatchTask, 0, len(taskRows)),
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
if queueRow.CompletedAt.Valid {
queue.CompletedAt = &queueRow.CompletedAt.Time
}
for _, taskRow := range taskRows {
task := &BatchTask{
ID: taskRow.ID,
Message: taskRow.Message,
Status: taskRow.Status,
}
if taskRow.ConversationID.Valid {
task.ConversationID = taskRow.ConversationID.String
}
if taskRow.StartedAt.Valid {
task.StartedAt = &taskRow.StartedAt.Time
}
if taskRow.CompletedAt.Valid {
task.CompletedAt = &taskRow.CompletedAt.Time
}
if taskRow.Error.Valid {
task.Error = taskRow.Error.String
}
if taskRow.Result.Valid {
task.Result = taskRow.Result.String
}
queue.Tasks = append(queue.Tasks, task)
}
return queue
}
// GetAllQueues 获取所有队列
func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue {
m.mu.RLock()
result := make([]*BatchTaskQueue, 0, len(m.queues))
for _, queue := range m.queues {
result = append(result, queue)
}
m.mu.RUnlock()
// 如果数据库可用,确保所有数据库中的队列都已加载到内存
if m.db != nil {
dbQueues, err := m.db.GetAllBatchQueues()
if err == nil {
m.mu.Lock()
for _, queueRow := range dbQueues {
if _, exists := m.queues[queueRow.ID]; !exists {
if queue := m.loadQueueFromDB(queueRow.ID); queue != nil {
m.queues[queueRow.ID] = queue
result = append(result, queue)
}
}
}
m.mu.Unlock()
}
}
return result
}
// LoadFromDB 从数据库加载所有队列
func (m *BatchTaskManager) LoadFromDB() error {
if m.db == nil {
return nil
}
queueRows, err := m.db.GetAllBatchQueues()
if err != nil {
return err
}
m.mu.Lock()
defer m.mu.Unlock()
for _, queueRow := range queueRows {
if _, exists := m.queues[queueRow.ID]; exists {
continue // 已存在,跳过
}
taskRows, err := m.db.GetBatchTasks(queueRow.ID)
if err != nil {
continue // 跳过加载失败的任务
}
queue := &BatchTaskQueue{
ID: queueRow.ID,
Status: queueRow.Status,
CreatedAt: queueRow.CreatedAt,
CurrentIndex: queueRow.CurrentIndex,
Tasks: make([]*BatchTask, 0, len(taskRows)),
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
if queueRow.CompletedAt.Valid {
queue.CompletedAt = &queueRow.CompletedAt.Time
}
for _, taskRow := range taskRows {
task := &BatchTask{
ID: taskRow.ID,
Message: taskRow.Message,
Status: taskRow.Status,
}
if taskRow.ConversationID.Valid {
task.ConversationID = taskRow.ConversationID.String
}
if taskRow.StartedAt.Valid {
task.StartedAt = &taskRow.StartedAt.Time
}
if taskRow.CompletedAt.Valid {
task.CompletedAt = &taskRow.CompletedAt.Time
}
if taskRow.Error.Valid {
task.Error = taskRow.Error.String
}
if taskRow.Result.Valid {
task.Result = taskRow.Result.String
}
queue.Tasks = append(queue.Tasks, task)
}
m.queues[queueRow.ID] = queue
}
return nil
}
// UpdateTaskStatus 更新任务状态
func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, result, errorMsg string) {
m.UpdateTaskStatusWithConversationID(queueID, taskID, status, result, errorMsg, "")
}
// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId
func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return
}
for _, task := range queue.Tasks {
if task.ID == taskID {
task.Status = status
if result != "" {
task.Result = result
}
if errorMsg != "" {
task.Error = errorMsg
}
if conversationID != "" {
task.ConversationID = conversationID
}
now := time.Now()
if status == "running" && task.StartedAt == nil {
task.StartedAt = &now
}
if status == "completed" || status == "failed" || status == "cancelled" {
task.CompletedAt = &now
}
break
}
}
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
}
// UpdateQueueStatus 更新队列状态
func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return
}
queue.Status = status
now := time.Now()
if status == "running" && queue.StartedAt == nil {
queue.StartedAt = &now
}
if status == "completed" || status == "cancelled" {
queue.CompletedAt = &now
}
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
}
// UpdateTaskMessage 更新任务消息(仅限待执行状态)
func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return fmt.Errorf("队列不存在")
}
// 检查队列状态,只有待执行状态的队列才能编辑任务
if queue.Status != "pending" {
return fmt.Errorf("只有待执行状态的队列才能编辑任务")
}
// 查找并更新任务
for _, task := range queue.Tasks {
if task.ID == taskID {
// 只有待执行状态的任务才能编辑
if task.Status != "pending" {
return fmt.Errorf("只有待执行状态的任务才能编辑")
}
task.Message = message
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchTaskMessage(queueID, taskID, message); err != nil {
return fmt.Errorf("更新任务消息失败: %w", err)
}
}
return nil
}
}
return fmt.Errorf("任务不存在")
}
// AddTaskToQueue 添加任务到队列(仅限待执行状态)
func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, error) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return nil, fmt.Errorf("队列不存在")
}
// 检查队列状态,只有待执行状态的队列才能添加任务
if queue.Status != "pending" {
return nil, fmt.Errorf("只有待执行状态的队列才能添加任务")
}
if message == "" {
return nil, fmt.Errorf("任务消息不能为空")
}
// 生成任务ID
taskID := generateShortID()
task := &BatchTask{
ID: taskID,
Message: message,
Status: "pending",
}
// 添加到内存队列
queue.Tasks = append(queue.Tasks, task)
// 同步到数据库
if m.db != nil {
if err := m.db.AddBatchTask(queueID, taskID, message); err != nil {
// 如果数据库保存失败,从内存中移除
queue.Tasks = queue.Tasks[:len(queue.Tasks)-1]
return nil, fmt.Errorf("添加任务失败: %w", err)
}
}
return task, nil
}
// DeleteTask 删除任务(仅限待执行状态)
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return fmt.Errorf("队列不存在")
}
// 检查队列状态,只有待执行状态的队列才能删除任务
if queue.Status != "pending" {
return fmt.Errorf("只有待执行状态的队列才能删除任务")
}
// 查找并删除任务
taskIndex := -1
for i, task := range queue.Tasks {
if task.ID == taskID {
// 只有待执行状态的任务才能删除
if task.Status != "pending" {
return fmt.Errorf("只有待执行状态的任务才能删除")
}
taskIndex = i
break
}
}
if taskIndex == -1 {
return fmt.Errorf("任务不存在")
}
// 从内存队列中删除
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
// 同步到数据库
if m.db != nil {
if err := m.db.DeleteBatchTask(queueID, taskID); err != nil {
// 如果数据库删除失败,恢复内存中的任务
// 这里需要重新插入,但为了简化,我们只记录错误
return fmt.Errorf("删除任务失败: %w", err)
}
}
return nil
}
// GetNextTask 获取下一个待执行的任务
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists {
return nil, false
}
for i := queue.CurrentIndex; i < len(queue.Tasks); i++ {
task := queue.Tasks[i]
if task.Status == "pending" {
queue.CurrentIndex = i
return task, true
}
}
return nil, false
}
// MoveToNextTask 移动到下一个任务
func (m *BatchTaskManager) MoveToNextTask(queueID string) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return
}
queue.CurrentIndex++
// 同步到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
}
// CancelQueue 取消队列
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists {
return false
}
if queue.Status == "completed" || queue.Status == "cancelled" {
return false
}
queue.Status = "cancelled"
now := time.Now()
queue.CompletedAt = &now
// 取消所有待执行的任务
for _, task := range queue.Tasks {
if task.Status == "pending" {
task.Status = "cancelled"
task.CompletedAt = &now
// 同步到数据库
if m.db != nil {
m.db.UpdateBatchTaskStatus(queueID, task.ID, "cancelled", "", "", "")
}
}
}
// 同步队列状态到数据库
if m.db != nil {
if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
return true
}
// DeleteQueue 删除队列
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
_, exists := m.queues[queueID]
if !exists {
return false
}
// 从数据库删除
if m.db != nil {
if err := m.db.DeleteBatchQueue(queueID); err != nil {
// 记录错误但继续(使用内存缓存)
}
}
delete(m.queues, queueID)
return true
}
// generateShortID 生成短ID
func generateShortID() string {
b := make([]byte, 4)
rand.Read(b)
return time.Now().Format("150405") + "-" + hex.EncodeToString(b)
}
+90 -3
View File
@@ -23,16 +23,31 @@ type AgentTask struct {
cancel func(error)
}
// CompletedTask 已完成的任务(用于历史记录)
type CompletedTask struct {
ConversationID string `json:"conversationId"`
Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"`
Status string `json:"status"`
}
// AgentTaskManager 管理正在运行的Agent任务
type AgentTaskManager struct {
mu sync.RWMutex
tasks map[string]*AgentTask
mu sync.RWMutex
tasks map[string]*AgentTask
completedTasks []*CompletedTask // 最近完成的任务历史
maxHistorySize int // 最大历史记录数
historyRetention time.Duration // 历史记录保留时间
}
// NewAgentTaskManager 创建任务管理器
func NewAgentTaskManager() *AgentTaskManager {
return &AgentTaskManager{
tasks: make(map[string]*AgentTask),
tasks: make(map[string]*AgentTask),
completedTasks: make([]*CompletedTask, 0),
maxHistorySize: 50, // 最多保留50条历史记录
historyRetention: 24 * time.Hour, // 保留24小时
}
}
@@ -118,9 +133,49 @@ func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string)
task.Status = finalStatus
}
// 保存到历史记录
completedTask := &CompletedTask{
ConversationID: task.ConversationID,
Message: task.Message,
StartedAt: task.StartedAt,
CompletedAt: time.Now(),
Status: finalStatus,
}
// 添加到历史记录
m.completedTasks = append(m.completedTasks, completedTask)
// 清理过期和过多的历史记录
m.cleanupHistory()
// 从运行任务中移除
delete(m.tasks, conversationID)
}
// cleanupHistory 清理过期的历史记录
func (m *AgentTaskManager) cleanupHistory() {
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
// 过滤掉过期的记录
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
validTasks = append(validTasks, task)
}
}
// 如果仍然超过最大数量,只保留最新的
if len(validTasks) > m.maxHistorySize {
// 按完成时间排序,保留最新的
// 由于是追加的,最新的在最后,所以直接取最后N个
start := len(validTasks) - m.maxHistorySize
validTasks = validTasks[start:]
}
m.completedTasks = validTasks
}
// GetActiveTasks 返回所有正在运行的任务
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
m.mu.RLock()
@@ -137,3 +192,35 @@ func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
}
return result
}
// GetCompletedTasks 返回最近完成的任务历史
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
m.mu.RLock()
defer m.mu.RUnlock()
// 清理过期记录(只读锁,不影响其他操作)
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
// 所以返回时过滤过期记录
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
result := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
result = append(result, task)
}
}
// 按完成时间倒序排序(最新的在前)
// 由于是追加的,最新的在最后,需要反转
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
// 限制返回数量
if len(result) > m.maxHistorySize {
result = result[:m.maxHistorySize]
}
return result
}