Add files via upload

This commit is contained in:
公明
2025-12-24 05:08:26 +08:00
committed by GitHub
parent 0fe6148284
commit 2df9c21d80
5 changed files with 569 additions and 84 deletions

View File

@@ -292,6 +292,8 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
type AgentLoopResult struct {
Response string
MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入压缩前的完整messages
LastReActOutput string // 最终大模型的输出
}
// ProgressCallback 进度回调函数类型
@@ -436,14 +438,48 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
MCPExecutionIDs: make([]string, 0),
}
// 用于保存当前的messages以便在异常情况下也能保存ReAct输入
var currentReActInput string
maxIterations := a.maxIterations
for i := 0; i < maxIterations; i++ {
// 在压缩前保存messages这样即使出现异常也能保存原始数据
messagesBeforeCompression := make([]ChatMessage, len(messages))
copy(messagesBeforeCompression, messages)
// 每轮调用前先尝试压缩,防止历史消息持续膨胀
messages = a.applyMemoryCompression(ctx, messages)
// 检查是否是最后一次迭代
isLastIteration := (i == maxIterations-1)
// 每次迭代都保存压缩前的messages以便在异常中断取消、错误等时也能保存最新的ReAct输入
// 这样无论何时中断,都能保存当前的上下文状态
messagesJSON, err := json.Marshal(messagesBeforeCompression)
if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else {
currentReActInput = string(messagesJSON)
// 更新result中的值确保始终保存最新的ReAct输入
result.LastReActInput = currentReActInput
}
// 检查上下文是否已取消
select {
case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput
if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。"
} else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
}
result.LastReActOutput = result.Response
return result, ctx.Err()
default:
}
// 获取可用工具
tools := a.getAvailableTools()
@@ -511,7 +547,12 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
sendProgress("progress", "正在调用AI模型...", nil)
response, err := a.callOpenAI(ctx, messages, tools)
if err != nil {
result.Response = ""
// API调用失败保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg
result.LastReActOutput = errorMsg
a.logger.Warn("OpenAI调用失败已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err)
}
@@ -535,12 +576,20 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
)
continue
}
result.Response = ""
// OpenAI返回错误保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg
result.LastReActOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
}
if len(response.Choices) == 0 {
result.Response = ""
// 没有收到响应保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
errorMsg := "没有收到响应"
result.Response = errorMsg
result.LastReActOutput = errorMsg
return result, fmt.Errorf("没有收到响应")
}
@@ -658,18 +707,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
messages = a.applyMemoryCompression(ctx, messages)
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
result.LastReActOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
// 如果获取总结失败,跳出循环,让后续逻辑处理
break
}
// 如果获取总结失败,跳出循环,让后续逻辑处理
break
}
continue
@@ -703,6 +753,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
result.LastReActOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
@@ -710,6 +761,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" {
result.Response = choice.Message.Content
result.LastReActOutput = result.Response
return result, nil
}
// 如果都没有内容,跳出循环,让后续逻辑处理
@@ -720,6 +772,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content
result.LastReActOutput = result.Response
return result, nil
}
}
@@ -739,6 +792,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
result.LastReActOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
@@ -746,6 +800,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮。系统已执行了多轮测试但由于达到迭代上限无法继续自动执行。建议您查看已执行的工具结果或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response
return result, nil
}

View File

