mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-24 08:24:09 +02:00
Add files via upload
This commit is contained in:
+297
-10
@@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/database"
|
||||
@@ -25,16 +25,16 @@ func safeTruncateString(s string, maxLen int) string {
|
||||
if utf8.RuneCountInString(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
|
||||
|
||||
// 将字符串转换为 rune 切片以正确计算字符数
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
|
||||
|
||||
// 截断到最大长度
|
||||
truncated := string(runes[:maxLen])
|
||||
|
||||
|
||||
// 尝试在标点符号或空格处截断,使截断更自然
|
||||
// 在截断点往前查找合适的断点(不超过20%的长度)
|
||||
searchRange := maxLen / 5
|
||||
@@ -43,7 +43,7 @@ func safeTruncateString(s string, maxLen int) string {
|
||||
}
|
||||
breakChars := []rune(",。、 ,.;:!?!?/\\-_")
|
||||
bestBreakPos := len(runes[:maxLen])
|
||||
|
||||
|
||||
for i := bestBreakPos - 1; i >= bestBreakPos-searchRange && i >= 0; i-- {
|
||||
for _, breakChar := range breakChars {
|
||||
if runes[i] == breakChar {
|
||||
@@ -52,7 +52,7 @@ func safeTruncateString(s string, maxLen int) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
found:
|
||||
truncated = string(runes[:bestBreakPos])
|
||||
return truncated + "..."
|
||||
@@ -64,6 +64,7 @@ type AgentHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
tasks *AgentTaskManager
|
||||
batchTaskManager *BatchTaskManager
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
}
|
||||
@@ -71,11 +72,20 @@ type AgentHandler struct {
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
|
||||
batchTaskManager := NewBatchTaskManager()
|
||||
batchTaskManager.SetDB(db)
|
||||
|
||||
// 从数据库加载所有批量任务队列
|
||||
if err := batchTaskManager.LoadFromDB(); err != nil {
|
||||
logger.Warn("从数据库加载批量任务队列失败", zap.Error(err))
|
||||
}
|
||||
|
||||
return &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
logger: logger,
|
||||
tasks: NewAgentTaskManager(),
|
||||
agent: agent,
|
||||
db: db,
|
||||
logger: logger,
|
||||
tasks: NewAgentTaskManager(),
|
||||
batchTaskManager: batchTaskManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -724,6 +734,283 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ListCompletedTasks 列出最近完成的任务历史
|
||||
func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tasks": h.tasks.GetCompletedTasks(),
|
||||
})
|
||||
}
|
||||
|
||||
// BatchTaskRequest 批量任务请求
|
||||
type BatchTaskRequest struct {
|
||||
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
var req BatchTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Tasks) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "任务列表不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤空任务
|
||||
validTasks := make([]string, 0, len(req.Tasks))
|
||||
for _, task := range req.Tasks {
|
||||
if task != "" {
|
||||
validTasks = append(validTasks, task)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validTasks) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "没有有效的任务"})
|
||||
return
|
||||
}
|
||||
|
||||
queue := h.batchTaskManager.CreateBatchQueue(validTasks)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"queueId": queue.ID,
|
||||
"queue": queue,
|
||||
})
|
||||
}
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (h *AgentHandler) GetBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"queue": queue})
|
||||
}
|
||||
|
||||
// ListBatchQueues 列出所有批量任务队列
|
||||
func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
|
||||
queues := h.batchTaskManager.GetAllQueues()
|
||||
c.JSON(http.StatusOK, gin.H{"queues": queues})
|
||||
}
|
||||
|
||||
// StartBatchQueue 开始执行批量任务队列
|
||||
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if queue.Status != "pending" && queue.Status != "paused" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "队列状态不允许启动"})
|
||||
return
|
||||
}
|
||||
|
||||
// 在后台执行批量任务
|
||||
go h.executeBatchQueue(queueID)
|
||||
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "running")
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
// CancelBatchQueue 取消批量任务队列
|
||||
func (h *AgentHandler) CancelBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
success := h.batchTaskManager.CancelQueue(queueID)
|
||||
if !success {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法取消"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已取消"})
|
||||
}
|
||||
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
success := h.batchTaskManager.DeleteQueue(queueID)
|
||||
if !success {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
|
||||
}
|
||||
|
||||
// UpdateBatchTask 更新批量任务消息
|
||||
func (h *AgentHandler) UpdateBatchTask(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
taskID := c.Param("taskId")
|
||||
|
||||
var req struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Message == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.batchTaskManager.UpdateTaskMessage(queueID, taskID, req.Message)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的队列信息
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已更新", "queue": queue})
|
||||
}
|
||||
|
||||
// AddBatchTask 添加任务到批量任务队列
|
||||
func (h *AgentHandler) AddBatchTask(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
|
||||
var req struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Message == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "任务消息不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
task, err := h.batchTaskManager.AddTaskToQueue(queueID, req.Message)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的队列信息
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
|
||||
}
|
||||
|
||||
// DeleteBatchTask 删除批量任务
|
||||
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
taskID := c.Param("taskId")
|
||||
|
||||
err := h.batchTaskManager.DeleteTask(queueID, taskID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的队列信息
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
// executeBatchQueue 执行批量任务队列
|
||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
|
||||
|
||||
for {
|
||||
// 检查队列状态
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" {
|
||||
break
|
||||
}
|
||||
|
||||
// 获取下一个任务
|
||||
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
|
||||
if !hasNext {
|
||||
// 所有任务完成
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
|
||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||
break
|
||||
}
|
||||
|
||||
// 更新任务状态为运行中
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
continue
|
||||
}
|
||||
conversationID = conv.ID
|
||||
|
||||
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 执行任务
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("conversationId", conversationID))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
result, err := h.agent.AgentLoopWithConversationID(ctx, task.Message, []agent.ChatMessage{}, conversationID)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", err.Error())
|
||||
} else {
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
// 保存助手回复
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 保存ReAct数据
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存结果
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", result.Response, "", conversationID)
|
||||
}
|
||||
|
||||
// 移动到下一个任务
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
|
||||
// 检查是否被取消
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if queue.Status == "cancelled" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
|
||||
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表
|
||||
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
|
||||
|
||||
@@ -0,0 +1,589 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// BatchTask 批量任务项
|
||||
type BatchTask struct {
|
||||
ID string `json:"id"`
|
||||
Message string `json:"message"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
Status string `json:"status"` // pending, running, completed, failed, cancelled
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// BatchTaskQueue 批量任务队列
|
||||
type BatchTaskQueue struct {
|
||||
ID string `json:"id"`
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BatchTaskManager 批量任务管理器
|
||||
type BatchTaskManager struct {
|
||||
db *database.DB
|
||||
queues map[string]*BatchTaskQueue
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBatchTaskManager 创建批量任务管理器
|
||||
func NewBatchTaskManager() *BatchTaskManager {
|
||||
return &BatchTaskManager{
|
||||
queues: make(map[string]*BatchTaskQueue),
|
||||
}
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接
|
||||
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.db = db
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(tasks []string) *BatchTaskQueue {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queueID := time.Now().Format("20060102150405") + "-" + generateShortID()
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueID,
|
||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||
Status: "pending",
|
||||
CreatedAt: time.Now(),
|
||||
CurrentIndex: 0,
|
||||
}
|
||||
|
||||
// 准备数据库保存的任务数据
|
||||
dbTasks := make([]map[string]interface{}, 0, len(tasks))
|
||||
|
||||
for _, message := range tasks {
|
||||
if message == "" {
|
||||
continue // 跳过空行
|
||||
}
|
||||
taskID := generateShortID()
|
||||
task := &BatchTask{
|
||||
ID: taskID,
|
||||
Message: message,
|
||||
Status: "pending",
|
||||
}
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
dbTasks = append(dbTasks, map[string]interface{}{
|
||||
"id": taskID,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.CreateBatchQueue(queueID, dbTasks); err != nil {
|
||||
// 如果数据库保存失败,记录错误但继续(使用内存缓存)
|
||||
// 这里可以添加日志记录
|
||||
}
|
||||
}
|
||||
|
||||
m.queues[queueID] = queue
|
||||
return queue
|
||||
}
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (m *BatchTaskManager) GetBatchQueue(queueID string) (*BatchTaskQueue, bool) {
|
||||
m.mu.RLock()
|
||||
queue, exists := m.queues[queueID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return queue, true
|
||||
}
|
||||
|
||||
// 如果内存中不存在,尝试从数据库加载
|
||||
if m.db != nil {
|
||||
if queue := m.loadQueueFromDB(queueID); queue != nil {
|
||||
m.mu.Lock()
|
||||
m.queues[queueID] = queue
|
||||
m.mu.Unlock()
|
||||
return queue, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// loadQueueFromDB 从数据库加载单个队列
|
||||
func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if m.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
queueRow, err := m.db.GetBatchQueue(queueID)
|
||||
if err != nil || queueRow == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
taskRows, err := m.db.GetBatchTasks(queueID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueRow.ID,
|
||||
Status: queueRow.Status,
|
||||
CreatedAt: queueRow.CreatedAt,
|
||||
CurrentIndex: queueRow.CurrentIndex,
|
||||
Tasks: make([]*BatchTask, 0, len(taskRows)),
|
||||
}
|
||||
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
if queueRow.CompletedAt.Valid {
|
||||
queue.CompletedAt = &queueRow.CompletedAt.Time
|
||||
}
|
||||
|
||||
for _, taskRow := range taskRows {
|
||||
task := &BatchTask{
|
||||
ID: taskRow.ID,
|
||||
Message: taskRow.Message,
|
||||
Status: taskRow.Status,
|
||||
}
|
||||
if taskRow.ConversationID.Valid {
|
||||
task.ConversationID = taskRow.ConversationID.String
|
||||
}
|
||||
if taskRow.StartedAt.Valid {
|
||||
task.StartedAt = &taskRow.StartedAt.Time
|
||||
}
|
||||
if taskRow.CompletedAt.Valid {
|
||||
task.CompletedAt = &taskRow.CompletedAt.Time
|
||||
}
|
||||
if taskRow.Error.Valid {
|
||||
task.Error = taskRow.Error.String
|
||||
}
|
||||
if taskRow.Result.Valid {
|
||||
task.Result = taskRow.Result.String
|
||||
}
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
}
|
||||
|
||||
return queue
|
||||
}
|
||||
|
||||
// GetAllQueues 获取所有队列
|
||||
func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue {
|
||||
m.mu.RLock()
|
||||
result := make([]*BatchTaskQueue, 0, len(m.queues))
|
||||
for _, queue := range m.queues {
|
||||
result = append(result, queue)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// 如果数据库可用,确保所有数据库中的队列都已加载到内存
|
||||
if m.db != nil {
|
||||
dbQueues, err := m.db.GetAllBatchQueues()
|
||||
if err == nil {
|
||||
m.mu.Lock()
|
||||
for _, queueRow := range dbQueues {
|
||||
if _, exists := m.queues[queueRow.ID]; !exists {
|
||||
if queue := m.loadQueueFromDB(queueRow.ID); queue != nil {
|
||||
m.queues[queueRow.ID] = queue
|
||||
result = append(result, queue)
|
||||
}
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// LoadFromDB 从数据库加载所有队列
|
||||
func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if m.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
queueRows, err := m.db.GetAllBatchQueues()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, queueRow := range queueRows {
|
||||
if _, exists := m.queues[queueRow.ID]; exists {
|
||||
continue // 已存在,跳过
|
||||
}
|
||||
|
||||
taskRows, err := m.db.GetBatchTasks(queueRow.ID)
|
||||
if err != nil {
|
||||
continue // 跳过加载失败的任务
|
||||
}
|
||||
|
||||
queue := &BatchTaskQueue{
|
||||
ID: queueRow.ID,
|
||||
Status: queueRow.Status,
|
||||
CreatedAt: queueRow.CreatedAt,
|
||||
CurrentIndex: queueRow.CurrentIndex,
|
||||
Tasks: make([]*BatchTask, 0, len(taskRows)),
|
||||
}
|
||||
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
if queueRow.CompletedAt.Valid {
|
||||
queue.CompletedAt = &queueRow.CompletedAt.Time
|
||||
}
|
||||
|
||||
for _, taskRow := range taskRows {
|
||||
task := &BatchTask{
|
||||
ID: taskRow.ID,
|
||||
Message: taskRow.Message,
|
||||
Status: taskRow.Status,
|
||||
}
|
||||
if taskRow.ConversationID.Valid {
|
||||
task.ConversationID = taskRow.ConversationID.String
|
||||
}
|
||||
if taskRow.StartedAt.Valid {
|
||||
task.StartedAt = &taskRow.StartedAt.Time
|
||||
}
|
||||
if taskRow.CompletedAt.Valid {
|
||||
task.CompletedAt = &taskRow.CompletedAt.Time
|
||||
}
|
||||
if taskRow.Error.Valid {
|
||||
task.Error = taskRow.Error.String
|
||||
}
|
||||
if taskRow.Result.Valid {
|
||||
task.Result = taskRow.Result.String
|
||||
}
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
}
|
||||
|
||||
m.queues[queueRow.ID] = queue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus 更新任务状态
|
||||
func (m *BatchTaskManager) UpdateTaskStatus(queueID, taskID, status string, result, errorMsg string) {
|
||||
m.UpdateTaskStatusWithConversationID(queueID, taskID, status, result, errorMsg, "")
|
||||
}
|
||||
|
||||
// UpdateTaskStatusWithConversationID 更新任务状态(包含conversationId)
|
||||
func (m *BatchTaskManager) UpdateTaskStatusWithConversationID(queueID, taskID, status string, result, errorMsg, conversationID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for _, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
task.Status = status
|
||||
if result != "" {
|
||||
task.Result = result
|
||||
}
|
||||
if errorMsg != "" {
|
||||
task.Error = errorMsg
|
||||
}
|
||||
if conversationID != "" {
|
||||
task.ConversationID = conversationID
|
||||
}
|
||||
now := time.Now()
|
||||
if status == "running" && task.StartedAt == nil {
|
||||
task.StartedAt = &now
|
||||
}
|
||||
if status == "completed" || status == "failed" || status == "cancelled" {
|
||||
task.CompletedAt = &now
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchTaskStatus(queueID, taskID, status, conversationID, result, errorMsg); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueStatus 更新队列状态
|
||||
func (m *BatchTaskManager) UpdateQueueStatus(queueID, status string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
queue.Status = status
|
||||
now := time.Now()
|
||||
if status == "running" && queue.StartedAt == nil {
|
||||
queue.StartedAt = &now
|
||||
}
|
||||
if status == "completed" || status == "cancelled" {
|
||||
queue.CompletedAt = &now
|
||||
}
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, status); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskMessage 更新任务消息(仅限待执行状态)
|
||||
func (m *BatchTaskManager) UpdateTaskMessage(queueID, taskID, message string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return fmt.Errorf("队列不存在")
|
||||
}
|
||||
|
||||
// 检查队列状态,只有待执行状态的队列才能编辑任务
|
||||
if queue.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的队列才能编辑任务")
|
||||
}
|
||||
|
||||
// 查找并更新任务
|
||||
for _, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
// 只有待执行状态的任务才能编辑
|
||||
if task.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的任务才能编辑")
|
||||
}
|
||||
task.Message = message
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchTaskMessage(queueID, taskID, message); err != nil {
|
||||
return fmt.Errorf("更新任务消息失败: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("任务不存在")
|
||||
}
|
||||
|
||||
// AddTaskToQueue 添加任务到队列(仅限待执行状态)
|
||||
func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("队列不存在")
|
||||
}
|
||||
|
||||
// 检查队列状态,只有待执行状态的队列才能添加任务
|
||||
if queue.Status != "pending" {
|
||||
return nil, fmt.Errorf("只有待执行状态的队列才能添加任务")
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
return nil, fmt.Errorf("任务消息不能为空")
|
||||
}
|
||||
|
||||
// 生成任务ID
|
||||
taskID := generateShortID()
|
||||
task := &BatchTask{
|
||||
ID: taskID,
|
||||
Message: message,
|
||||
Status: "pending",
|
||||
}
|
||||
|
||||
// 添加到内存队列
|
||||
queue.Tasks = append(queue.Tasks, task)
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.AddBatchTask(queueID, taskID, message); err != nil {
|
||||
// 如果数据库保存失败,从内存中移除
|
||||
queue.Tasks = queue.Tasks[:len(queue.Tasks)-1]
|
||||
return nil, fmt.Errorf("添加任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// DeleteTask 删除任务(仅限待执行状态)
|
||||
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return fmt.Errorf("队列不存在")
|
||||
}
|
||||
|
||||
// 检查队列状态,只有待执行状态的队列才能删除任务
|
||||
if queue.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的队列才能删除任务")
|
||||
}
|
||||
|
||||
// 查找并删除任务
|
||||
taskIndex := -1
|
||||
for i, task := range queue.Tasks {
|
||||
if task.ID == taskID {
|
||||
// 只有待执行状态的任务才能删除
|
||||
if task.Status != "pending" {
|
||||
return fmt.Errorf("只有待执行状态的任务才能删除")
|
||||
}
|
||||
taskIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if taskIndex == -1 {
|
||||
return fmt.Errorf("任务不存在")
|
||||
}
|
||||
|
||||
// 从内存队列中删除
|
||||
queue.Tasks = append(queue.Tasks[:taskIndex], queue.Tasks[taskIndex+1:]...)
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.DeleteBatchTask(queueID, taskID); err != nil {
|
||||
// 如果数据库删除失败,恢复内存中的任务
|
||||
// 这里需要重新插入,但为了简化,我们只记录错误
|
||||
return fmt.Errorf("删除任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个待执行的任务
|
||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for i := queue.CurrentIndex; i < len(queue.Tasks); i++ {
|
||||
task := queue.Tasks[i]
|
||||
if task.Status == "pending" {
|
||||
queue.CurrentIndex = i
|
||||
return task, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// MoveToNextTask 移动到下一个任务
|
||||
func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
queue.CurrentIndex++
|
||||
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueCurrentIndex(queueID, queue.CurrentIndex); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelQueue 取消队列
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
if queue.Status == "completed" || queue.Status == "cancelled" {
|
||||
return false
|
||||
}
|
||||
|
||||
queue.Status = "cancelled"
|
||||
now := time.Now()
|
||||
queue.CompletedAt = &now
|
||||
|
||||
// 取消所有待执行的任务
|
||||
for _, task := range queue.Tasks {
|
||||
if task.Status == "pending" {
|
||||
task.Status = "cancelled"
|
||||
task.CompletedAt = &now
|
||||
// 同步到数据库
|
||||
if m.db != nil {
|
||||
m.db.UpdateBatchTaskStatus(queueID, task.ID, "cancelled", "", "", "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 同步队列状态到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteQueue 删除队列
|
||||
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// 从数据库删除
|
||||
if m.db != nil {
|
||||
if err := m.db.DeleteBatchQueue(queueID); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
|
||||
delete(m.queues, queueID)
|
||||
return true
|
||||
}
|
||||
|
||||
// generateShortID 生成短ID
|
||||
func generateShortID() string {
|
||||
b := make([]byte, 4)
|
||||
rand.Read(b)
|
||||
return time.Now().Format("150405") + "-" + hex.EncodeToString(b)
|
||||
}
|
||||
@@ -23,16 +23,31 @@ type AgentTask struct {
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
// CompletedTask 已完成的任务(用于历史记录)
|
||||
type CompletedTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
Message string `json:"message,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
CompletedAt time.Time `json:"completedAt"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// AgentTaskManager 管理正在运行的Agent任务
|
||||
type AgentTaskManager struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
completedTasks []*CompletedTask // 最近完成的任务历史
|
||||
maxHistorySize int // 最大历史记录数
|
||||
historyRetention time.Duration // 历史记录保留时间
|
||||
}
|
||||
|
||||
// NewAgentTaskManager 创建任务管理器
|
||||
func NewAgentTaskManager() *AgentTaskManager {
|
||||
return &AgentTaskManager{
|
||||
tasks: make(map[string]*AgentTask),
|
||||
tasks: make(map[string]*AgentTask),
|
||||
completedTasks: make([]*CompletedTask, 0),
|
||||
maxHistorySize: 50, // 最多保留50条历史记录
|
||||
historyRetention: 24 * time.Hour, // 保留24小时
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,9 +133,49 @@ func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string)
|
||||
task.Status = finalStatus
|
||||
}
|
||||
|
||||
// 保存到历史记录
|
||||
completedTask := &CompletedTask{
|
||||
ConversationID: task.ConversationID,
|
||||
Message: task.Message,
|
||||
StartedAt: task.StartedAt,
|
||||
CompletedAt: time.Now(),
|
||||
Status: finalStatus,
|
||||
}
|
||||
|
||||
// 添加到历史记录
|
||||
m.completedTasks = append(m.completedTasks, completedTask)
|
||||
|
||||
// 清理过期和过多的历史记录
|
||||
m.cleanupHistory()
|
||||
|
||||
// 从运行任务中移除
|
||||
delete(m.tasks, conversationID)
|
||||
}
|
||||
|
||||
// cleanupHistory 清理过期的历史记录
|
||||
func (m *AgentTaskManager) cleanupHistory() {
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
// 过滤掉过期的记录
|
||||
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
validTasks = append(validTasks, task)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果仍然超过最大数量,只保留最新的
|
||||
if len(validTasks) > m.maxHistorySize {
|
||||
// 按完成时间排序,保留最新的
|
||||
// 由于是追加的,最新的在最后,所以直接取最后N个
|
||||
start := len(validTasks) - m.maxHistorySize
|
||||
validTasks = validTasks[start:]
|
||||
}
|
||||
|
||||
m.completedTasks = validTasks
|
||||
}
|
||||
|
||||
// GetActiveTasks 返回所有正在运行的任务
|
||||
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
|
||||
m.mu.RLock()
|
||||
@@ -137,3 +192,35 @@ func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetCompletedTasks 返回最近完成的任务历史
|
||||
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// 清理过期记录(只读锁,不影响其他操作)
|
||||
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
|
||||
// 所以返回时过滤过期记录
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
result := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
result = append(result, task)
|
||||
}
|
||||
}
|
||||
|
||||
// 按完成时间倒序排序(最新的在前)
|
||||
// 由于是追加的,最新的在最后,需要反转
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if len(result) > m.maxHistorySize {
|
||||
result = result[:m.maxHistorySize]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user