From 4a183078eaab51371e64db6a0f66c19b60ef51c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Wed, 24 Dec 2025 06:48:34 +0800 Subject: [PATCH] Add files via upload --- internal/attackchain/builder.go | 93 ++++++++++++++++++++++----------- internal/handler/agent.go | 25 +++++++-- 2 files changed, 83 insertions(+), 35 deletions(-) diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go index 7d61988c..0e939117 100644 --- a/internal/attackchain/builder.go +++ b/internal/attackchain/builder.go @@ -14,7 +14,6 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" - "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/openai" "github.com/google/uuid" @@ -146,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID modelOutput = "" } - var userInput string + // var userInput string var reactInputFinal string var dataSource string // 记录数据来源 @@ -173,7 +172,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID zap.Int("modelOutputSize", len(modelOutput))) // 从保存的ReAct输入(JSON格式)中提取用户输入 - userInput = b.extractUserInputFromReActInput(reactInputJSON) + // userInput = b.extractUserInputFromReActInput(reactInputJSON) // 将JSON格式的messages转换为可读格式 reactInputFinal = b.formatReActInputFromJSON(reactInputJSON) @@ -188,7 +187,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID // 提取用户输入(最后一条user消息) for i := len(messages) - 1; i >= 0; i-- { if strings.EqualFold(messages[i].Role, "user") { - userInput = messages[i].Content + // userInput = messages[i].Content break } } @@ -206,7 +205,8 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID } // 3. 构建简化的prompt,一次性传递给大模型 - prompt := b.buildSimplePrompt(userInput, reactInputFinal, modelOutput) + prompt := b.buildSimplePrompt(reactInputFinal, modelOutput) + // fmt.Println(prompt) // 6. 调用AI生成攻击链(一次性,不做任何处理) chainJSON, err := b.callAIForChainGeneration(ctx, prompt) if err != nil { @@ -214,7 +214,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID } // 7. 解析JSON并生成节点/边ID(前端需要有效的ID) - chainData, err := b.parseChainJSON(chainJSON, nil) // executions为nil,因为我们不再使用tool_execution_id + chainData, err := b.parseChainJSON(chainJSON) if err != nil { // 如果解析失败,返回空链,让前端处理错误 b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON)) @@ -250,25 +250,25 @@ func (b *Builder) buildReActInput(messages []database.Message) string { } // extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入 -func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string { - // reactInputJSON是JSON格式的ChatMessage数组,需要解析 - var messages []map[string]interface{} - if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { - b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) - return "" - } +// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string { +// // reactInputJSON是JSON格式的ChatMessage数组,需要解析 +// var messages []map[string]interface{} +// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { +// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) +// return "" +// } - // 从后往前查找最后一条user消息 - for i := len(messages) - 1; i >= 0; i-- { - if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") { - if content, ok := messages[i]["content"].(string); ok { - return content - } - } - } +// // 从后往前查找最后一条user消息 +// for i := len(messages) - 1; i >= 0; i-- { +// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") { +// if content, ok := messages[i]["content"].(string); ok { +// return content +// } +// } +// } - return "" -} +// return "" +// } // formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { @@ -283,13 +283,45 @@ func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { role, _ := msg["role"].(string) content, _ := msg["content"].(string) - // 如果content为空但存在tool_calls,标记为工具调用消息 - if content == "" { + // 处理assistant消息:提取tool_calls信息 + if role == "assistant" { if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { - content = fmt.Sprintf("[工具调用: %d个]", len(toolCalls)) + // 如果有文本内容,先显示 + if content != "" { + builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content)) + } + // 详细显示每个工具调用 + builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls))) + for i, toolCall := range toolCalls { + if tc, ok := toolCall.(map[string]interface{}); ok { + toolCallID, _ := tc["id"].(string) + if funcData, ok := tc["function"].(map[string]interface{}); ok { + toolName, _ := funcData["name"].(string) + arguments, _ := funcData["arguments"].(string) + builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1)) + builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID)) + builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName)) + builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments)) + } + } + } + builder.WriteString("\n") + continue } } + // 处理tool消息:显示tool_call_id和完整内容 + if role == "tool" { + toolCallID, _ := msg["tool_call_id"].(string) + if toolCallID != "" { + builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content)) + } else { + builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) + } + continue + } + + // 其他消息类型(system, user等)正常显示 builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) } @@ -297,7 +329,7 @@ func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { } // buildSimplePrompt 构建简化的prompt -func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) string { +func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { return fmt.Sprintf(`你是一个专业的安全测试分析师。请根据以下信息生成攻击链图。 ## ⚠️ 重要原则 - 严禁杜撰 @@ -310,8 +342,7 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s 如果ReAct输入中没有实际的工具执行记录,或者模型输出中明确表示任务未完成/被取消,必须返回空的攻击链(空的nodes和edges数组)。 -## 用户输入 -%s + ## 最后一轮ReAct的输入(历史对话上下文) %s @@ -360,7 +391,7 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s **再次强调:如果没有实际数据,返回空的nodes和edges数组。严禁杜撰!** -只返回JSON,不要包含其他解释文字。`, userInput, reactInput, modelOutput) +只返回JSON,不要包含其他解释文字。`, reactInput, modelOutput) } // saveChain 保存攻击链到数据库(简化版本,移除tool_execution_id) @@ -481,7 +512,7 @@ type ChainJSON struct { } // parseChainJSON 解析攻击链JSON -func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) { +func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) { var chainData ChainJSON if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil { return nil, fmt.Errorf("解析JSON失败: %w", err) diff --git a/internal/handler/agent.go b/internal/handler/agent.go index a21b89e8..36a8eb67 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -689,6 +689,7 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) { } // loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 +// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表 func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { // 获取保存的ReAct输入和输出 reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID) @@ -696,16 +697,30 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent. return nil, fmt.Errorf("获取ReAct数据失败: %w", err) } + // 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致) if reactInputJSON == "" { - return nil, fmt.Errorf("ReAct数据为空") + return nil, fmt.Errorf("ReAct数据为空,将使用消息表") } + dataSource := "database_last_react_input" + // 解析JSON格式的messages数组 var messagesArray []map[string]interface{} if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) } + messageCount := len(messagesArray) + + h.logger.Info("使用保存的ReAct数据恢复历史上下文", + zap.String("conversationId", conversationID), + zap.String("dataSource", dataSource), + zap.Int("reactInputSize", len(reactInputJSON)), + zap.Int("messageCount", messageCount), + zap.Int("reactOutputSize", len(reactOutput)), + ) + // fmt.Println("messagesArray:", messagesArray)//debug + // 转换为Agent消息格式 agentMessages := make([]agent.ChatMessage, 0, len(messagesArray)) for _, msgMap := range messagesArray { @@ -816,11 +831,13 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent. return nil, fmt.Errorf("从ReAct数据解析的消息为空") } - h.logger.Info("从ReAct数据恢复历史消息", + h.logger.Info("从ReAct数据恢复历史消息完成", zap.String("conversationId", conversationID), - zap.Int("messageCount", len(agentMessages)), + zap.String("dataSource", dataSource), + zap.Int("originalMessageCount", messageCount), + zap.Int("finalMessageCount", len(agentMessages)), zap.Bool("hasReactOutput", reactOutput != ""), ) - + fmt.Println("agentMessages:", agentMessages) //debug return agentMessages, nil }