mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
799 lines
30 KiB
Go
799 lines
30 KiB
Go
package attackchain
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/agent"
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/database"
|
||
"cyberstrike-ai/internal/openai"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// Builder 攻击链构建器
|
||
type Builder struct {
|
||
db *database.DB
|
||
logger *zap.Logger
|
||
openAIClient *openai.Client
|
||
openAIConfig *config.OpenAIConfig
|
||
tokenCounter agent.TokenCounter
|
||
maxTokens int // 最大tokens限制,默认100000
|
||
}
|
||
|
||
// Node 攻击链节点(使用database包的类型)
|
||
type Node = database.AttackChainNode
|
||
|
||
// Edge 攻击链边(使用database包的类型)
|
||
type Edge = database.AttackChainEdge
|
||
|
||
// Chain 完整的攻击链
|
||
type Chain struct {
|
||
Nodes []Node `json:"nodes"`
|
||
Edges []Edge `json:"edges"`
|
||
}
|
||
|
||
// NewBuilder 创建新的攻击链构建器
|
||
func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder {
|
||
transport := &http.Transport{
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 10,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
}
|
||
httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
|
||
|
||
// 优先使用配置文件中的统一 Token 上限(config.yaml -> openai.max_total_tokens)
|
||
maxTokens := 0
|
||
if openAIConfig != nil && openAIConfig.MaxTotalTokens > 0 {
|
||
maxTokens = openAIConfig.MaxTotalTokens
|
||
} else if openAIConfig != nil {
|
||
// 如果未显式配置 max_total_tokens,则根据模型设置一个合理的默认值
|
||
model := strings.ToLower(openAIConfig.Model)
|
||
if strings.Contains(model, "gpt-4") {
|
||
maxTokens = 128000 // gpt-4通常支持128k
|
||
} else if strings.Contains(model, "gpt-3.5") {
|
||
maxTokens = 16000 // gpt-3.5-turbo通常支持16k
|
||
} else if strings.Contains(model, "deepseek") {
|
||
maxTokens = 131072 // deepseek-chat通常支持131k
|
||
} else {
|
||
maxTokens = 100000 // 兜底默认值
|
||
}
|
||
} else {
|
||
// 没有 OpenAI 配置时使用兜底值,避免为 0
|
||
maxTokens = 100000
|
||
}
|
||
|
||
return &Builder{
|
||
db: db,
|
||
logger: logger,
|
||
openAIClient: openai.NewClient(openAIConfig, httpClient, logger),
|
||
openAIConfig: openAIConfig,
|
||
tokenCounter: agent.NewTikTokenCounter(),
|
||
maxTokens: maxTokens,
|
||
}
|
||
}
|
||
|
||
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
|
||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||
|
||
// 0. 首先检查是否有实际的工具执行记录
|
||
messages, err := b.db.GetMessages(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取对话消息失败: %w", err)
|
||
}
|
||
|
||
if len(messages) == 0 {
|
||
b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID))
|
||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||
}
|
||
|
||
// 检查是否有实际的工具执行(通过检查assistant消息的mcp_execution_ids)
|
||
hasToolExecutions := false
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||
if len(messages[i].MCPExecutionIDs) > 0 {
|
||
hasToolExecutions = true
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查任务是否被取消(通过检查最后一条assistant消息内容或process_details)
|
||
taskCancelled := false
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||
content := strings.ToLower(messages[i].Content)
|
||
if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") {
|
||
taskCancelled = true
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
// 如果任务被取消且没有实际工具执行,返回空攻击链
|
||
if taskCancelled && !hasToolExecutions {
|
||
b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链",
|
||
zap.String("conversationId", conversationID),
|
||
zap.Bool("taskCancelled", taskCancelled),
|
||
zap.Bool("hasToolExecutions", hasToolExecutions))
|
||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||
}
|
||
|
||
// 如果没有实际工具执行,也返回空攻击链(避免AI编造)
|
||
if !hasToolExecutions {
|
||
b.logger.Info("没有实际工具执行记录,返回空攻击链",
|
||
zap.String("conversationId", conversationID))
|
||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||
}
|
||
|
||
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
|
||
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
|
||
if err != nil {
|
||
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
|
||
// 继续使用原来的逻辑
|
||
reactInputJSON = ""
|
||
modelOutput = ""
|
||
}
|
||
|
||
// var userInput string
|
||
var reactInputFinal string
|
||
var dataSource string // 记录数据来源
|
||
|
||
// 如果成功获取到保存的ReAct数据,直接使用
|
||
if reactInputJSON != "" && modelOutput != "" {
|
||
// 计算 ReAct 输入的哈希值,用于追踪
|
||
hash := sha256.Sum256([]byte(reactInputJSON))
|
||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
||
|
||
// 统计消息数量
|
||
var messageCount int
|
||
var tempMessages []interface{}
|
||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
||
messageCount = len(tempMessages)
|
||
}
|
||
|
||
dataSource = "database_last_react_input"
|
||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
||
zap.String("conversationId", conversationID),
|
||
zap.String("dataSource", dataSource),
|
||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||
zap.Int("messageCount", messageCount),
|
||
zap.String("reactInputHash", reactInputHash),
|
||
zap.Int("modelOutputSize", len(modelOutput)))
|
||
|
||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||
|
||
// 将JSON格式的messages转换为可读格式
|
||
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
|
||
} else {
|
||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||
dataSource = "messages_table"
|
||
b.logger.Info("从消息历史构建ReAct数据",
|
||
zap.String("conversationId", conversationID),
|
||
zap.String("dataSource", dataSource),
|
||
zap.Int("messageCount", len(messages)))
|
||
|
||
// 提取用户输入(最后一条user消息)
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
if strings.EqualFold(messages[i].Role, "user") {
|
||
// userInput = messages[i].Content
|
||
break
|
||
}
|
||
}
|
||
|
||
// 提取最后一轮ReAct的输入(历史消息+当前用户输入)
|
||
reactInputFinal = b.buildReActInput(messages)
|
||
|
||
// 提取大模型最后的输出(最后一条assistant消息)
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||
modelOutput = messages[i].Content
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. 构建简化的prompt,一次性传递给大模型
|
||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
||
// fmt.Println(prompt)
|
||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("AI生成失败: %w", err)
|
||
}
|
||
|
||
// 7. 解析JSON并生成节点/边ID(前端需要有效的ID)
|
||
chainData, err := b.parseChainJSON(chainJSON)
|
||
if err != nil {
|
||
// 如果解析失败,返回空链,让前端处理错误
|
||
b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON))
|
||
return &Chain{
|
||
Nodes: []Node{},
|
||
Edges: []Edge{},
|
||
}, nil
|
||
}
|
||
|
||
b.logger.Info("攻击链构建完成",
|
||
zap.String("conversationId", conversationID),
|
||
zap.String("dataSource", dataSource),
|
||
zap.Int("nodes", len(chainData.Nodes)),
|
||
zap.Int("edges", len(chainData.Edges)))
|
||
|
||
// 保存到数据库(供后续加载使用)
|
||
if err := b.saveChain(conversationID, chainData.Nodes, chainData.Edges); err != nil {
|
||
b.logger.Warn("保存攻击链到数据库失败", zap.Error(err))
|
||
// 即使保存失败,也返回数据给前端
|
||
}
|
||
|
||
// 直接返回,不做任何处理和校验
|
||
return chainData, nil
|
||
}
|
||
|
||
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
||
func (b *Builder) buildReActInput(messages []database.Message) string {
|
||
var builder strings.Builder
|
||
for _, msg := range messages {
|
||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||
}
|
||
return builder.String()
|
||
}
|
||
|
||
// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入
|
||
// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
|
||
// // reactInputJSON是JSON格式的ChatMessage数组,需要解析
|
||
// var messages []map[string]interface{}
|
||
// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||
// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||
// return ""
|
||
// }
|
||
|
||
// // 从后往前查找最后一条user消息
|
||
// for i := len(messages) - 1; i >= 0; i-- {
|
||
// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
|
||
// if content, ok := messages[i]["content"].(string); ok {
|
||
// return content
|
||
// }
|
||
// }
|
||
// }
|
||
|
||
// return ""
|
||
// }
|
||
|
||
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||
var messages []map[string]interface{}
|
||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||
return reactInputJSON // 如果解析失败,返回原始JSON
|
||
}
|
||
|
||
var builder strings.Builder
|
||
for _, msg := range messages {
|
||
role, _ := msg["role"].(string)
|
||
content, _ := msg["content"].(string)
|
||
|
||
// 处理assistant消息:提取tool_calls信息
|
||
if role == "assistant" {
|
||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
||
// 如果有文本内容,先显示
|
||
if content != "" {
|
||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||
}
|
||
// 详细显示每个工具调用
|
||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
||
for i, toolCall := range toolCalls {
|
||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
||
toolCallID, _ := tc["id"].(string)
|
||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
||
toolName, _ := funcData["name"].(string)
|
||
arguments, _ := funcData["arguments"].(string)
|
||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
||
}
|
||
}
|
||
}
|
||
builder.WriteString("\n")
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 处理tool消息:显示tool_call_id和完整内容
|
||
if role == "tool" {
|
||
toolCallID, _ := msg["tool_call_id"].(string)
|
||
if toolCallID != "" {
|
||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
||
} else {
|
||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 其他消息类型(system, user等)正常显示
|
||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||
}
|
||
|
||
return builder.String()
|
||
}
|
||
|
||
// buildSimplePrompt 构建简化的prompt
|
||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
|
||
|
||
## 核心目标
|
||
|
||
构建一个能够讲述完整攻击故事的攻击链让学习者能够:
|
||
1. 理解渗透测试的完整流程和思维逻辑(从目标识别到漏洞发现的每一步)
|
||
2. 学习如何从失败中获取线索并调整策略
|
||
3. 掌握工具使用的实际效果和局限性
|
||
4. 理解漏洞发现和利用的因果关系
|
||
|
||
**关键原则**:完整性优先。必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而遗漏重要信息。
|
||
|
||
## 构建流程(按此顺序思考)
|
||
|
||
### 第一步:理解上下文
|
||
仔细分析ReAct输入中的工具调用序列和大模型输出,识别:
|
||
- 测试目标(IP、域名、URL等)
|
||
- 实际执行的工具和参数
|
||
- 工具返回的关键信息(成功结果、错误信息、超时等)
|
||
- AI的分析和决策过程
|
||
|
||
### 第二步:提取关键节点
|
||
从工具执行记录中提取有意义的节点,**确保不遗漏任何关键步骤**:
|
||
- **target节点**:每个独立的测试目标创建一个target节点
|
||
- **action节点**:每个有意义的工具执行创建一个action节点(包括提供线索的失败、成功的信息收集、漏洞验证等)
|
||
- **vulnerability节点**:每个真实确认的漏洞创建一个vulnerability节点
|
||
- **完整性检查**:对照ReAct输入中的工具调用序列,确保每个有意义的工具执行都被包含在攻击链中
|
||
|
||
### 第三步:构建逻辑关系(树状结构)
|
||
**重要:必须构建树状结构,而不是简单的线性链。**
|
||
按照因果关系连接节点,形成树状图(因为是单agent执行,所以可以不按照时间顺序):
|
||
- **分支结构**:一个节点可以有多个后续节点(例如:端口扫描发现多个端口后,可以同时进行多个不同的测试)
|
||
- **汇聚结构**:多个节点可以指向同一个节点(例如:多个不同的测试都发现了同一个漏洞)
|
||
- 识别哪些action是基于前面action的结果而执行的
|
||
- 识别哪些vulnerability是由哪些action发现的
|
||
- 识别失败节点如何为后续成功提供线索
|
||
- **避免线性链**:不要将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构
|
||
|
||
### 第四步:优化和精简
|
||
- **完整性检查**:确保所有有意义的工具执行都被包含,不要遗漏关键步骤
|
||
- **合并规则**:只合并真正相似或重复的action节点(如多次相同工具的相似调用)
|
||
- **删除规则**:只删除完全无价值的失败节点(完全无输出、纯系统错误、重复的相同失败)
|
||
- **重要提醒**:宁可保留更多节点,也不要遗漏关键步骤。攻击链必须完整展现渗透测试过程
|
||
- 确保攻击链逻辑连贯,能够讲述完整故事
|
||
|
||
## 节点类型详解
|
||
|
||
### target(目标节点)
|
||
- **用途**:标识测试目标
|
||
- **创建规则**:每个独立目标(不同IP/域名)创建一个target节点
|
||
- **多目标处理**:不同目标的节点不相互连接,各自形成独立的子图
|
||
- **metadata.target**:精确记录目标标识(IP地址、域名、URL等)
|
||
|
||
### action(行动节点)
|
||
- **用途**:记录工具执行和AI分析结果
|
||
- **标签规则**:
|
||
* 15-25个汉字,动宾结构
|
||
* 成功节点:描述执行结果(如"扫描端口发现80/443/8080"、"目录扫描发现/admin路径")
|
||
* 失败节点:描述失败原因(如"尝试SQL注入(被WAF拦截)"、"端口扫描超时(目标不可达)")
|
||
- **ai_analysis要求**:
|
||
* 成功节点:总结工具执行的关键发现,说明这些发现的意义
|
||
* 失败节点:必须说明失败原因、获得的线索、这些线索如何指引后续行动
|
||
* 不超过150字,要具体、有信息量
|
||
- **findings要求**:
|
||
* 提取工具返回结果中的关键信息点
|
||
* 每个finding应该是独立的、有价值的信息片段
|
||
* 成功节点:列出关键发现(如["80端口开放", "443端口开放", "HTTP服务为Apache 2.4"])
|
||
* 失败节点:列出失败线索(如["WAF拦截", "返回403", "检测到Cloudflare"])
|
||
- **status标记**:
|
||
* 成功节点:不设置或设为"success"
|
||
* 提供线索的失败节点:必须设为"failed_insight"
|
||
- **risk_score**:始终为0(action节点不评估风险)
|
||
|
||
### vulnerability(漏洞节点)
|
||
- **用途**:记录真实确认的安全漏洞
|
||
- **创建规则**:
|
||
* 必须是真实确认的漏洞,不是所有发现都是漏洞
|
||
* 需要明确的漏洞证据(如SQL注入返回数据库错误、XSS成功执行等)
|
||
- **risk_score规则**:
|
||
* critical(90-100):可导致系统完全沦陷(RCE、SQL注入导致数据泄露等)
|
||
* high(80-89):可导致敏感信息泄露或权限提升
|
||
* medium(60-79):存在安全风险但影响有限
|
||
* low(40-59):轻微安全问题
|
||
- **metadata要求**:
|
||
* vulnerability_type:漏洞类型(SQL注入、XSS、RCE等)
|
||
* description:详细描述漏洞位置、原理、影响
|
||
* severity:critical/high/medium/low
|
||
* location:精确的漏洞位置(URL、参数、文件路径等)
|
||
|
||
## 节点过滤和合并规则
|
||
|
||
### 必须保留的失败节点
|
||
以下失败情况必须创建节点,因为它们提供了有价值的线索:
|
||
- 工具返回明确的错误信息(权限错误、连接拒绝、认证失败等)
|
||
- 超时或连接失败(可能表明防火墙、网络隔离等)
|
||
- WAF/防火墙拦截(返回403、406等,表明存在防护机制)
|
||
- 工具未安装或配置错误(但执行了调用)
|
||
- 目标不可达(DNS解析失败、网络不通等)
|
||
|
||
### 应该删除的失败节点
|
||
以下情况不应创建节点:
|
||
- 完全无输出的工具调用
|
||
- 纯系统错误(与目标无关,如本地环境问题)
|
||
- 重复的相同失败(多次相同错误只保留第一次)
|
||
|
||
### 节点合并规则
|
||
以下情况应合并节点:
|
||
- 同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点)
|
||
- 同一目标的多个相似探测(如多个目录扫描工具,合并为一个"目录扫描"节点)
|
||
|
||
### 节点数量控制
|
||
- **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制数量而删除重要节点
|
||
- **建议范围**:单目标通常8-15个节点,但如果实际执行步骤较多,可以适当增加(最多20个节点)
|
||
- **优先保留**:关键成功步骤、提供线索的失败、发现的漏洞、重要的信息收集步骤
|
||
- **可以合并**:同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点)
|
||
- **可以删除**:完全无输出的工具调用、纯系统错误、重复的相同失败(多次相同错误只保留第一次)
|
||
- **重要原则**:宁可节点稍多,也不要遗漏关键步骤。攻击链必须能够完整展现渗透测试的完整过程
|
||
|
||
## 边的类型和权重
|
||
|
||
### 边的类型
|
||
- **leads_to**:表示"导致"或"引导到",用于action→action、target→action
|
||
* 例如:端口扫描 → 目录扫描(因为发现了80端口,所以进行目录扫描)
|
||
- **discovers**:表示"发现",**专门用于action→vulnerability**
|
||
* 例如:SQL注入测试 → SQL注入漏洞
|
||
* **重要**:所有action→vulnerability的边都必须使用discovers类型,即使多个action都指向同一个vulnerability,也应该统一使用discovers
|
||
- **enables**:表示"使能"或"促成",**仅用于vulnerability→vulnerability、action→action(当后续行动依赖前面结果时)**
|
||
* 例如:信息泄露漏洞 → 权限提升漏洞(通过信息泄露获得的信息促成了权限提升)
|
||
* **重要**:enables不能用于action→vulnerability,action→vulnerability必须使用discovers
|
||
|
||
### 边的权重
|
||
- **权重1-2**:弱关联(如初步探测到进一步探测)
|
||
- **权重3-4**:中等关联(如发现端口到服务识别)
|
||
- **权重5-7**:强关联(如发现漏洞、关键信息泄露)
|
||
- **权重8-10**:极强关联(如漏洞利用成功、权限提升)
|
||
|
||
### DAG结构要求(树状图)
|
||
- 所有边的source节点id必须小于target节点id(确保无环)
|
||
- 节点id从"node_1"开始递增
|
||
- 确保无孤立节点(每个节点至少有一条边连接)
|
||
- **树状结构要求**:
|
||
* 一个节点可以有多个后续节点(分支),例如:端口扫描节点可以同时连接到"Web服务识别"、"FTP服务识别"、"SSH服务识别"等多个节点
|
||
* 多个节点可以汇聚到一个节点(汇聚),例如:多个不同的测试都指向同一个漏洞节点
|
||
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构
|
||
|
||
## 攻击链逻辑连贯性要求
|
||
|
||
构建的攻击链应该能够回答以下问题:
|
||
1. **起点**:测试从哪里开始?(target节点)
|
||
2. **探索过程**:如何逐步收集信息?(action节点序列)
|
||
3. **失败与调整**:遇到障碍时如何调整策略?(failed_insight节点)
|
||
4. **关键发现**:发现了哪些重要信息?(action的findings)
|
||
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
||
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
||
|
||
## 最后一轮ReAct输入
|
||
|
||
%s
|
||
|
||
## 大模型输出
|
||
|
||
%s
|
||
|
||
## 输出格式
|
||
|
||
严格按照以下JSON格式输出,不要添加任何其他文字:
|
||
|
||
**重要:示例展示的是树状结构,注意node_2(端口扫描)同时连接到多个后续节点(node_3、node_4),形成分支结构。**
|
||
|
||
{
|
||
"nodes": [
|
||
{
|
||
"id": "node_1",
|
||
"type": "target",
|
||
"label": "测试目标: example.com",
|
||
"risk_score": 40,
|
||
"metadata": {
|
||
"target": "example.com"
|
||
}
|
||
},
|
||
{
|
||
"id": "node_2",
|
||
"type": "action",
|
||
"label": "扫描端口发现80/443/8080",
|
||
"risk_score": 0,
|
||
"metadata": {
|
||
"tool_name": "nmap",
|
||
"tool_intent": "端口扫描",
|
||
"ai_analysis": "使用nmap对目标进行端口扫描,发现80、443、8080端口开放。80端口运行HTTP服务,443端口运行HTTPS服务,8080端口可能为管理后台。这些开放端口为后续Web应用测试提供了入口。",
|
||
"findings": ["80端口开放", "443端口开放", "8080端口开放", "HTTP服务为Apache 2.4"]
|
||
}
|
||
},
|
||
{
|
||
"id": "node_3",
|
||
"type": "action",
|
||
"label": "目录扫描发现/admin后台",
|
||
"risk_score": 0,
|
||
"metadata": {
|
||
"tool_name": "dirsearch",
|
||
"tool_intent": "目录扫描",
|
||
"ai_analysis": "使用dirsearch对目标进行目录扫描,发现/admin目录存在且可访问。该目录可能为管理后台,是重要的测试目标。",
|
||
"findings": ["/admin目录存在", "返回200状态码", "疑似管理后台"]
|
||
}
|
||
},
|
||
{
|
||
"id": "node_4",
|
||
"type": "action",
|
||
"label": "识别Web服务为Apache 2.4",
|
||
"risk_score": 0,
|
||
"metadata": {
|
||
"tool_name": "whatweb",
|
||
"tool_intent": "Web服务识别",
|
||
"ai_analysis": "识别出目标运行Apache 2.4服务器,这为后续的漏洞测试提供了重要信息。",
|
||
"findings": ["Apache 2.4", "PHP版本信息"]
|
||
}
|
||
},
|
||
{
|
||
"id": "node_5",
|
||
"type": "action",
|
||
"label": "尝试SQL注入(被WAF拦截)",
|
||
"risk_score": 0,
|
||
"metadata": {
|
||
"tool_name": "sqlmap",
|
||
"tool_intent": "SQL注入检测",
|
||
"ai_analysis": "对/login.php进行SQL注入测试时被WAF拦截,返回403错误。错误信息显示检测到Cloudflare防护。这表明目标部署了WAF,需要调整测试策略。",
|
||
"findings": ["WAF拦截", "返回403", "检测到Cloudflare", "目标部署WAF"],
|
||
"status": "failed_insight"
|
||
}
|
||
},
|
||
{
|
||
"id": "node_6",
|
||
"type": "vulnerability",
|
||
"label": "SQL注入漏洞",
|
||
"risk_score": 85,
|
||
"metadata": {
|
||
"vulnerability_type": "SQL注入",
|
||
"description": "在/admin/login.php的username参数发现SQL注入漏洞,可通过注入payload绕过登录验证,直接获取管理员权限。漏洞返回数据库错误信息,确认存在注入点。",
|
||
"severity": "high",
|
||
"location": "/admin/login.php?username="
|
||
}
|
||
}
|
||
],
|
||
"edges": [
|
||
{
|
||
"source": "node_1",
|
||
"target": "node_2",
|
||
"type": "leads_to",
|
||
"weight": 3
|
||
},
|
||
{
|
||
"source": "node_2",
|
||
"target": "node_3",
|
||
"type": "leads_to",
|
||
"weight": 4
|
||
},
|
||
{
|
||
"source": "node_2",
|
||
"target": "node_4",
|
||
"type": "leads_to",
|
||
"weight": 3
|
||
},
|
||
{
|
||
"source": "node_3",
|
||
"target": "node_5",
|
||
"type": "leads_to",
|
||
"weight": 4
|
||
},
|
||
{
|
||
"source": "node_5",
|
||
"target": "node_6",
|
||
"type": "discovers",
|
||
"weight": 7
|
||
}
|
||
]
|
||
}
|
||
|
||
## 重要提醒
|
||
|
||
1. **严禁杜撰**:只使用ReAct输入中实际执行的工具和实际返回的结果。如无实际数据,返回空的nodes和edges数组。
|
||
2. **树状结构优先**:必须构建树状结构,而不是线性链。一个节点可以有多个后续节点(分支),多个节点可以指向同一个节点(汇聚)。避免将所有节点连成一条线。
|
||
3. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程。
|
||
4. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点。
|
||
5. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程。
|
||
6. **准确性**:所有节点信息必须基于实际数据,不要推测或假设。
|
||
7. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点。
|
||
8. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||
|
||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
||
}
|
||
|
||
// saveChain 保存攻击链到数据库
|
||
func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error {
|
||
// 先删除旧的攻击链数据
|
||
if err := b.db.DeleteAttackChain(conversationID); err != nil {
|
||
b.logger.Warn("删除旧攻击链失败", zap.Error(err))
|
||
}
|
||
|
||
for _, node := range nodes {
|
||
metadataJSON, _ := json.Marshal(node.Metadata)
|
||
if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, "", string(metadataJSON), node.RiskScore); err != nil {
|
||
b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 保存边
|
||
for _, edge := range edges {
|
||
if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil {
|
||
b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// LoadChainFromDatabase 从数据库加载攻击链
|
||
func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) {
|
||
nodes, err := b.db.LoadAttackChainNodes(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载攻击链节点失败: %w", err)
|
||
}
|
||
|
||
edges, err := b.db.LoadAttackChainEdges(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载攻击链边失败: %w", err)
|
||
}
|
||
|
||
return &Chain{
|
||
Nodes: nodes,
|
||
Edges: edges,
|
||
}, nil
|
||
}
|
||
|
||
// callAIForChainGeneration 调用AI生成攻击链
|
||
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
|
||
requestBody := map[string]interface{}{
|
||
"model": b.openAIConfig.Model,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。",
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
},
|
||
},
|
||
"temperature": 0.3,
|
||
"max_tokens": 8000,
|
||
}
|
||
|
||
var apiResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if b.openAIClient == nil {
|
||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||
}
|
||
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
|
||
var apiErr *openai.APIError
|
||
if errors.As(err, &apiErr) {
|
||
bodyStr := strings.ToLower(apiErr.Body)
|
||
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
|
||
return "", fmt.Errorf("context length exceeded")
|
||
}
|
||
} else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") {
|
||
return "", fmt.Errorf("context length exceeded")
|
||
}
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
|
||
if len(apiResponse.Choices) == 0 {
|
||
return "", fmt.Errorf("API未返回有效响应")
|
||
}
|
||
|
||
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
|
||
// 尝试提取JSON(可能包含markdown代码块)
|
||
content = strings.TrimPrefix(content, "```json")
|
||
content = strings.TrimPrefix(content, "```")
|
||
content = strings.TrimSuffix(content, "```")
|
||
content = strings.TrimSpace(content)
|
||
|
||
return content, nil
|
||
}
|
||
|
||
// ChainJSON 攻击链JSON结构
|
||
type ChainJSON struct {
|
||
Nodes []struct {
|
||
ID string `json:"id"`
|
||
Type string `json:"type"`
|
||
Label string `json:"label"`
|
||
RiskScore int `json:"risk_score"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
} `json:"nodes"`
|
||
Edges []struct {
|
||
Source string `json:"source"`
|
||
Target string `json:"target"`
|
||
Type string `json:"type"`
|
||
Weight int `json:"weight"`
|
||
} `json:"edges"`
|
||
}
|
||
|
||
// parseChainJSON 解析攻击链JSON
|
||
func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) {
|
||
var chainData ChainJSON
|
||
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
|
||
return nil, fmt.Errorf("解析JSON失败: %w", err)
|
||
}
|
||
|
||
// 创建节点ID映射(AI返回的ID -> 新的UUID)
|
||
nodeIDMap := make(map[string]string)
|
||
|
||
// 转换为Chain结构
|
||
nodes := make([]Node, 0, len(chainData.Nodes))
|
||
for _, n := range chainData.Nodes {
|
||
// 生成新的UUID节点ID
|
||
newNodeID := fmt.Sprintf("node_%s", uuid.New().String())
|
||
nodeIDMap[n.ID] = newNodeID
|
||
|
||
node := Node{
|
||
ID: newNodeID,
|
||
Type: n.Type,
|
||
Label: n.Label,
|
||
RiskScore: n.RiskScore,
|
||
Metadata: n.Metadata,
|
||
}
|
||
if node.Metadata == nil {
|
||
node.Metadata = make(map[string]interface{})
|
||
}
|
||
nodes = append(nodes, node)
|
||
}
|
||
|
||
// 转换边
|
||
edges := make([]Edge, 0, len(chainData.Edges))
|
||
for _, e := range chainData.Edges {
|
||
sourceID, ok := nodeIDMap[e.Source]
|
||
if !ok {
|
||
continue
|
||
}
|
||
targetID, ok := nodeIDMap[e.Target]
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// 生成边的ID(前端需要)
|
||
edgeID := fmt.Sprintf("edge_%s", uuid.New().String())
|
||
|
||
edges = append(edges, Edge{
|
||
ID: edgeID,
|
||
Source: sourceID,
|
||
Target: targetID,
|
||
Type: e.Type,
|
||
Weight: e.Weight,
|
||
})
|
||
}
|
||
|
||
return &Chain{
|
||
Nodes: nodes,
|
||
Edges: edges,
|
||
}, nil
|
||
}
|
||
|
||
// 以下所有方法已不再使用,已删除以简化代码
|