mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-01 16:50:47 +02:00
Add files via upload
This commit is contained in:
@@ -543,8 +543,13 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
|
||||
// 检查是否是用户取消:context的cause是ErrTaskCancelled
|
||||
// 如果cause是ErrTaskCancelled,无论错误是什么类型(包括context.Canceled),都视为用户取消
|
||||
// 这样可以正确处理在API调用过程中被取消的情况
|
||||
isCancelled := errors.Is(cause, ErrTaskCancelled)
|
||||
|
||||
switch {
|
||||
case errors.Is(cause, ErrTaskCancelled):
|
||||
case isCancelled:
|
||||
taskStatus = "cancelled"
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
|
||||
@@ -972,12 +977,46 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("conversationId", conversationID))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
// 存储取消函数,以便在取消队列时能够取消当前任务
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
result, err := h.agent.AgentLoopWithConversationID(ctx, task.Message, []agent.ChatMessage{}, conversationID)
|
||||
// 任务执行完成,清理取消函数
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", err.Error())
|
||||
// 检查是否是取消错误
|
||||
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
|
||||
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
|
||||
// 3. 检查 result.Response 中是否包含取消相关的消息
|
||||
errStr := err.Error()
|
||||
isCancelled := errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(result != nil && result.Response != "" && (strings.Contains(result.Response, "任务已被取消") || strings.Contains(result.Response, "任务执行中断")))
|
||||
|
||||
if isCancelled {
|
||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
// 如果result中有更具体的取消消息,使用它
|
||||
if result != nil && result.Response != "" && (strings.Contains(result.Response, "任务已被取消") || strings.Contains(result.Response, "任务执行中断")) {
|
||||
cancelMsg = result.Response
|
||||
}
|
||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
||||
if errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
// 保存ReAct数据(如果存在)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||
} else {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", err.Error())
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
@@ -36,15 +37,17 @@ type BatchTaskQueue struct {
|
||||
|
||||
// BatchTaskManager 批量任务管理器
|
||||
type BatchTaskManager struct {
|
||||
db *database.DB
|
||||
queues map[string]*BatchTaskQueue
|
||||
mu sync.RWMutex
|
||||
db *database.DB
|
||||
queues map[string]*BatchTaskQueue
|
||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBatchTaskManager 创建批量任务管理器
|
||||
func NewBatchTaskManager() *BatchTaskManager {
|
||||
return &BatchTaskManager{
|
||||
queues: make(map[string]*BatchTaskQueue),
|
||||
queues: make(map[string]*BatchTaskQueue),
|
||||
taskCancels: make(map[string]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -520,17 +523,29 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskCancel 设置当前任务的取消函数
|
||||
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
m.taskCancels[queueID] = cancel
|
||||
} else {
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelQueue 取消队列
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
if queue.Status == "completed" || queue.Status == "cancelled" {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -550,6 +565,14 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// 取消当前正在执行的任务
|
||||
if cancel, exists := m.taskCancels[queueID]; exists {
|
||||
cancel()
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// 同步队列状态到数据库
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueStatus(queueID, "cancelled"); err != nil {
|
||||
@@ -570,6 +593,9 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// 清理取消函数
|
||||
delete(m.taskCancels, queueID)
|
||||
|
||||
// 从数据库删除
|
||||
if m.db != nil {
|
||||
if err := m.db.DeleteBatchQueue(queueID); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user