Add files via upload

This commit is contained in:
公明
2026-04-28 11:37:52 +08:00
committed by GitHub
parent 3b3d094dc4
commit b53cae3a02
26 changed files with 374 additions and 374 deletions
+19 -19
View File
@@ -338,8 +338,8 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
type AgentLoopResult struct { type AgentLoopResult struct {
Response string Response string
MCPExecutionIDs []string MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入(压缩后的messagesJSON格式 LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messagesJSON;与 multiagent.RunResult 字段对齐
LastReActOutput string // 最终大模型的输出 LastAgentTraceOutput string // 最终助手输出文本
} }
// ProgressCallback 进度回调函数类型 // ProgressCallback 进度回调函数类型
@@ -471,7 +471,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} }
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
var currentReActInput string var currentAgentTraceInput string
maxIterations := a.maxIterations maxIterations := a.maxIterations
thinkingStreamSeq := 0 thinkingStreamSeq := 0
@@ -490,9 +490,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if err != nil { if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else { } else {
currentReActInput = string(messagesJSON) currentAgentTraceInput = string(messagesJSON)
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的) // 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
} }
// 检查上下文是否已取消 // 检查上下文是否已取消
@@ -500,13 +500,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
case <-ctx.Done(): case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因) // 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
if ctx.Err() == context.Canceled { if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。" result.Response = "任务已被取消。"
} else { } else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
} }
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, ctx.Err() return result, ctx.Err()
default: default:
} }
@@ -600,10 +600,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if err != nil { if err != nil {
// API调用失败,保存当前的ReAct输入和错误信息作为输出 // API调用失败,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err) return result, fmt.Errorf("调用OpenAI失败: %w", err)
} }
@@ -629,19 +629,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
continue continue
} }
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
} }
if len(response.Choices) == 0 { if len(response.Choices) == 0 {
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出 // 没有收到响应,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := "没有收到响应" errorMsg := "没有收到响应"
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("没有收到响应") return result, fmt.Errorf("没有收到响应")
} }
@@ -816,7 +816,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
@@ -863,14 +863,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果获取总结失败,使用当前回复作为结果 // 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" { if choice.Message.Content != "" {
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
// 如果都没有内容,跳出循环,让后续逻辑处理 // 如果都没有内容,跳出循环,让后续逻辑处理
@@ -881,7 +881,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" { if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil) sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
} }
@@ -910,14 +910,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果无法生成总结,返回友好的提示 // 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
-1
View File
@@ -283,4 +283,3 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold) t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
} }
} }
+84 -92
View File
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
defer h.hitlManager.DeactivateConversation(conversationID) defer h.hitlManager.DeactivateConversation(conversationID)
} }
// 优先尝试从保存的ReAct数据恢复历史上下文 // 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表 // 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID) historyMessages, err := h.db.GetMessages(conversationID)
if err != nil { if err != nil {
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
} }
} else { } else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 校验附件数量(非流式) // 校验附件数量(非流式)
@@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
// 即使执行失败,也尝试保存ReAct数据(如果result中有) // 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil { if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr)) h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
} else { } else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
// 因为AI已经生成了回复,用户应该能看到 // 因为AI已经生成了回复,用户应该能看到
} }
// 保存最后一轮ReAct的输入和输出 // 保存最后一轮代理轨迹与助手输出
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
@@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
"deep", "deep",
) )
if errMA != nil { if errMA != nil {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
errMsg := "执行失败: " + errMA.Error() errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
@@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
} }
} }
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" { if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput) _ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
} }
return resultMA.Response, conversationID, nil return resultMA.Response, conversationID, nil
} }
@@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
} }
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
return result.Response, conversationID, nil return result.Response, conversationID, nil
} }
@@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
ssePublishConversationID = conversationID ssePublishConversationID = conversationID
// 优先尝试从保存的ReAct数据恢复历史上下文 // 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表 // 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID) historyMessages, err := h.db.GetMessages(conversationID)
if err != nil { if err != nil {
@@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
} }
} else { } else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 校验附件数量 // 校验附件数量
@@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
} }
// 即使任务被取消,也尝试保存ReAct数据(如果result中有) // 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
} }
// 即使任务超时,也尝试保存ReAct数据(如果result中有) // 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
} }
// 即使任务失败,也尝试保存ReAct数据(如果result中有) // 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
} }
// 保存最后一轮ReAct的输入和输出 // 保存最后一轮代理轨迹与助手输出
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
cancel() cancel()
if runErr != nil { if runErr != nil {
if useRunResult {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
// 检查是否是取消错误 // 检查是否是取消错误
// 1. 直接检查是否是 context.Canceled(包括包装后的错误) // 1. 直接检查是否是 context.Canceled(包括包装后的错误)
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字 // 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
@@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
} }
} }
// 保存ReAct数据(如果存在) // 保存代理轨迹(如果存在)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} }
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") { } else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} }
} }
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
@@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
if useRunResult { if useRunResult {
resText = resultMA.Response resText = resultMA.Response
mcpIDs = resultMA.MCPExecutionIDs mcpIDs = resultMA.MCPExecutionIDs
lastIn = resultMA.LastReActInput lastIn = resultMA.LastAgentTraceInput
lastOut = resultMA.LastReActOutput lastOut = resultMA.LastAgentTraceOutput
} else { } else {
resText = result.Response resText = result.Response
mcpIDs = result.MCPExecutionIDs mcpIDs = result.MCPExecutionIDs
lastIn = result.LastReActInput lastIn = result.LastAgentTraceInput
lastOut = result.LastReActOutput lastOut = result.LastAgentTraceOutput
} }
// 更新助手消息内容 // 更新助手消息内容
@@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
} }
} }
// 保存ReAct数据 // 保存代理轨迹
if lastIn != "" || lastOut != "" { if lastIn != "" || lastOut != "" {
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil { if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
} }
} }
@@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
} }
} }
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退消息表 // 逻辑与攻击链一致:优先用保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出 traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取ReAct数据失败: %w", err) return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
} }
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致) if traceInputJSON == "" {
if reactInputJSON == "" { return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
} }
dataSource := "database_last_react_input" dataSource := "database_last_agent_trace"
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{} var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
} }
messageCount := len(messagesArray) messageCount := len(messagesArray)
h.logger.Info("使用保存的ReAct数据恢复历史上下文", h.logger.Info("使用保存的代理轨迹恢复历史上下文",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)), zap.Int("traceInputSize", len(traceInputJSON)),
zap.Int("messageCount", messageCount), zap.Int("messageCount", messageCount),
zap.Int("reactOutputSize", len(reactOutput)), zap.Int("assistantOutSize", len(assistantOut)),
) )
// fmt.Println("messagesArray:", messagesArray)//debug // fmt.Println("messagesArray:", messagesArray)//debug
@@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
agentMessages = append(agentMessages, msg) agentMessages = append(agentMessages, msg)
} }
// 如果存在last_react_output,需要将其作为最后一条assistant消息 // 存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出 if assistantOut != "" {
if reactOutput != "" {
// 检查最后一条消息是否是assistant消息且没有tool_calls
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
if len(agentMessages) > 0 { if len(agentMessages) > 0 {
lastMsg := &agentMessages[len(agentMessages)-1] lastMsg := &agentMessages[len(agentMessages)-1]
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content lastMsg.Content = assistantOut
lastMsg.Content = reactOutput
} else { } else {
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
agentMessages = append(agentMessages, agent.ChatMessage{ agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant", Role: "assistant",
Content: reactOutput, Content: assistantOut,
}) })
} }
} else { } else {
// 如果没有消息,直接添加最终输出
agentMessages = append(agentMessages, agent.ChatMessage{ agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant", Role: "assistant",
Content: reactOutput, Content: assistantOut,
}) })
} }
} }
if len(agentMessages) == 0 { if len(agentMessages) == 0 {
return nil, fmt.Errorf("从ReAct数据解析的消息为空") return nil, fmt.Errorf("从代理轨迹解析的消息为空")
} }
// 修复可能存在的失配tool消息,避免OpenAI报错
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
if h.agent != nil { if h.agent != nil {
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed { if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息", h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
) )
} }
} }
h.logger.Info("从ReAct数据恢复历史消息完成", h.logger.Info("从代理轨迹恢复历史消息完成",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("originalMessageCount", messageCount), zap.Int("originalMessageCount", messageCount),
zap.Int("finalMessageCount", len(agentMessages)), zap.Int("finalMessageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""), zap.Bool("hasAssistantOut", assistantOut != ""),
) )
fmt.Println("agentMessages:", agentMessages) //debug
return agentMessages, nil return agentMessages, nil
} }
-1
View File
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
c.JSON(http.StatusOK, chain) c.JSON(http.StatusOK, chain)
} }
-1
View File
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
"message": "ok", "message": "ok",
}) })
} }
+7 -5
View File
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
progressCallback, progressCallback,
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return return
} }
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.AssistantMessageID, prep.AssistantMessageID,
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
-3
View File
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
return stats return stats
} }
// GetExecution 获取特定执行记录 // GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) { func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
+21 -6
View File
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
strings.TrimSpace(req.Orchestration), strings.TrimSpace(req.Orchestration),
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error() errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" { if prep.AssistantMessageID != "" {
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
}) })
} }
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
func multiAgentHTTPErrorStatus(err error) (int, string) { func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error() msg := err.Error()
switch { switch {
+1 -1
View File
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
-1
View File
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
<-doneChan <-doneChan
} }
-1
View File
@@ -460,4 +460,3 @@ func sanitizeExportName(raw string) string {
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-") replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
return replacer.Replace(name) return replacer.Replace(name)
} }
+1 -1
View File
@@ -8,8 +8,8 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
) )
+1 -1
View File
@@ -11,9 +11,9 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"go.uber.org/zap" "go.uber.org/zap"
) )
-1
View File
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines)) t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
} }
} }