mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-22 15:39:47 +02:00
Add files via upload
This commit is contained in:
+296
-151
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -214,141 +215,17 @@ type StreamEvent struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// AgentLoopStream 处理Agent Loop流式请求
|
||||
func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 对于流式请求,也发送SSE格式的错误
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
event := StreamEvent{
|
||||
Type: "error",
|
||||
Message: "请求参数错误: " + err.Error(),
|
||||
}
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到Agent Loop流式请求",
|
||||
zap.String("message", req.Message),
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
|
||||
|
||||
// 发送初始事件
|
||||
// 用于跟踪客户端是否已断开连接
|
||||
clientDisconnected := false
|
||||
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
// 如果客户端已断开,不再发送事件
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求上下文是否被取消(客户端断开)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
clientDisconnected = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
event := StreamEvent{
|
||||
Type: eventType,
|
||||
Message: message,
|
||||
Data: data,
|
||||
}
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
|
||||
// 尝试写入事件,如果失败则标记客户端断开
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||
clientDisconnected = true
|
||||
h.logger.Debug("客户端断开连接,停止发送SSE事件", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新响应,如果失败则标记客户端断开
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有对话ID,创建新对话
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
sendEvent("error", "创建对话失败: "+err.Error(), nil)
|
||||
return
|
||||
}
|
||||
conversationID = conv.ID
|
||||
}
|
||||
|
||||
sendEvent("conversation", "会话已创建", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 预先创建助手消息,以便关联过程详情
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.Error(err))
|
||||
// 如果创建失败,继续执行但不保存过程详情
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
// 创建进度回调函数,同时保存到数据库
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
|
||||
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
||||
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
||||
func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||
// 用于保存tool_call事件中的参数,以便在tool_result时使用
|
||||
toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments
|
||||
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
sendEvent(eventType, message, data)
|
||||
return func(eventType, message string, data interface{}) {
|
||||
// 如果提供了sendEventFunc,发送流式事件
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc(eventType, message, data)
|
||||
}
|
||||
|
||||
// 保存tool_call事件中的参数
|
||||
if eventType == "tool_call" {
|
||||
@@ -481,6 +358,140 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AgentLoopStream 处理Agent Loop流式请求
|
||||
func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 对于流式请求,也发送SSE格式的错误
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
event := StreamEvent{
|
||||
Type: "error",
|
||||
Message: "请求参数错误: " + err.Error(),
|
||||
}
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到Agent Loop流式请求",
|
||||
zap.String("message", req.Message),
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
|
||||
|
||||
// 发送初始事件
|
||||
// 用于跟踪客户端是否已断开连接
|
||||
clientDisconnected := false
|
||||
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
// 如果客户端已断开,不再发送事件
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查请求上下文是否被取消(客户端断开)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
clientDisconnected = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
event := StreamEvent{
|
||||
Type: eventType,
|
||||
Message: message,
|
||||
Data: data,
|
||||
}
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
|
||||
// 尝试写入事件,如果失败则标记客户端断开
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||
clientDisconnected = true
|
||||
h.logger.Debug("客户端断开连接,停止发送SSE事件", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新响应,如果失败则标记客户端断开
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有对话ID,创建新对话
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
sendEvent("error", "创建对话失败: "+err.Error(), nil)
|
||||
return
|
||||
}
|
||||
conversationID = conv.ID
|
||||
}
|
||||
|
||||
sendEvent("conversation", "会话已创建", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 预先创建助手消息,以便关联过程详情
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.Error(err))
|
||||
// 如果创建失败,继续执行但不保存过程详情
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
// 创建进度回调函数,同时保存到数据库
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
|
||||
// 创建进度回调函数,复用统一逻辑
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
|
||||
|
||||
// 创建一个独立的上下文用于任务执行,不随HTTP请求取消
|
||||
// 这样即使客户端断开连接(如刷新页面),任务也能继续执行
|
||||
@@ -795,10 +806,76 @@ func (h *AgentHandler) GetBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"queue": queue})
|
||||
}
|
||||
|
||||
// ListBatchQueues 列出所有批量任务队列
|
||||
// ListBatchQueuesResponse 批量任务队列列表响应
|
||||
type ListBatchQueuesResponse struct {
|
||||
Queues []*BatchTaskQueue `json:"queues"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// ListBatchQueues 列出所有批量任务队列(支持筛选和分页)
|
||||
func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
|
||||
queues := h.batchTaskManager.GetAllQueues()
|
||||
c.JSON(http.StatusOK, gin.H{"queues": queues})
|
||||
limitStr := c.DefaultQuery("limit", "10")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
pageStr := c.Query("page")
|
||||
status := c.Query("status")
|
||||
keyword := c.Query("keyword")
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
page := 1
|
||||
|
||||
// 如果提供了page参数,优先使用page计算offset
|
||||
if pageStr != "" {
|
||||
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
||||
page = p
|
||||
offset = (page - 1) * limit
|
||||
}
|
||||
}
|
||||
|
||||
// 限制pageSize范围
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 10
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
// 默认status为"all"
|
||||
if status == "" {
|
||||
status = "all"
|
||||
}
|
||||
|
||||
// 获取队列列表和总数
|
||||
queues, total, err := h.batchTaskManager.ListQueues(limit, offset, status, keyword)
|
||||
if err != nil {
|
||||
h.logger.Error("获取批量任务队列列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总页数
|
||||
totalPages := (total + limit - 1) / limit
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
// 如果使用offset计算page,需要重新计算
|
||||
if pageStr == "" {
|
||||
page = (offset / limit) + 1
|
||||
}
|
||||
|
||||
response := ListBatchQueuesResponse{
|
||||
Queues: queues,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: limit,
|
||||
TotalPages: totalPages,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// StartBatchQueue 开始执行批量任务队列
|
||||
@@ -822,15 +899,15 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
// CancelBatchQueue 取消批量任务队列
|
||||
func (h *AgentHandler) CancelBatchQueue(c *gin.Context) {
|
||||
// PauseBatchQueue 暂停批量任务队列
|
||||
func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
success := h.batchTaskManager.CancelQueue(queueID)
|
||||
success := h.batchTaskManager.PauseQueue(queueID)
|
||||
if !success {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法取消"})
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已取消"})
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
||||
}
|
||||
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
@@ -936,7 +1013,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
for {
|
||||
// 检查队列状态
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" {
|
||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -973,13 +1050,28 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 预先创建助手消息,以便关联过程详情
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
// 如果创建失败,继续执行但不保存过程详情
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, nil)
|
||||
|
||||
// 执行任务
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("conversationId", conversationID))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
// 存储取消函数,以便在取消队列时能够取消当前任务
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
result, err := h.agent.AgentLoopWithConversationID(ctx, task.Message, []agent.ChatMessage{}, conversationID)
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback)
|
||||
// 任务执行完成,清理取消函数
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
cancel()
|
||||
@@ -1002,9 +1094,25 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
if result != nil && result.Response != "" && (strings.Contains(result.Response, "任务已被取消") || strings.Contains(result.Response, "任务执行中断")) {
|
||||
cancelMsg = result.Response
|
||||
}
|
||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
||||
if errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
cancelMsg,
|
||||
assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存取消详情到数据库
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
// 如果没有预先创建的助手消息,创建一个新的
|
||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
||||
if errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
}
|
||||
// 保存ReAct数据(如果存在)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
@@ -1015,15 +1123,52 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||
} else {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
errorMsg := "执行失败: " + err.Error()
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存错误详情到数据库
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", err.Error())
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
// 保存助手回复
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
// 如果更新失败,尝试创建新消息
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果没有预先创建的助手消息,创建一个新的
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存ReAct数据
|
||||
@@ -1042,9 +1187,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
// 移动到下一个任务
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
|
||||
// 检查是否被取消
|
||||
// 检查是否被取消或暂停
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if queue.Status == "cancelled" {
|
||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -214,6 +216,100 @@ func (m *BatchTaskManager) GetAllQueues() []*BatchTaskQueue {
|
||||
return result
|
||||
}
|
||||
|
||||
// ListQueues 列出队列(支持筛选和分页)
|
||||
func (m *BatchTaskManager) ListQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueue, int, error) {
|
||||
var queues []*BatchTaskQueue
|
||||
var total int
|
||||
|
||||
// 如果数据库可用,从数据库查询
|
||||
if m.db != nil {
|
||||
// 获取总数
|
||||
count, err := m.db.CountBatchQueues(status, keyword)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("统计队列总数失败: %w", err)
|
||||
}
|
||||
total = count
|
||||
|
||||
// 获取队列列表(只获取ID)
|
||||
queueRows, err := m.db.ListBatchQueues(limit, offset, status, keyword)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询队列列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 加载完整的队列信息(从内存或数据库)
|
||||
m.mu.Lock()
|
||||
for _, queueRow := range queueRows {
|
||||
var queue *BatchTaskQueue
|
||||
// 先从内存查找
|
||||
if cached, exists := m.queues[queueRow.ID]; exists {
|
||||
queue = cached
|
||||
} else {
|
||||
// 从数据库加载
|
||||
queue = m.loadQueueFromDB(queueRow.ID)
|
||||
if queue != nil {
|
||||
m.queues[queueRow.ID] = queue
|
||||
}
|
||||
}
|
||||
if queue != nil {
|
||||
queues = append(queues, queue)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
} else {
|
||||
// 没有数据库,从内存中筛选和分页
|
||||
m.mu.RLock()
|
||||
allQueues := make([]*BatchTaskQueue, 0, len(m.queues))
|
||||
for _, queue := range m.queues {
|
||||
allQueues = append(allQueues, queue)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
// 筛选
|
||||
filtered := make([]*BatchTaskQueue, 0)
|
||||
for _, queue := range allQueues {
|
||||
// 状态筛选
|
||||
if status != "" && status != "all" && queue.Status != status {
|
||||
continue
|
||||
}
|
||||
// 关键字搜索
|
||||
if keyword != "" {
|
||||
keywordLower := strings.ToLower(keyword)
|
||||
queueIDLower := strings.ToLower(queue.ID)
|
||||
if !strings.Contains(queueIDLower, keywordLower) {
|
||||
// 也可以搜索创建时间
|
||||
createdAtStr := queue.CreatedAt.Format("2006-01-02 15:04:05")
|
||||
if !strings.Contains(createdAtStr, keyword) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, queue)
|
||||
}
|
||||
|
||||
// 按创建时间倒序排序
|
||||
sort.Slice(filtered, func(i, j int) bool {
|
||||
return filtered[i].CreatedAt.After(filtered[j].CreatedAt)
|
||||
})
|
||||
|
||||
total = len(filtered)
|
||||
|
||||
// 分页
|
||||
start := offset
|
||||
if start > len(filtered) {
|
||||
start = len(filtered)
|
||||
}
|
||||
end := start + limit
|
||||
if end > len(filtered) {
|
||||
end = len(filtered)
|
||||
}
|
||||
if start < len(filtered) {
|
||||
queues = filtered[start:end]
|
||||
}
|
||||
}
|
||||
|
||||
return queues, total, nil
|
||||
}
|
||||
|
||||
// LoadFromDB 从数据库加载所有队列
|
||||
func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if m.db == nil {
|
||||
@@ -534,7 +630,42 @@ func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFu
|
||||
}
|
||||
}
|
||||
|
||||
// CancelQueue 取消队列
|
||||
// PauseQueue 暂停队列
|
||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if queue.Status != "running" {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
queue.Status = "paused"
|
||||
|
||||
// 取消当前正在执行的任务(通过取消context)
|
||||
if cancel, exists := m.taskCancels[queueID]; exists {
|
||||
cancel()
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// 同步队列状态到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, "paused"); err != nil {
|
||||
// 记录错误但继续(使用内存缓存)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user