@@ -2,6 +2,8 @@ package attackchain
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -85,7 +87,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
// 1. 获取对话消息
// 0. 首先检查是否有实际的工具执行记录
messages, err := b.db.GetMessages(conversationID)
if err != nil {
return nil, fmt.Errorf("获取对话消息失败: %w", err)
@@ -96,31 +98,115 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 2. 提取用户输入最后一条user消息
var userInput string
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
userInput = messages[i].Content
break
}
}
// 3. 提取最后一轮ReAct的输入历史消息+当前用户输入)
// 最后一轮ReAct的输入 = 所有历史消息(包括当前用户输入)
reactInput := b.buildReActInput(messages)
// 4. 提取大模型最后的输出最后一条assistant消息
var modelOutput string
// 检查是否有实际的工具执行通过检查assistant消息的mcp_execution_ids
hasToolExecutions := false
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
modelOutput = messages[i].Content
if len(messages[i].MCPExecutionIDs) > 0 {
hasToolExecutions = true
break
}
}
}
// 检查任务是否被取消通过检查最后一条assistant消息内容或process_details
taskCancelled := false
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
content := strings.ToLower(messages[i].Content)
if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") {
taskCancelled = true
}
break
}
}
// 5. 构建简化的prompt一次性传递给大模型
prompt := b.buildSimplePrompt(userInput, reactInput, modelOutput)
// 如果任务被取消且没有实际工具执行,返回空攻击链
if taskCancelled && !hasToolExecutions {
b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链",
zap.String("conversationId", conversationID),
zap.Bool("taskCancelled", taskCancelled),
zap.Bool("hasToolExecutions", hasToolExecutions))
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 如果没有实际工具执行也返回空攻击链避免AI编造
if !hasToolExecutions {
b.logger.Info("没有实际工具执行记录,返回空攻击链",
zap.String("conversationId", conversationID))
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
if err != nil {
b.logger.Warn("获取保存的ReAct数据失败将使用消息历史构建", zap.Error(err))
// 继续使用原来的逻辑
reactInputJSON = ""
modelOutput = ""
}
var userInput string
var reactInputFinal string
var dataSource string // 记录数据来源
// 如果成功获取到保存的ReAct数据直接使用
if reactInputJSON != "" && modelOutput != "" {
// 计算 ReAct 输入的哈希值,用于追踪
hash := sha256.Sum256([]byte(reactInputJSON))
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
// 统计消息数量
var messageCount int
var tempMessages []interface{}
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
messageCount = len(tempMessages)
}
dataSource = "database_last_react_input"
b.logger.Info("使用保存的ReAct数据构建攻击链",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)),
zap.Int("messageCount", messageCount),
zap.String("reactInputHash", reactInputHash),
zap.Int("modelOutputSize", len(modelOutput)))
// 从保存的ReAct输入JSON格式中提取用户输入
userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
} else {
// 2. 如果没有保存的ReAct数据从对话消息构建
dataSource = "messages_table"
b.logger.Info("从消息历史构建ReAct数据",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("messageCount", len(messages)))
// 提取用户输入最后一条user消息
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
userInput = messages[i].Content
break
}
}
// 提取最后一轮ReAct的输入历史消息+当前用户输入)
reactInputFinal = b.buildReActInput(messages)
// 提取大模型最后的输出最后一条assistant消息
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
modelOutput = messages[i].Content
break
}
}
}
// 3. 构建简化的prompt一次性传递给大模型
prompt := b.buildSimplePrompt(userInput, reactInputFinal, modelOutput)
// 6. 调用AI生成攻击链一次性不做任何处理
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
if err != nil {
@@ -140,6 +226,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
b.logger.Info("攻击链构建完成",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("nodes", len(chainData.Nodes)),
zap.Int("edges", len(chainData.Edges)))
@@ -162,10 +249,67 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
return builder.String()
}
// extractUserInputFromReActInput 从保存的ReAct输入JSON格式的messages数组中提取最后一条用户输入
func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
// reactInputJSON是JSON格式的ChatMessage数组需要解析
var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return ""
}
// 从后往前查找最后一条user消息
for i := len(messages) - 1; i >= 0; i-- {
if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
if content, ok := messages[i]["content"].(string); ok {
return content
}
}
}
return ""
}
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return reactInputJSON // 如果解析失败返回原始JSON
}
var builder strings.Builder
for _, msg := range messages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
// 如果content为空但存在tool_calls标记为工具调用消息
if content == "" {
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
content = fmt.Sprintf("[工具调用: %d个]", len(toolCalls))
}
}
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
return builder.String()
}
// buildSimplePrompt 构建简化的prompt
func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) string {
return fmt.Sprintf(`你是一个专业的安全测试分析师。请根据以下信息生成攻击链图。
## ⚠️ 重要原则 - 严禁杜撰
**严格禁止编造或推测任何内容!** 你必须:
1. **只使用实际发生的信息**仅基于ReAct输入中实际执行的工具调用和实际返回的结果
2. **不要推测**:如果没有实际执行工具或发现漏洞,不要编造
3. **不要假设**不能仅根据URL、目标名称等推断漏洞类型
4. **基于事实**:每个节点和边都必须有实际依据,来自工具执行结果或模型的实际输出
如果ReAct输入中没有实际的工具执行记录或者模型输出中明确表示任务未完成/被取消必须返回空的攻击链空的nodes和edges数组
## 用户输入
%s
@@ -177,10 +321,15 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
## 任务要求
请根据上述信息,生成一个清晰的攻击链图。攻击链应该包含:
1. **target目标**:从用户输入中提取的测试目标
2. **action行动**从ReAct输入和模型输出中提取的关键测试步骤
3. **vulnerability漏洞**:从模型输出中提取的发现的漏洞
请根据上述信息,**仅基于实际执行的数据**生成一个清晰的攻击链图。攻击链应该包含:
1. **target目标**:从用户输入中提取的实际测试目标(必须是用户明确提供的)
2. **action行动**从ReAct输入中提取的**实际执行的**工具调用和测试步骤必须有tool_calls证据
3. **vulnerability漏洞**:从模型输出中提取的**实际发现的**漏洞(必须在输出中明确提及,不能推测)
**关键检查点:**
- 如果ReAct输入中没有tool_calls说明没有实际执行工具 → 只能生成target节点
- 如果模型输出中没有明确提到发现的漏洞不要编造vulnerability节点
- 如果任务被取消或未完成,返回空攻击链
## 输出格式
@@ -194,8 +343,8 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
"risk_score": 0-100,
"metadata": {
"target": "目标target节点",
"tool_name": "工具名称action节点",
"description": "描述vulnerability节点"
"tool_name": "工具名称action节点,必须是实际调用的工具",
"description": "描述vulnerability节点,必须是实际发现的漏洞"
}
}
],
@@ -209,6 +358,8 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
]
}
**再次强调如果没有实际数据返回空的nodes和edges数组。严禁杜撰**
只返回JSON不要包含其他解释文字。`, userInput, reactInput, modelOutput)
}

