mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -292,6 +292,8 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
|
||||
type AgentLoopResult struct {
|
||||
Response string
|
||||
MCPExecutionIDs []string
|
||||
LastReActInput string // 最后一轮ReAct的输入(压缩前的完整messages)
|
||||
LastReActOutput string // 最终大模型的输出
|
||||
}
|
||||
|
||||
// ProgressCallback 进度回调函数类型
|
||||
@@ -436,14 +438,48 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
MCPExecutionIDs: make([]string, 0),
|
||||
}
|
||||
|
||||
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
|
||||
var currentReActInput string
|
||||
|
||||
maxIterations := a.maxIterations
|
||||
for i := 0; i < maxIterations; i++ {
|
||||
// 在压缩前保存messages,这样即使出现异常也能保存原始数据
|
||||
messagesBeforeCompression := make([]ChatMessage, len(messages))
|
||||
copy(messagesBeforeCompression, messages)
|
||||
|
||||
// 每轮调用前先尝试压缩,防止历史消息持续膨胀
|
||||
messages = a.applyMemoryCompression(ctx, messages)
|
||||
|
||||
// 检查是否是最后一次迭代
|
||||
isLastIteration := (i == maxIterations-1)
|
||||
|
||||
// 每次迭代都保存压缩前的messages,以便在异常中断(取消、错误等)时也能保存最新的ReAct输入
|
||||
// 这样无论何时中断,都能保存当前的上下文状态
|
||||
messagesJSON, err := json.Marshal(messagesBeforeCompression)
|
||||
if err != nil {
|
||||
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
|
||||
} else {
|
||||
currentReActInput = string(messagesJSON)
|
||||
// 更新result中的值,确保始终保存最新的ReAct输入
|
||||
result.LastReActInput = currentReActInput
|
||||
}
|
||||
|
||||
// 检查上下文是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 上下文被取消(可能是用户主动暂停或其他原因)
|
||||
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
|
||||
result.LastReActInput = currentReActInput
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Response = "任务已被取消。"
|
||||
} else {
|
||||
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
|
||||
}
|
||||
result.LastReActOutput = result.Response
|
||||
return result, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 获取可用工具
|
||||
tools := a.getAvailableTools()
|
||||
|
||||
@@ -511,7 +547,12 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
sendProgress("progress", "正在调用AI模型...", nil)
|
||||
response, err := a.callOpenAI(ctx, messages, tools)
|
||||
if err != nil {
|
||||
result.Response = ""
|
||||
// API调用失败,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
|
||||
result.Response = errorMsg
|
||||
result.LastReActOutput = errorMsg
|
||||
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
|
||||
return result, fmt.Errorf("调用OpenAI失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -535,12 +576,20 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
)
|
||||
continue
|
||||
}
|
||||
result.Response = ""
|
||||
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
|
||||
result.Response = errorMsg
|
||||
result.LastReActOutput = errorMsg
|
||||
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
result.Response = ""
|
||||
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
errorMsg := "没有收到响应"
|
||||
result.Response = errorMsg
|
||||
result.LastReActOutput = errorMsg
|
||||
return result, fmt.Errorf("没有收到响应")
|
||||
}
|
||||
|
||||
@@ -658,18 +707,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages)
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 如果获取总结失败,跳出循环,让后续逻辑处理
|
||||
break
|
||||
}
|
||||
// 如果获取总结失败,跳出循环,让后续逻辑处理
|
||||
break
|
||||
}
|
||||
|
||||
continue
|
||||
@@ -703,6 +753,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
@@ -710,6 +761,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
// 如果获取总结失败,使用当前回复作为结果
|
||||
if choice.Message.Content != "" {
|
||||
result.Response = choice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
return result, nil
|
||||
}
|
||||
// 如果都没有内容,跳出循环,让后续逻辑处理
|
||||
@@ -720,6 +772,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
if choice.FinishReason == "stop" {
|
||||
sendProgress("progress", "正在生成最终回复...", nil)
|
||||
result.Response = choice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
@@ -739,6 +792,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
@@ -746,6 +800,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 如果无法生成总结,返回友好的提示
|
||||
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
|
||||
result.LastReActOutput = result.Response
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ package attackchain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -85,7 +87,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
||||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||||
|
||||
// 1. 获取对话消息
|
||||
// 0. 首先检查是否有实际的工具执行记录
|
||||
messages, err := b.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取对话消息失败: %w", err)
|
||||
@@ -96,31 +98,115 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||||
}
|
||||
|
||||
// 2. 提取用户输入(最后一条user消息)
|
||||
var userInput string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "user") {
|
||||
userInput = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 提取最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
// 最后一轮ReAct的输入 = 所有历史消息(包括当前用户输入)
|
||||
reactInput := b.buildReActInput(messages)
|
||||
|
||||
// 4. 提取大模型最后的输出(最后一条assistant消息)
|
||||
var modelOutput string
|
||||
// 检查是否有实际的工具执行(通过检查assistant消息的mcp_execution_ids)
|
||||
hasToolExecutions := false
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
modelOutput = messages[i].Content
|
||||
if len(messages[i].MCPExecutionIDs) > 0 {
|
||||
hasToolExecutions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查任务是否被取消(通过检查最后一条assistant消息内容或process_details)
|
||||
taskCancelled := false
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
content := strings.ToLower(messages[i].Content)
|
||||
if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") {
|
||||
taskCancelled = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 构建简化的prompt,一次性传递给大模型
|
||||
prompt := b.buildSimplePrompt(userInput, reactInput, modelOutput)
|
||||
// 如果任务被取消且没有实际工具执行,返回空攻击链
|
||||
if taskCancelled && !hasToolExecutions {
|
||||
b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Bool("taskCancelled", taskCancelled),
|
||||
zap.Bool("hasToolExecutions", hasToolExecutions))
|
||||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||||
}
|
||||
|
||||
// 如果没有实际工具执行,也返回空攻击链(避免AI编造)
|
||||
if !hasToolExecutions {
|
||||
b.logger.Info("没有实际工具执行记录,返回空攻击链",
|
||||
zap.String("conversationId", conversationID))
|
||||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||||
}
|
||||
|
||||
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
|
||||
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
|
||||
if err != nil {
|
||||
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
|
||||
// 继续使用原来的逻辑
|
||||
reactInputJSON = ""
|
||||
modelOutput = ""
|
||||
}
|
||||
|
||||
var userInput string
|
||||
var reactInputFinal string
|
||||
var dataSource string // 记录数据来源
|
||||
|
||||
// 如果成功获取到保存的ReAct数据,直接使用
|
||||
if reactInputJSON != "" && modelOutput != "" {
|
||||
// 计算 ReAct 输入的哈希值,用于追踪
|
||||
hash := sha256.Sum256([]byte(reactInputJSON))
|
||||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
||||
|
||||
// 统计消息数量
|
||||
var messageCount int
|
||||
var tempMessages []interface{}
|
||||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
||||
messageCount = len(tempMessages)
|
||||
}
|
||||
|
||||
dataSource = "database_last_react_input"
|
||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||||
zap.Int("messageCount", messageCount),
|
||||
zap.String("reactInputHash", reactInputHash),
|
||||
zap.Int("modelOutputSize", len(modelOutput)))
|
||||
|
||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
||||
userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
|
||||
// 将JSON格式的messages转换为可读格式
|
||||
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
|
||||
} else {
|
||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||
dataSource = "messages_table"
|
||||
b.logger.Info("从消息历史构建ReAct数据",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("messageCount", len(messages)))
|
||||
|
||||
// 提取用户输入(最后一条user消息)
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "user") {
|
||||
userInput = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 提取最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
reactInputFinal = b.buildReActInput(messages)
|
||||
|
||||
// 提取大模型最后的输出(最后一条assistant消息)
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
modelOutput = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建简化的prompt,一次性传递给大模型
|
||||
prompt := b.buildSimplePrompt(userInput, reactInputFinal, modelOutput)
|
||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||
if err != nil {
|
||||
@@ -140,6 +226,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
|
||||
b.logger.Info("攻击链构建完成",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("nodes", len(chainData.Nodes)),
|
||||
zap.Int("edges", len(chainData.Edges)))
|
||||
|
||||
@@ -162,10 +249,67 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入
|
||||
func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
|
||||
// reactInputJSON是JSON格式的ChatMessage数组,需要解析
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
|
||||
// 从后往前查找最后一条user消息
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
|
||||
if content, ok := messages[i]["content"].(string); ok {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
return reactInputJSON // 如果解析失败,返回原始JSON
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
role, _ := msg["role"].(string)
|
||||
content, _ := msg["content"].(string)
|
||||
|
||||
// 如果content为空但存在tool_calls,标记为工具调用消息
|
||||
if content == "" {
|
||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
||||
content = fmt.Sprintf("[工具调用: %d个]", len(toolCalls))
|
||||
}
|
||||
}
|
||||
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// buildSimplePrompt 构建简化的prompt
|
||||
func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) string {
|
||||
return fmt.Sprintf(`你是一个专业的安全测试分析师。请根据以下信息生成攻击链图。
|
||||
|
||||
## ⚠️ 重要原则 - 严禁杜撰
|
||||
|
||||
**严格禁止编造或推测任何内容!** 你必须:
|
||||
1. **只使用实际发生的信息**:仅基于ReAct输入中实际执行的工具调用和实际返回的结果
|
||||
2. **不要推测**:如果没有实际执行工具或发现漏洞,不要编造
|
||||
3. **不要假设**:不能仅根据URL、目标名称等推断漏洞类型
|
||||
4. **基于事实**:每个节点和边都必须有实际依据,来自工具执行结果或模型的实际输出
|
||||
|
||||
如果ReAct输入中没有实际的工具执行记录,或者模型输出中明确表示任务未完成/被取消,必须返回空的攻击链(空的nodes和edges数组)。
|
||||
|
||||
## 用户输入
|
||||
%s
|
||||
|
||||
@@ -177,10 +321,15 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
|
||||
|
||||
## 任务要求
|
||||
|
||||
请根据上述信息,生成一个清晰的攻击链图。攻击链应该包含:
|
||||
1. **target(目标)**:从用户输入中提取的测试目标
|
||||
2. **action(行动)**:从ReAct输入和模型输出中提取的关键测试步骤
|
||||
3. **vulnerability(漏洞)**:从模型输出中提取的发现的漏洞
|
||||
请根据上述信息,**仅基于实际执行的数据**生成一个清晰的攻击链图。攻击链应该包含:
|
||||
1. **target(目标)**:从用户输入中提取的实际测试目标(必须是用户明确提供的)
|
||||
2. **action(行动)**:从ReAct输入中提取的**实际执行的**工具调用和测试步骤(必须有tool_calls证据)
|
||||
3. **vulnerability(漏洞)**:从模型输出中提取的**实际发现的**漏洞(必须在输出中明确提及,不能推测)
|
||||
|
||||
**关键检查点:**
|
||||
- 如果ReAct输入中没有tool_calls,说明没有实际执行工具 → 只能生成target节点
|
||||
- 如果模型输出中没有明确提到发现的漏洞,不要编造vulnerability节点
|
||||
- 如果任务被取消或未完成,返回空攻击链
|
||||
|
||||
## 输出格式
|
||||
|
||||
@@ -194,8 +343,8 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
|
||||
"risk_score": 0-100,
|
||||
"metadata": {
|
||||
"target": "目标(target节点)",
|
||||
"tool_name": "工具名称(action节点)",
|
||||
"description": "描述(vulnerability节点)"
|
||||
"tool_name": "工具名称(action节点,必须是实际调用的工具)",
|
||||
"description": "描述(vulnerability节点,必须是实际发现的漏洞)"
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -209,6 +358,8 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
|
||||
]
|
||||
}
|
||||
|
||||
**再次强调:如果没有实际数据,返回空的nodes和edges数组。严禁杜撰!**
|
||||
|
||||
只返回JSON,不要包含其他解释文字。`, userInput, reactInput, modelOutput)
|
||||
}
|
||||
|
||||
|
||||
@@ -205,6 +205,42 @@ func (db *DB) DeleteConversation(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveReActData 保存最后一轮ReAct的输入和输出
|
||||
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
|
||||
reactInput, reactOutput, time.Now(), conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存ReAct数据失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReActData 获取最后一轮ReAct的输入和输出
|
||||
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
|
||||
var input, output sql.NullString
|
||||
err = db.QueryRow(
|
||||
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
|
||||
conversationID,
|
||||
).Scan(&input, &output)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", "", fmt.Errorf("对话不存在")
|
||||
}
|
||||
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
}
|
||||
|
||||
if input.Valid {
|
||||
reactInput = input.String
|
||||
}
|
||||
if output.Valid {
|
||||
reactOutput = output.String
|
||||
}
|
||||
|
||||
return reactInput, reactOutput, nil
|
||||
}
|
||||
|
||||
// AddMessage 添加消息
|
||||
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -3,6 +3,7 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap"
|
||||
@@ -46,7 +47,9 @@ func (db *DB) initTables() error {
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
updated_at DATETIME NOT NULL,
|
||||
last_react_input TEXT,
|
||||
last_react_output TEXT
|
||||
);`
|
||||
|
||||
// 创建消息表
|
||||
@@ -199,10 +202,58 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
// 为已有表添加新字段(如果不存在)
|
||||
if err := db.migrateConversationsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversations表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
db.logger.Info("数据库表初始化完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateConversationsTable 迁移conversations表,添加新字段
|
||||
func (db *DB) migrateConversationsTable() error {
|
||||
// 检查last_react_input字段是否存在
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误(SQLite错误信息可能不同)
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil {
|
||||
db.logger.Warn("添加last_react_input字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查last_react_output字段是否存在
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil {
|
||||
db.logger.Warn("添加last_react_output字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
|
||||
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
|
||||
@@ -87,41 +87,30 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
conversationID = conv.ID
|
||||
}
|
||||
|
||||
// 获取历史消息(排除当前消息,因为还没保存)
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
historyMessages = []database.Message{}
|
||||
}
|
||||
|
||||
h.logger.Info("获取历史消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("count", len(historyMessages)),
|
||||
)
|
||||
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for i, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
contentPreview := msg.Content
|
||||
if len(contentPreview) > 50 {
|
||||
contentPreview = contentPreview[:50] + "..."
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
h.logger.Info("添加历史消息",
|
||||
zap.Int("index", i),
|
||||
zap.String("role", msg.Role),
|
||||
zap.String("content", contentPreview),
|
||||
)
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
h.logger.Info("历史消息转换完成",
|
||||
zap.Int("originalCount", len(historyMessages)),
|
||||
zap.Int("convertedCount", len(agentHistoryMessages)),
|
||||
)
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
@@ -132,6 +121,16 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
// 即使执行失败,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
|
||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
|
||||
} else {
|
||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -142,6 +141,15 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Error("保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 保存最后一轮ReAct的输入和输出
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Response: result.Response,
|
||||
MCPExecutionIDs: result.MCPExecutionIDs,
|
||||
@@ -246,20 +254,28 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
// 获取历史消息
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
historyMessages = []database.Message{}
|
||||
}
|
||||
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
@@ -499,6 +515,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务被取消,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
@@ -524,6 +550,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务超时,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
@@ -549,6 +585,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务失败,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
@@ -585,6 +631,15 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 保存最后一轮ReAct的输入和输出
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
// 发送最终响应
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
@@ -632,3 +687,140 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||
"tasks": h.tasks.GetActiveTasks(),
|
||||
})
|
||||
}
|
||||
|
||||
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
|
||||
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
|
||||
// 获取保存的ReAct输入和输出
|
||||
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
}
|
||||
|
||||
if reactInputJSON == "" {
|
||||
return nil, fmt.Errorf("ReAct数据为空")
|
||||
}
|
||||
|
||||
// 解析JSON格式的messages数组
|
||||
var messagesArray []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
|
||||
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为Agent消息格式
|
||||
agentMessages := make([]agent.ChatMessage, 0, len(messagesArray))
|
||||
for _, msgMap := range messagesArray {
|
||||
msg := agent.ChatMessage{}
|
||||
|
||||
// 解析role
|
||||
if role, ok := msgMap["role"].(string); ok {
|
||||
msg.Role = role
|
||||
} else {
|
||||
continue // 跳过无效消息
|
||||
}
|
||||
|
||||
// 跳过system消息(AgentLoop会重新添加)
|
||||
if msg.Role == "system" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析content
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
|
||||
// 解析tool_calls(如果存在)
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
|
||||
msg.ToolCalls = make([]agent.ToolCall, 0, len(toolCallsArray))
|
||||
for _, tcRaw := range toolCallsArray {
|
||||
if tcMap, ok := tcRaw.(map[string]interface{}); ok {
|
||||
toolCall := agent.ToolCall{}
|
||||
|
||||
// 解析ID
|
||||
if id, ok := tcMap["id"].(string); ok {
|
||||
toolCall.ID = id
|
||||
}
|
||||
|
||||
// 解析Type
|
||||
if toolType, ok := tcMap["type"].(string); ok {
|
||||
toolCall.Type = toolType
|
||||
}
|
||||
|
||||
// 解析Function
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
toolCall.Function = agent.FunctionCall{}
|
||||
|
||||
// 解析函数名
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
toolCall.Function.Name = name
|
||||
}
|
||||
|
||||
// 解析arguments(可能是字符串或对象)
|
||||
if argsRaw, ok := funcMap["arguments"]; ok {
|
||||
if argsStr, ok := argsRaw.(string); ok {
|
||||
// 如果是字符串,解析为JSON
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
|
||||
// 如果已经是对象,直接使用
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolCall.ID != "" {
|
||||
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析tool_call_id(tool角色消息)
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
|
||||
agentMessages = append(agentMessages, msg)
|
||||
}
|
||||
|
||||
// 如果存在last_react_output,需要将其作为最后一条assistant消息
|
||||
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出
|
||||
if reactOutput != "" {
|
||||
// 检查最后一条消息是否是assistant消息且没有tool_calls
|
||||
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
|
||||
if len(agentMessages) > 0 {
|
||||
lastMsg := &agentMessages[len(agentMessages)-1]
|
||||
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
|
||||
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content
|
||||
lastMsg.Content = reactOutput
|
||||
} else {
|
||||
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
|
||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: reactOutput,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// 如果没有消息,直接添加最终输出
|
||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: reactOutput,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(agentMessages) == 0 {
|
||||
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
|
||||
}
|
||||
|
||||
h.logger.Info("从ReAct数据恢复历史消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("messageCount", len(agentMessages)),
|
||||
zap.Bool("hasReactOutput", reactOutput != ""),
|
||||
)
|
||||
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user