diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 171bf82f..e415d526 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -292,6 +292,8 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error { type AgentLoopResult struct { Response string MCPExecutionIDs []string + LastReActInput string // 最后一轮ReAct的输入(压缩前的完整messages) + LastReActOutput string // 最终大模型的输出 } // ProgressCallback 进度回调函数类型 @@ -436,14 +438,48 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his MCPExecutionIDs: make([]string, 0), } + // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 + var currentReActInput string + maxIterations := a.maxIterations for i := 0; i < maxIterations; i++ { + // 在压缩前保存messages,这样即使出现异常也能保存原始数据 + messagesBeforeCompression := make([]ChatMessage, len(messages)) + copy(messagesBeforeCompression, messages) + // 每轮调用前先尝试压缩,防止历史消息持续膨胀 messages = a.applyMemoryCompression(ctx, messages) // 检查是否是最后一次迭代 isLastIteration := (i == maxIterations-1) + // 每次迭代都保存压缩前的messages,以便在异常中断(取消、错误等)时也能保存最新的ReAct输入 + // 这样无论何时中断,都能保存当前的上下文状态 + messagesJSON, err := json.Marshal(messagesBeforeCompression) + if err != nil { + a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) + } else { + currentReActInput = string(messagesJSON) + // 更新result中的值,确保始终保存最新的ReAct输入 + result.LastReActInput = currentReActInput + } + + // 检查上下文是否已取消 + select { + case <-ctx.Done(): + // 上下文被取消(可能是用户主动暂停或其他原因) + a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) + result.LastReActInput = currentReActInput + if ctx.Err() == context.Canceled { + result.Response = "任务已被取消。" + } else { + result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) + } + result.LastReActOutput = result.Response + return result, ctx.Err() + default: + } + // 获取可用工具 tools := a.getAvailableTools() @@ -511,7 +547,12 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his sendProgress("progress", "正在调用AI模型...", nil) response, err := a.callOpenAI(ctx, messages, tools) if err != nil { - result.Response = "" + // API调用失败,保存当前的ReAct输入和错误信息作为输出 + result.LastReActInput = currentReActInput + errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) + result.Response = errorMsg + result.LastReActOutput = errorMsg + a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) return result, fmt.Errorf("调用OpenAI失败: %w", err) } @@ -535,12 +576,20 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his ) continue } - result.Response = "" + // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 + result.LastReActInput = currentReActInput + errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) + result.Response = errorMsg + result.LastReActOutput = errorMsg return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) } if len(response.Choices) == 0 { - result.Response = "" + // 没有收到响应,保存当前的ReAct输入和错误信息作为输出 + result.LastReActInput = currentReActInput + errorMsg := "没有收到响应" + result.Response = errorMsg + result.LastReActOutput = errorMsg return result, fmt.Errorf("没有收到响应") } @@ -658,18 +707,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", }) messages = a.applyMemoryCompression(ctx, messages) - // 立即调用OpenAI获取总结 - summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 - if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { - summaryChoice := summaryResponse.Choices[0] - if summaryChoice.Message.Content != "" { - result.Response = summaryChoice.Message.Content - sendProgress("progress", "总结生成完成", nil) - return result, nil - } + // 立即调用OpenAI获取总结 + summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 + if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { + summaryChoice := summaryResponse.Choices[0] + if summaryChoice.Message.Content != "" { + result.Response = summaryChoice.Message.Content + result.LastReActOutput = result.Response + sendProgress("progress", "总结生成完成", nil) + return result, nil } - // 如果获取总结失败,跳出循环,让后续逻辑处理 - break + } + // 如果获取总结失败,跳出循环,让后续逻辑处理 + break } continue @@ -703,6 +753,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his summaryChoice := summaryResponse.Choices[0] if summaryChoice.Message.Content != "" { result.Response = summaryChoice.Message.Content + result.LastReActOutput = result.Response sendProgress("progress", "总结生成完成", nil) return result, nil } @@ -710,6 +761,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 如果获取总结失败,使用当前回复作为结果 if choice.Message.Content != "" { result.Response = choice.Message.Content + result.LastReActOutput = result.Response return result, nil } // 如果都没有内容,跳出循环,让后续逻辑处理 @@ -720,6 +772,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his if choice.FinishReason == "stop" { sendProgress("progress", "正在生成最终回复...", nil) result.Response = choice.Message.Content + result.LastReActOutput = result.Response return result, nil } } @@ -739,6 +792,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his summaryChoice := summaryResponse.Choices[0] if summaryChoice.Message.Content != "" { result.Response = summaryChoice.Message.Content + result.LastReActOutput = result.Response sendProgress("progress", "总结生成完成", nil) return result, nil } @@ -746,6 +800,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 如果无法生成总结,返回友好的提示 result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) + result.LastReActOutput = result.Response return result, nil } diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go index 25b25627..7d61988c 100644 --- a/internal/attackchain/builder.go +++ b/internal/attackchain/builder.go @@ -2,6 +2,8 @@ package attackchain import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -85,7 +87,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap. func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) { b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID)) - // 1. 获取对话消息 + // 0. 首先检查是否有实际的工具执行记录 messages, err := b.db.GetMessages(conversationID) if err != nil { return nil, fmt.Errorf("获取对话消息失败: %w", err) @@ -96,31 +98,115 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil } - // 2. 提取用户输入(最后一条user消息) - var userInput string - for i := len(messages) - 1; i >= 0; i-- { - if strings.EqualFold(messages[i].Role, "user") { - userInput = messages[i].Content - break - } - } - - // 3. 提取最后一轮ReAct的输入(历史消息+当前用户输入) - // 最后一轮ReAct的输入 = 所有历史消息(包括当前用户输入) - reactInput := b.buildReActInput(messages) - - // 4. 提取大模型最后的输出(最后一条assistant消息) - var modelOutput string + // 检查是否有实际的工具执行(通过检查assistant消息的mcp_execution_ids) + hasToolExecutions := false for i := len(messages) - 1; i >= 0; i-- { if strings.EqualFold(messages[i].Role, "assistant") { - modelOutput = messages[i].Content + if len(messages[i].MCPExecutionIDs) > 0 { + hasToolExecutions = true + break + } + } + } + + // 检查任务是否被取消(通过检查最后一条assistant消息内容或process_details) + taskCancelled := false + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "assistant") { + content := strings.ToLower(messages[i].Content) + if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") { + taskCancelled = true + } break } } - // 5. 构建简化的prompt,一次性传递给大模型 - prompt := b.buildSimplePrompt(userInput, reactInput, modelOutput) + // 如果任务被取消且没有实际工具执行,返回空攻击链 + if taskCancelled && !hasToolExecutions { + b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链", + zap.String("conversationId", conversationID), + zap.Bool("taskCancelled", taskCancelled), + zap.Bool("hasToolExecutions", hasToolExecutions)) + return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil + } + // 如果没有实际工具执行,也返回空攻击链(避免AI编造) + if !hasToolExecutions { + b.logger.Info("没有实际工具执行记录,返回空攻击链", + zap.String("conversationId", conversationID)) + return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil + } + + // 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出 + reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID) + if err != nil { + b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err)) + // 继续使用原来的逻辑 + reactInputJSON = "" + modelOutput = "" + } + + var userInput string + var reactInputFinal string + var dataSource string // 记录数据来源 + + // 如果成功获取到保存的ReAct数据,直接使用 + if reactInputJSON != "" && modelOutput != "" { + // 计算 ReAct 输入的哈希值,用于追踪 + hash := sha256.Sum256([]byte(reactInputJSON)) + reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识 + + // 统计消息数量 + var messageCount int + var tempMessages []interface{} + if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil { + messageCount = len(tempMessages) + } + + dataSource = "database_last_react_input" + b.logger.Info("使用保存的ReAct数据构建攻击链", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("reactInputSize", len(reactInputJSON)), + zap.Int("messageCount", messageCount), + zap.String("reactInputHash", reactInputHash), + zap.Int("modelOutputSize", len(modelOutput))) + + // 从保存的ReAct输入(JSON格式)中提取用户输入 + userInput = b.extractUserInputFromReActInput(reactInputJSON) + + // 将JSON格式的messages转换为可读格式 + reactInputFinal = b.formatReActInputFromJSON(reactInputJSON) + } else { + // 2. 如果没有保存的ReAct数据,从对话消息构建 + dataSource = "messages_table" + b.logger.Info("从消息历史构建ReAct数据", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("messageCount", len(messages))) + + // 提取用户输入(最后一条user消息) + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "user") { + userInput = messages[i].Content + break + } + } + + // 提取最后一轮ReAct的输入(历史消息+当前用户输入) + reactInputFinal = b.buildReActInput(messages) + + // 提取大模型最后的输出(最后一条assistant消息) + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "assistant") { + modelOutput = messages[i].Content + break + } + } + } + + // 3. 构建简化的prompt,一次性传递给大模型 + prompt := b.buildSimplePrompt(userInput, reactInputFinal, modelOutput) // 6. 调用AI生成攻击链(一次性,不做任何处理) chainJSON, err := b.callAIForChainGeneration(ctx, prompt) if err != nil { @@ -140,6 +226,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID b.logger.Info("攻击链构建完成", zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), zap.Int("nodes", len(chainData.Nodes)), zap.Int("edges", len(chainData.Edges))) @@ -162,10 +249,67 @@ func (b *Builder) buildReActInput(messages []database.Message) string { return builder.String() } +// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入 +func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string { + // reactInputJSON是JSON格式的ChatMessage数组,需要解析 + var messages []map[string]interface{} + if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { + b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) + return "" + } + + // 从后往前查找最后一条user消息 + for i := len(messages) - 1; i >= 0; i-- { + if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") { + if content, ok := messages[i]["content"].(string); ok { + return content + } + } + } + + return "" +} + +// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 +func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { + var messages []map[string]interface{} + if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { + b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) + return reactInputJSON // 如果解析失败,返回原始JSON + } + + var builder strings.Builder + for _, msg := range messages { + role, _ := msg["role"].(string) + content, _ := msg["content"].(string) + + // 如果content为空但存在tool_calls,标记为工具调用消息 + if content == "" { + if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { + content = fmt.Sprintf("[工具调用: %d个]", len(toolCalls)) + } + } + + builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) + } + + return builder.String() +} + // buildSimplePrompt 构建简化的prompt func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) string { return fmt.Sprintf(`你是一个专业的安全测试分析师。请根据以下信息生成攻击链图。 +## ⚠️ 重要原则 - 严禁杜撰 + +**严格禁止编造或推测任何内容!** 你必须: +1. **只使用实际发生的信息**:仅基于ReAct输入中实际执行的工具调用和实际返回的结果 +2. **不要推测**:如果没有实际执行工具或发现漏洞,不要编造 +3. **不要假设**:不能仅根据URL、目标名称等推断漏洞类型 +4. **基于事实**:每个节点和边都必须有实际依据,来自工具执行结果或模型的实际输出 + +如果ReAct输入中没有实际的工具执行记录,或者模型输出中明确表示任务未完成/被取消,必须返回空的攻击链(空的nodes和edges数组)。 + ## 用户输入 %s @@ -177,10 +321,15 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s ## 任务要求 -请根据上述信息,生成一个清晰的攻击链图。攻击链应该包含: -1. **target(目标)**:从用户输入中提取的测试目标 -2. **action(行动)**:从ReAct输入和模型输出中提取的关键测试步骤 -3. **vulnerability(漏洞)**:从模型输出中提取的发现的漏洞 +请根据上述信息,**仅基于实际执行的数据**生成一个清晰的攻击链图。攻击链应该包含: +1. **target(目标)**:从用户输入中提取的实际测试目标(必须是用户明确提供的) +2. **action(行动)**:从ReAct输入中提取的**实际执行的**工具调用和测试步骤(必须有tool_calls证据) +3. **vulnerability(漏洞)**:从模型输出中提取的**实际发现的**漏洞(必须在输出中明确提及,不能推测) + +**关键检查点:** +- 如果ReAct输入中没有tool_calls,说明没有实际执行工具 → 只能生成target节点 +- 如果模型输出中没有明确提到发现的漏洞,不要编造vulnerability节点 +- 如果任务被取消或未完成,返回空攻击链 ## 输出格式 @@ -194,8 +343,8 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s "risk_score": 0-100, "metadata": { "target": "目标(target节点)", - "tool_name": "工具名称(action节点)", - "description": "描述(vulnerability节点)" + "tool_name": "工具名称(action节点,必须是实际调用的工具)", + "description": "描述(vulnerability节点,必须是实际发现的漏洞)" } } ], @@ -209,6 +358,8 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s ] } +**再次强调:如果没有实际数据,返回空的nodes和edges数组。严禁杜撰!** + 只返回JSON,不要包含其他解释文字。`, userInput, reactInput, modelOutput) } diff --git a/internal/database/conversation.go b/internal/database/conversation.go index e15785a0..af17fe5f 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -205,6 +205,42 @@ func (db *DB) DeleteConversation(id string) error { return nil } +// SaveReActData 保存最后一轮ReAct的输入和输出 +func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error { + _, err := db.Exec( + "UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?", + reactInput, reactOutput, time.Now(), conversationID, + ) + if err != nil { + return fmt.Errorf("保存ReAct数据失败: %w", err) + } + return nil +} + +// GetReActData 获取最后一轮ReAct的输入和输出 +func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) { + var input, output sql.NullString + err = db.QueryRow( + "SELECT last_react_input, last_react_output FROM conversations WHERE id = ?", + conversationID, + ).Scan(&input, &output) + if err != nil { + if err == sql.ErrNoRows { + return "", "", fmt.Errorf("对话不存在") + } + return "", "", fmt.Errorf("获取ReAct数据失败: %w", err) + } + + if input.Valid { + reactInput = input.String + } + if output.Valid { + reactOutput = output.String + } + + return reactInput, reactOutput, nil +} + // AddMessage 添加消息 func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { id := uuid.New().String() diff --git a/internal/database/database.go b/internal/database/database.go index b0221e22..514f7e3b 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "strings" _ "github.com/mattn/go-sqlite3" "go.uber.org/zap" @@ -46,7 +47,9 @@ func (db *DB) initTables() error { id TEXT PRIMARY KEY, title TEXT NOT NULL, created_at DATETIME NOT NULL, - updated_at DATETIME NOT NULL + updated_at DATETIME NOT NULL, + last_react_input TEXT, + last_react_output TEXT );` // 创建消息表 @@ -199,10 +202,58 @@ func (db *DB) initTables() error { return fmt.Errorf("创建索引失败: %w", err) } + // 为已有表添加新字段(如果不存在) + if err := db.migrateConversationsTable(); err != nil { + db.logger.Warn("迁移conversations表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } + db.logger.Info("数据库表初始化完成") return nil } +// migrateConversationsTable 迁移conversations表,添加新字段 +func (db *DB) migrateConversationsTable() error { + // 检查last_react_input字段是否存在 + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil { + // 如果字段已存在,忽略错误(SQLite错误信息可能不同) + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil { + db.logger.Warn("添加last_react_input字段失败", zap.Error(err)) + } + } + + // 检查last_react_output字段是否存在 + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count) + if err != nil { + // 如果查询失败,尝试添加字段 + if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil { + // 如果字段已存在,忽略错误 + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr)) + } + } + } else if count == 0 { + // 字段不存在,添加它 + if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil { + db.logger.Warn("添加last_react_output字段失败", zap.Error(err)) + } + } + + return nil +} + // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") diff --git a/internal/handler/agent.go b/internal/handler/agent.go index d43b31af..a21b89e8 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -87,41 +87,30 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { conversationID = conv.ID } - // 获取历史消息(排除当前消息,因为还没保存) - historyMessages, err := h.db.GetMessages(conversationID) + // 优先尝试从保存的ReAct数据恢复历史上下文 + agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) if err != nil { - h.logger.Warn("获取历史消息失败", zap.Error(err)) - historyMessages = []database.Message{} - } - - h.logger.Info("获取历史消息", - zap.String("conversationId", conversationID), - zap.Int("count", len(historyMessages)), - ) - - // 将数据库消息转换为Agent消息格式 - agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages)) - for i, msg := range historyMessages { - agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ - Role: msg.Role, - Content: msg.Content, - }) - contentPreview := msg.Content - if len(contentPreview) > 50 { - contentPreview = contentPreview[:50] + "..." + h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) + // 回退到使用数据库消息表 + historyMessages, err := h.db.GetMessages(conversationID) + if err != nil { + h.logger.Warn("获取历史消息失败", zap.Error(err)) + agentHistoryMessages = []agent.ChatMessage{} + } else { + // 将数据库消息转换为Agent消息格式 + agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) + for _, msg := range historyMessages { + agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ + Role: msg.Role, + Content: msg.Content, + }) + } + h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) } - h.logger.Info("添加历史消息", - zap.Int("index", i), - zap.String("role", msg.Role), - zap.String("content", contentPreview), - ) + } else { + h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) } - h.logger.Info("历史消息转换完成", - zap.Int("originalCount", len(historyMessages)), - zap.Int("convertedCount", len(agentHistoryMessages)), - ) - // 保存用户消息 _, err = h.db.AddMessage(conversationID, "user", req.Message, nil) if err != nil { @@ -132,6 +121,16 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages) if err != nil { h.logger.Error("Agent Loop执行失败", zap.Error(err)) + + // 即使执行失败,也尝试保存ReAct数据(如果result中有) + if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { + if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil { + h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr)) + } else { + h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) + } + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -142,6 +141,15 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { h.logger.Error("保存助手消息失败", zap.Error(err)) } + // 保存最后一轮ReAct的输入和输出 + if result.LastReActInput != "" || result.LastReActOutput != "" { + if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { + h.logger.Warn("保存ReAct数据失败", zap.Error(err)) + } else { + h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) + } + } + c.JSON(http.StatusOK, ChatResponse{ Response: result.Response, MCPExecutionIDs: result.MCPExecutionIDs, @@ -246,20 +254,28 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { "conversationId": conversationID, }) - // 获取历史消息 - historyMessages, err := h.db.GetMessages(conversationID) + // 优先尝试从保存的ReAct数据恢复历史上下文 + agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) if err != nil { - h.logger.Warn("获取历史消息失败", zap.Error(err)) - historyMessages = []database.Message{} - } - - // 将数据库消息转换为Agent消息格式 - agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages)) - for _, msg := range historyMessages { - agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ - Role: msg.Role, - Content: msg.Content, - }) + h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) + // 回退到使用数据库消息表 + historyMessages, err := h.db.GetMessages(conversationID) + if err != nil { + h.logger.Warn("获取历史消息失败", zap.Error(err)) + agentHistoryMessages = []agent.ChatMessage{} + } else { + // 将数据库消息转换为Agent消息格式 + agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) + for _, msg := range historyMessages { + agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ + Role: msg.Role, + Content: msg.Content, + }) + } + h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) + } + } else { + h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) } // 保存用户消息 @@ -499,6 +515,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) } + + // 即使任务被取消,也尝试保存ReAct数据(如果result中有) + if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { + if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { + h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err)) + } else { + h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID)) + } + } + sendEvent("cancelled", cancelMsg, map[string]interface{}{ "conversationId": conversationID, "messageId": assistantMessageID, @@ -524,6 +550,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) } + + // 即使任务超时,也尝试保存ReAct数据(如果result中有) + if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { + if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { + h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err)) + } else { + h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID)) + } + } + sendEvent("error", timeoutMsg, map[string]interface{}{ "conversationId": conversationID, "messageId": assistantMessageID, @@ -549,6 +585,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) } + + // 即使任务失败,也尝试保存ReAct数据(如果result中有) + if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { + if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { + h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err)) + } else { + h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) + } + } + sendEvent("error", errorMsg, map[string]interface{}{ "conversationId": conversationID, "messageId": assistantMessageID, @@ -585,6 +631,15 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { } } + // 保存最后一轮ReAct的输入和输出 + if result.LastReActInput != "" || result.LastReActOutput != "" { + if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { + h.logger.Warn("保存ReAct数据失败", zap.Error(err)) + } else { + h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) + } + } + // 发送最终响应 sendEvent("response", result.Response, map[string]interface{}{ "mcpExecutionIds": result.MCPExecutionIDs, @@ -632,3 +687,140 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) { "tasks": h.tasks.GetActiveTasks(), }) } + +// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 +func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { + // 获取保存的ReAct输入和输出 + reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID) + if err != nil { + return nil, fmt.Errorf("获取ReAct数据失败: %w", err) + } + + if reactInputJSON == "" { + return nil, fmt.Errorf("ReAct数据为空") + } + + // 解析JSON格式的messages数组 + var messagesArray []map[string]interface{} + if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { + return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) + } + + // 转换为Agent消息格式 + agentMessages := make([]agent.ChatMessage, 0, len(messagesArray)) + for _, msgMap := range messagesArray { + msg := agent.ChatMessage{} + + // 解析role + if role, ok := msgMap["role"].(string); ok { + msg.Role = role + } else { + continue // 跳过无效消息 + } + + // 跳过system消息(AgentLoop会重新添加) + if msg.Role == "system" { + continue + } + + // 解析content + if content, ok := msgMap["content"].(string); ok { + msg.Content = content + } + + // 解析tool_calls(如果存在) + if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil { + if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok { + msg.ToolCalls = make([]agent.ToolCall, 0, len(toolCallsArray)) + for _, tcRaw := range toolCallsArray { + if tcMap, ok := tcRaw.(map[string]interface{}); ok { + toolCall := agent.ToolCall{} + + // 解析ID + if id, ok := tcMap["id"].(string); ok { + toolCall.ID = id + } + + // 解析Type + if toolType, ok := tcMap["type"].(string); ok { + toolCall.Type = toolType + } + + // 解析Function + if funcMap, ok := tcMap["function"].(map[string]interface{}); ok { + toolCall.Function = agent.FunctionCall{} + + // 解析函数名 + if name, ok := funcMap["name"].(string); ok { + toolCall.Function.Name = name + } + + // 解析arguments(可能是字符串或对象) + if argsRaw, ok := funcMap["arguments"]; ok { + if argsStr, ok := argsRaw.(string); ok { + // 如果是字符串,解析为JSON + var argsMap map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil { + toolCall.Function.Arguments = argsMap + } + } else if argsMap, ok := argsRaw.(map[string]interface{}); ok { + // 如果已经是对象,直接使用 + toolCall.Function.Arguments = argsMap + } + } + } + + if toolCall.ID != "" { + msg.ToolCalls = append(msg.ToolCalls, toolCall) + } + } + } + } + } + + // 解析tool_call_id(tool角色消息) + if toolCallID, ok := msgMap["tool_call_id"].(string); ok { + msg.ToolCallID = toolCallID + } + + agentMessages = append(agentMessages, msg) + } + + // 如果存在last_react_output,需要将其作为最后一条assistant消息 + // 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出 + if reactOutput != "" { + // 检查最后一条消息是否是assistant消息且没有tool_calls + // 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复 + if len(agentMessages) > 0 { + lastMsg := &agentMessages[len(agentMessages)-1] + if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { + // 最后一条是assistant消息且没有tool_calls,用最终输出更新其content + lastMsg.Content = reactOutput + } else { + // 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息 + agentMessages = append(agentMessages, agent.ChatMessage{ + Role: "assistant", + Content: reactOutput, + }) + } + } else { + // 如果没有消息,直接添加最终输出 + agentMessages = append(agentMessages, agent.ChatMessage{ + Role: "assistant", + Content: reactOutput, + }) + } + } + + if len(agentMessages) == 0 { + return nil, fmt.Errorf("从ReAct数据解析的消息为空") + } + + h.logger.Info("从ReAct数据恢复历史消息", + zap.String("conversationId", conversationID), + zap.Int("messageCount", len(agentMessages)), + zap.Bool("hasReactOutput", reactOutput != ""), + ) + + return agentMessages, nil +}