View File

@@ -205,6 +205,42 @@ func (db *DB) DeleteConversation(id string) error {
return nil
}
// SaveReActData 保存最后一轮ReAct的输入和输出
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
_, err := db.Exec(
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
reactInput, reactOutput, time.Now(), conversationID,
)
if err != nil {
return fmt.Errorf("保存ReAct数据失败: %w", err)
}
return nil
}
// GetReActData 获取最后一轮ReAct的输入和输出
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
var input, output sql.NullString
err = db.QueryRow(
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
conversationID,
).Scan(&input, &output)
if err != nil {
if err == sql.ErrNoRows {
return "", "", fmt.Errorf("对话不存在")
}
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
}
if input.Valid {
reactInput = input.String
}
if output.Valid {
reactOutput = output.String
}
return reactInput, reactOutput, nil
}
// AddMessage 添加消息
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
id := uuid.New().String()

View File

@@ -3,6 +3,7 @@ package database
import (
"database/sql"
"fmt"
"strings"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
@@ -46,7 +47,9 @@ func (db *DB) initTables() error {
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
updated_at DATETIME NOT NULL,
last_react_input TEXT,
last_react_output TEXT
);`
// 创建消息表
@@ -199,10 +202,58 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建索引失败: %w", err)
}
// 为已有表添加新字段(如果不存在)
if err := db.migrateConversationsTable(); err != nil {
db.logger.Warn("迁移conversations表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
db.logger.Info("数据库表初始化完成")
return nil
}
// migrateConversationsTable 迁移conversations表添加新字段
func (db *DB) migrateConversationsTable() error {
// 检查last_react_input字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil {
// 如果字段已存在忽略错误SQLite错误信息可能不同
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil {
db.logger.Warn("添加last_react_input字段失败", zap.Error(err))
}
}
// 检查last_react_output字段是否存在
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil {
db.logger.Warn("添加last_react_output字段失败", zap.Error(err))
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")

View File

@@ -87,41 +87,30 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
conversationID = conv.ID
}
// 获取历史消息(排除当前消息,因为还没保存)
historyMessages, err := h.db.GetMessages(conversationID)
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
historyMessages = []database.Message{}
}
h.logger.Info("获取历史消息",
zap.String("conversationId", conversationID),
zap.Int("count", len(historyMessages)),
)
// 将数据库消息转换为Agent消息格式
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
for i, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
contentPreview := msg.Content
if len(contentPreview) > 50 {
contentPreview = contentPreview[:50] + "..."
h.logger.Warn("从ReAct数据加载历史消息失败使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
agentHistoryMessages = []agent.ChatMessage{}
} else {
// 将数据库消息转换为Agent消息格式
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
h.logger.Info("添加历史消息",
zap.Int("index", i),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
h.logger.Info("历史消息转换完成",
zap.Int("originalCount", len(historyMessages)),
zap.Int("convertedCount", len(agentHistoryMessages)),
)
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
@@ -132,6 +121,16 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
// 即使执行失败也尝试保存ReAct数据如果result中有
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
} else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
}
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -142,6 +141,15 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Error("保存助手消息失败", zap.Error(err))
}
// 保存最后一轮ReAct的输入和输出
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
}
}
c.JSON(http.StatusOK, ChatResponse{
Response: result.Response,
MCPExecutionIDs: result.MCPExecutionIDs,
@@ -246,20 +254,28 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
"conversationId": conversationID,
})
// 获取历史消息
historyMessages, err := h.db.GetMessages(conversationID)
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
historyMessages = []database.Message{}
}
// 将数据库消息转换为Agent消息格式
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
h.logger.Warn("从ReAct数据加载历史消息失败使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
agentHistoryMessages = []agent.ChatMessage{}
} else {
// 将数据库消息转换为Agent消息格式
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
@@ -499,6 +515,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
}
// 即使任务被取消也尝试保存ReAct数据如果result中有
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
} else {
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
}
}
sendEvent("cancelled", cancelMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
@@ -524,6 +550,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
}
// 即使任务超时也尝试保存ReAct数据如果result中有
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
} else {
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
}
}
sendEvent("error", timeoutMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
@@ -549,6 +585,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
}
// 即使任务失败也尝试保存ReAct数据如果result中有
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
} else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
}
}
sendEvent("error", errorMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
@@ -585,6 +631,15 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
}
// 保存最后一轮ReAct的输入和输出
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
}
}
// 发送最终响应
sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": result.MCPExecutionIDs,
@@ -632,3 +687,140 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
"tasks": h.tasks.GetActiveTasks(),
})
}
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
if err != nil {
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
}
if reactInputJSON == "" {
return nil, fmt.Errorf("ReAct数据为空")
}
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
}
// 转换为Agent消息格式
agentMessages := make([]agent.ChatMessage, 0, len(messagesArray))
for _, msgMap := range messagesArray {
msg := agent.ChatMessage{}
// 解析role
if role, ok := msgMap["role"].(string); ok {
msg.Role = role
} else {
continue // 跳过无效消息
}
// 跳过system消息AgentLoop会重新添加
if msg.Role == "system" {
continue
}
// 解析content
if content, ok := msgMap["content"].(string); ok {
msg.Content = content
}
// 解析tool_calls如果存在
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
msg.ToolCalls = make([]agent.ToolCall, 0, len(toolCallsArray))
for _, tcRaw := range toolCallsArray {
if tcMap, ok := tcRaw.(map[string]interface{}); ok {
toolCall := agent.ToolCall{}
// 解析ID
if id, ok := tcMap["id"].(string); ok {
toolCall.ID = id
}
// 解析Type
if toolType, ok := tcMap["type"].(string); ok {
toolCall.Type = toolType
}
// 解析Function
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
toolCall.Function = agent.FunctionCall{}
// 解析函数名
if name, ok := funcMap["name"].(string); ok {
toolCall.Function.Name = name
}
// 解析arguments可能是字符串或对象
if argsRaw, ok := funcMap["arguments"]; ok {
if argsStr, ok := argsRaw.(string); ok {
// 如果是字符串解析为JSON
var argsMap map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
toolCall.Function.Arguments = argsMap
}
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
// 如果已经是对象,直接使用
toolCall.Function.Arguments = argsMap
}
}
}
if toolCall.ID != "" {
msg.ToolCalls = append(msg.ToolCalls, toolCall)
}
}
}
}
}
// 解析tool_call_idtool角色消息
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
msg.ToolCallID = toolCallID
}
agentMessages = append(agentMessages, msg)
}
// 如果存在last_react_output需要将其作为最后一条assistant消息
// 因为last_react_input是在迭代开始前保存的不包含最后一轮的最终输出
if reactOutput != "" {
// 检查最后一条消息是否是assistant消息且没有tool_calls
// 如果有tool_calls说明后面应该还有tool消息和最终的assistant回复
if len(agentMessages) > 0 {
lastMsg := &agentMessages[len(agentMessages)-1]
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
// 最后一条是assistant消息且没有tool_calls用最终输出更新其content
lastMsg.Content = reactOutput
} else {
// 最后一条不是assistant消息或者有tool_calls添加最终输出作为新的assistant消息
agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant",
Content: reactOutput,
})
}
} else {
// 如果没有消息,直接添加最终输出
agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant",
Content: reactOutput,
})
}
}
if len(agentMessages) == 0 {
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
}
h.logger.Info("从ReAct数据恢复历史消息",
zap.String("conversationId", conversationID),
zap.Int("messageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""),
)
return agentMessages, nil
}