mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -14,7 +14,6 @@ import (
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -146,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
modelOutput = ""
|
||||
}
|
||||
|
||||
var userInput string
|
||||
// var userInput string
|
||||
var reactInputFinal string
|
||||
var dataSource string // 记录数据来源
|
||||
|
||||
@@ -173,7 +172,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
zap.Int("modelOutputSize", len(modelOutput)))
|
||||
|
||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
||||
userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
|
||||
// 将JSON格式的messages转换为可读格式
|
||||
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
|
||||
@@ -188,7 +187,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
// 提取用户输入(最后一条user消息)
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "user") {
|
||||
userInput = messages[i].Content
|
||||
// userInput = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -206,7 +205,8 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
|
||||
// 3. 构建简化的prompt,一次性传递给大模型
|
||||
prompt := b.buildSimplePrompt(userInput, reactInputFinal, modelOutput)
|
||||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
||||
// fmt.Println(prompt)
|
||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||
if err != nil {
|
||||
@@ -214,7 +214,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
|
||||
// 7. 解析JSON并生成节点/边ID(前端需要有效的ID)
|
||||
chainData, err := b.parseChainJSON(chainJSON, nil) // executions为nil,因为我们不再使用tool_execution_id
|
||||
chainData, err := b.parseChainJSON(chainJSON)
|
||||
if err != nil {
|
||||
// 如果解析失败,返回空链,让前端处理错误
|
||||
b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON))
|
||||
@@ -250,25 +250,25 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
|
||||
}
|
||||
|
||||
// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入
|
||||
func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
|
||||
// reactInputJSON是JSON格式的ChatMessage数组,需要解析
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
|
||||
// // reactInputJSON是JSON格式的ChatMessage数组,需要解析
|
||||
// var messages []map[string]interface{}
|
||||
// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// 从后往前查找最后一条user消息
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
|
||||
if content, ok := messages[i]["content"].(string); ok {
|
||||
return content
|
||||
}
|
||||
}
|
||||
}
|
||||
// // 从后往前查找最后一条user消息
|
||||
// for i := len(messages) - 1; i >= 0; i-- {
|
||||
// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
|
||||
// if content, ok := messages[i]["content"].(string); ok {
|
||||
// return content
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
return ""
|
||||
}
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||||
@@ -283,13 +283,45 @@ func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||||
role, _ := msg["role"].(string)
|
||||
content, _ := msg["content"].(string)
|
||||
|
||||
// 如果content为空但存在tool_calls,标记为工具调用消息
|
||||
if content == "" {
|
||||
// 处理assistant消息:提取tool_calls信息
|
||||
if role == "assistant" {
|
||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
||||
content = fmt.Sprintf("[工具调用: %d个]", len(toolCalls))
|
||||
// 如果有文本内容,先显示
|
||||
if content != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||
}
|
||||
// 详细显示每个工具调用
|
||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
||||
for i, toolCall := range toolCalls {
|
||||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
||||
toolCallID, _ := tc["id"].(string)
|
||||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
||||
toolName, _ := funcData["name"].(string)
|
||||
arguments, _ := funcData["arguments"].(string)
|
||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 处理tool消息:显示tool_call_id和完整内容
|
||||
if role == "tool" {
|
||||
toolCallID, _ := msg["tool_call_id"].(string)
|
||||
if toolCallID != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他消息类型(system, user等)正常显示
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
|
||||
@@ -297,7 +329,7 @@ func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||||
}
|
||||
|
||||
// buildSimplePrompt 构建简化的prompt
|
||||
func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) string {
|
||||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
return fmt.Sprintf(`你是一个专业的安全测试分析师。请根据以下信息生成攻击链图。
|
||||
|
||||
## ⚠️ 重要原则 - 严禁杜撰
|
||||
@@ -310,8 +342,7 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
|
||||
|
||||
如果ReAct输入中没有实际的工具执行记录,或者模型输出中明确表示任务未完成/被取消,必须返回空的攻击链(空的nodes和edges数组)。
|
||||
|
||||
## 用户输入
|
||||
%s
|
||||
|
||||
|
||||
## 最后一轮ReAct的输入(历史对话上下文)
|
||||
%s
|
||||
@@ -360,7 +391,7 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
|
||||
|
||||
**再次强调:如果没有实际数据,返回空的nodes和edges数组。严禁杜撰!**
|
||||
|
||||
只返回JSON,不要包含其他解释文字。`, userInput, reactInput, modelOutput)
|
||||
只返回JSON,不要包含其他解释文字。`, reactInput, modelOutput)
|
||||
}
|
||||
|
||||
// saveChain 保存攻击链到数据库(简化版本,移除tool_execution_id)
|
||||
@@ -481,7 +512,7 @@ type ChainJSON struct {
|
||||
}
|
||||
|
||||
// parseChainJSON 解析攻击链JSON
|
||||
func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) {
|
||||
func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) {
|
||||
var chainData ChainJSON
|
||||
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
|
||||
return nil, fmt.Errorf("解析JSON失败: %w", err)
|
||||
|
||||
@@ -689,6 +689,7 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||
}
|
||||
|
||||
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
|
||||
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表
|
||||
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
|
||||
// 获取保存的ReAct输入和输出
|
||||
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
|
||||
@@ -696,16 +697,30 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
|
||||
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致)
|
||||
if reactInputJSON == "" {
|
||||
return nil, fmt.Errorf("ReAct数据为空")
|
||||
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
|
||||
}
|
||||
|
||||
dataSource := "database_last_react_input"
|
||||
|
||||
// 解析JSON格式的messages数组
|
||||
var messagesArray []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
|
||||
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
|
||||
}
|
||||
|
||||
messageCount := len(messagesArray)
|
||||
|
||||
h.logger.Info("使用保存的ReAct数据恢复历史上下文",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||||
zap.Int("messageCount", messageCount),
|
||||
zap.Int("reactOutputSize", len(reactOutput)),
|
||||
)
|
||||
// fmt.Println("messagesArray:", messagesArray)//debug
|
||||
|
||||
// 转换为Agent消息格式
|
||||
agentMessages := make([]agent.ChatMessage, 0, len(messagesArray))
|
||||
for _, msgMap := range messagesArray {
|
||||
@@ -816,11 +831,13 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
|
||||
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
|
||||
}
|
||||
|
||||
h.logger.Info("从ReAct数据恢复历史消息",
|
||||
h.logger.Info("从ReAct数据恢复历史消息完成",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("messageCount", len(agentMessages)),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("originalMessageCount", messageCount),
|
||||
zap.Int("finalMessageCount", len(agentMessages)),
|
||||
zap.Bool("hasReactOutput", reactOutput != ""),
|
||||
)
|
||||
|
||||
fmt.Println("agentMessages:", agentMessages) //debug
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user