Add files via upload

This commit is contained in:
公明
2025-12-24 06:48:34 +08:00
committed by GitHub
parent f1355037ee
commit 4a183078ea
2 changed files with 83 additions and 35 deletions

View File

@@ -14,7 +14,6 @@ import (
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/openai"
"github.com/google/uuid"
@@ -146,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
modelOutput = ""
}
var userInput string
// var userInput string
var reactInputFinal string
var dataSource string // 记录数据来源
@@ -173,7 +172,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
zap.Int("modelOutputSize", len(modelOutput)))
// 从保存的ReAct输入JSON格式中提取用户输入
userInput = b.extractUserInputFromReActInput(reactInputJSON)
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
@@ -188,7 +187,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
// 提取用户输入最后一条user消息
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
userInput = messages[i].Content
// userInput = messages[i].Content
break
}
}
@@ -206,7 +205,8 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
}
// 3. 构建简化的prompt一次性传递给大模型
prompt := b.buildSimplePrompt(userInput, reactInputFinal, modelOutput)
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
// fmt.Println(prompt)
// 6. 调用AI生成攻击链一次性不做任何处理
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
if err != nil {
@@ -214,7 +214,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
}
// 7. 解析JSON并生成节点/边ID前端需要有效的ID
chainData, err := b.parseChainJSON(chainJSON, nil) // executions为nil因为我们不再使用tool_execution_id
chainData, err := b.parseChainJSON(chainJSON)
if err != nil {
// 如果解析失败,返回空链,让前端处理错误
b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON))
@@ -250,25 +250,25 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
}
// extractUserInputFromReActInput 从保存的ReAct输入JSON格式的messages数组中提取最后一条用户输入
func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
// reactInputJSON是JSON格式的ChatMessage数组需要解析
var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return ""
}
// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
// // reactInputJSON是JSON格式的ChatMessage数组需要解析
// var messages []map[string]interface{}
// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
// return ""
// }
// 从后往前查找最后一条user消息
for i := len(messages) - 1; i >= 0; i-- {
if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
if content, ok := messages[i]["content"].(string); ok {
return content
}
}
}
// // 从后往前查找最后一条user消息
// for i := len(messages) - 1; i >= 0; i-- {
// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
// if content, ok := messages[i]["content"].(string); ok {
// return content
// }
// }
// }
return ""
}
// return ""
// }
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
@@ -283,13 +283,45 @@ func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
// 如果content为空但存在tool_calls,标记为工具调用消
if content == "" {
// 处理assistant消息提取tool_calls
if role == "assistant" {
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
content = fmt.Sprintf("[工具调用: %d个]", len(toolCalls))
// 如果有文本内容,先显示
if content != "" {
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
}
// 详细显示每个工具调用
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
for i, toolCall := range toolCalls {
if tc, ok := toolCall.(map[string]interface{}); ok {
toolCallID, _ := tc["id"].(string)
if funcData, ok := tc["function"].(map[string]interface{}); ok {
toolName, _ := funcData["name"].(string)
arguments, _ := funcData["arguments"].(string)
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
}
}
}
builder.WriteString("\n")
continue
}
}
// 处理tool消息显示tool_call_id和完整内容
if role == "tool" {
toolCallID, _ := msg["tool_call_id"].(string)
if toolCallID != "" {
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
} else {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
continue
}
// 其他消息类型system, user等正常显示
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
@@ -297,7 +329,7 @@ func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
}
// buildSimplePrompt 构建简化的prompt
func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) string {
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
return fmt.Sprintf(`你是一个专业的安全测试分析师。请根据以下信息生成攻击链图。
## ⚠️ 重要原则 - 严禁杜撰
@@ -310,8 +342,7 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
如果ReAct输入中没有实际的工具执行记录或者模型输出中明确表示任务未完成/被取消必须返回空的攻击链空的nodes和edges数组
## 用户输入
%s
## 最后一轮ReAct的输入历史对话上下文
%s
@@ -360,7 +391,7 @@ func (b *Builder) buildSimplePrompt(userInput, reactInput, modelOutput string) s
**再次强调如果没有实际数据返回空的nodes和edges数组。严禁杜撰**
只返回JSON不要包含其他解释文字。`, userInput, reactInput, modelOutput)
只返回JSON不要包含其他解释文字。`, reactInput, modelOutput)
}
// saveChain 保存攻击链到数据库简化版本移除tool_execution_id
@@ -481,7 +512,7 @@ type ChainJSON struct {
}
// parseChainJSON 解析攻击链JSON
func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) {
func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) {
var chainData ChainJSON
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
return nil, fmt.Errorf("解析JSON失败: %w", err)

View File

@@ -689,6 +689,7 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
}
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
// 采用与攻击链生成类似的拼接逻辑优先使用保存的last_react_input和last_react_output若不存在则回退到消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
@@ -696,16 +697,30 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
}
// 如果last_react_input为空回退到使用消息表与攻击链生成逻辑一致
if reactInputJSON == "" {
return nil, fmt.Errorf("ReAct数据为空")
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
}
dataSource := "database_last_react_input"
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
}
messageCount := len(messagesArray)
h.logger.Info("使用保存的ReAct数据恢复历史上下文",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)),
zap.Int("messageCount", messageCount),
zap.Int("reactOutputSize", len(reactOutput)),
)
// fmt.Println("messagesArray:", messagesArray)//debug
// 转换为Agent消息格式
agentMessages := make([]agent.ChatMessage, 0, len(messagesArray))
for _, msgMap := range messagesArray {
@@ -816,11 +831,13 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
}
h.logger.Info("从ReAct数据恢复历史消息",
h.logger.Info("从ReAct数据恢复历史消息完成",
zap.String("conversationId", conversationID),
zap.Int("messageCount", len(agentMessages)),
zap.String("dataSource", dataSource),
zap.Int("originalMessageCount", messageCount),
zap.Int("finalMessageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""),
)
fmt.Println("agentMessages:", agentMessages) //debug
return agentMessages, nil
}