Files
CyberStrikeAI/internal/attackchain/builder.go
2025-11-18 00:07:15 +08:00

1085 lines
35 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package attackchain
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Builder 攻击链构建器
type Builder struct {
db *database.DB
logger *zap.Logger
openAIClient *http.Client
openAIConfig *config.OpenAIConfig
}
// Node 攻击链节点使用database包的类型
type Node = database.AttackChainNode
// Edge 攻击链边使用database包的类型
type Edge = database.AttackChainEdge
// Chain 完整的攻击链
type Chain struct {
Nodes []Node `json:"nodes"`
Edges []Edge `json:"edges"`
}
// NewBuilder 创建新的攻击链构建器
func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder {
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}
return &Builder{
db: db,
logger: logger,
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
openAIConfig: openAIConfig,
}
}
// BuildChainFromConversation 从对话构建攻击链(一次性生成整个图)
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
b.logger.Info("开始构建攻击链(一次性生成)", zap.String("conversationId", conversationID))
// 1. 获取对话消息和工具执行记录
messages, err := b.db.GetMessages(conversationID)
if err != nil {
return nil, fmt.Errorf("获取对话消息失败: %w", err)
}
executions, err := b.getToolExecutionsByConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("获取工具执行记录失败: %w", err)
}
// 获取过程详情
processDetailsMap, err := b.db.GetProcessDetailsByConversation(conversationID)
if err != nil {
b.logger.Warn("获取过程详情失败", zap.Error(err))
processDetailsMap = make(map[string][]database.ProcessDetail)
}
if len(executions) == 0 && len(messages) == 0 {
b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID))
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 2. 准备上下文数据
contextData, err := b.prepareContextData(messages, executions, processDetailsMap)
if err != nil {
return nil, fmt.Errorf("准备上下文数据失败: %w", err)
}
// 3. 一次性生成攻击链(带重试和压缩机制)
chain, err := b.generateChainWithRetry(ctx, contextData, 5)
if err != nil {
return nil, fmt.Errorf("生成攻击链失败: %w", err)
}
// 4. 保存到数据库
if err := b.saveChain(conversationID, chain.Nodes, chain.Edges); err != nil {
b.logger.Warn("保存攻击链失败", zap.Error(err))
// 不返回错误,继续返回结果
}
b.logger.Info("攻击链构建完成",
zap.String("conversationId", conversationID),
zap.Int("nodes", len(chain.Nodes)),
zap.Int("edges", len(chain.Edges)))
return chain, nil
}
// getToolExecutionsByConversation 获取对话的工具执行记录
func (b *Builder) getToolExecutionsByConversation(conversationID string) ([]*mcp.ToolExecution, error) {
// 通过conversation_id关联messages再通过mcp_execution_ids关联tool_executions
// 简化实现:直接查询所有工具执行记录,然后过滤(实际应该优化查询)
allExecutions, err := b.db.LoadToolExecutions()
if err != nil {
return nil, err
}
// 获取对话的消息提取mcp_execution_ids
messages, err := b.db.GetMessages(conversationID)
if err != nil {
return nil, err
}
// 收集所有execution IDs
executionIDSet := make(map[string]bool)
for _, msg := range messages {
if len(msg.MCPExecutionIDs) > 0 {
for _, id := range msg.MCPExecutionIDs {
executionIDSet[id] = true
}
}
}
// 过滤执行记录
var filteredExecutions []*mcp.ToolExecution
for _, exec := range allExecutions {
if executionIDSet[exec.ID] {
filteredExecutions = append(filteredExecutions, exec)
}
}
// 按时间排序
sort.Slice(filteredExecutions, func(i, j int) bool {
return filteredExecutions[i].StartTime.Before(filteredExecutions[j].StartTime)
})
return filteredExecutions, nil
}
// saveChain 保存攻击链到数据库
func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error {
// 先删除旧的攻击链数据
if err := b.db.DeleteAttackChain(conversationID); err != nil {
b.logger.Warn("删除旧攻击链失败", zap.Error(err))
}
// 保存节点
for _, node := range nodes {
metadataJSON, _ := json.Marshal(node.Metadata)
if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, node.ToolExecutionID, string(metadataJSON), node.RiskScore); err != nil {
b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err))
}
}
// 保存边
for _, edge := range edges {
if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil {
b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err))
}
}
return nil
}
// LoadChainFromDatabase 从数据库加载攻击链
func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) {
nodes, err := b.db.LoadAttackChainNodes(conversationID)
if err != nil {
return nil, fmt.Errorf("加载攻击链节点失败: %w", err)
}
edges, err := b.db.LoadAttackChainEdges(conversationID)
if err != nil {
return nil, fmt.Errorf("加载攻击链边失败: %w", err)
}
return &Chain{
Nodes: nodes,
Edges: edges,
}, nil
}
// ContextData 上下文数据(用于一次性生成攻击链)
type ContextData struct {
Messages []database.Message `json:"messages"`
Executions []*mcp.ToolExecution `json:"executions"`
ProcessDetails map[string][]database.ProcessDetail `json:"process_details"`
SummarizedItems map[string]string `json:"summarized_items"` // 已总结的项目key: 原始ID, value: 总结内容)
}
// prepareContextData 准备上下文数据
func (b *Builder) prepareContextData(messages []database.Message, executions []*mcp.ToolExecution, processDetails map[string][]database.ProcessDetail) (*ContextData, error) {
return &ContextData{
Messages: messages,
Executions: executions,
ProcessDetails: processDetails,
SummarizedItems: make(map[string]string),
}, nil
}
// generateChainWithRetry 生成攻击链(带重试和压缩机制)
func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) {
for attempt := 0; attempt < maxRetries; attempt++ {
b.logger.Info("尝试生成攻击链",
zap.Int("attempt", attempt+1),
zap.Int("maxRetries", maxRetries))
// 构建提示词
prompt, err := b.buildChainGenerationPrompt(contextData)
if err != nil {
return nil, fmt.Errorf("构建提示词失败: %w", err)
}
// 调用AI生成攻击链
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
if err != nil {
// 检查是否是上下文过长错误
if strings.Contains(err.Error(), "context length") || strings.Contains(err.Error(), "too long") || strings.Contains(err.Error(), "context length exceeded") {
b.logger.Warn("上下文过长,尝试压缩",
zap.Int("attempt", attempt+1),
zap.Error(err))
// 压缩最长的子节点
if err := b.compressLongestItem(ctx, contextData); err != nil {
return nil, fmt.Errorf("压缩上下文失败: %w", err)
}
// 重试
continue
}
return nil, fmt.Errorf("AI生成失败: %w", err)
}
// 解析JSON传入executions用于ID映射
chain, err := b.parseChainJSON(chainJSON, contextData.Executions)
if err != nil {
return nil, fmt.Errorf("解析攻击链JSON失败: %w", err)
}
return chain, nil
}
return nil, fmt.Errorf("生成攻击链失败:超过最大重试次数 %d", maxRetries)
}
// buildChainGenerationPrompt 构建攻击链生成提示词
func (b *Builder) buildChainGenerationPrompt(contextData *ContextData) (string, error) {
var promptBuilder strings.Builder
promptBuilder.WriteString(`你是一个专业的安全测试分析师。请根据以下对话和工具执行记录,生成清晰、有教育意义的攻击链图。
## 核心原则
**目标:让不懂渗透测试的同学可以通过这个攻击链路学习到知识,而不是无数个节点看花眼。**
**即便某些工具执行或漏洞挖掘没有成功,只要它们提供了关键线索、错误提示或下一步思路,也要被保留下来。**
## 任务要求
1. **节点类型简化只保留3种**
- **target目标**从用户输入中提取测试目标IP、域名、URL等
- **重要如果对话中测试了多个不同的目标如先测试A网页后测试B网页必须为每个不同的目标创建独立的target节点**
- 每个target节点只关联属于它的action节点通过工具执行参数中的目标来判断
- 不同目标的action节点之间**不应该**建立关联关系
- **action行动****工具执行 + AI分析结果 = 一个action节点**
- 将每个工具执行和AI对该工具结果的分析合并为一个action节点
- 节点标签应该清晰描述"做了什么"、"得到了什么结果或线索"(例如:"使用Nmap扫描端口发现22、80、443端口开放" 或 "尝试SQLmap虽然失败但提示存在WAF拦截"
- 默认关注成功的执行;但如果执行失败却提供了有价值的线索(错误信息、资产指纹、下一步建议等),也要保留,记为"带线索的失败"行动
- **重要action节点必须关联到正确的target节点通过工具执行参数判断目标**
- **vulnerability漏洞**从工具执行结果和AI分析中提取的**真实漏洞**(不是所有发现都是漏洞)。若验证失败但能明确表明某个漏洞利用方向不可行,可作为行动节点的线索描述,而不是漏洞节点。
2. **过滤规则(重要!)**
- **默认忽略**彻底无效的信息:完全没有输出、没有任何线索的失败执行仍需过滤
- **必须保留**下列失败执行:
- 错误信息里包含了潜在线索、受限条件、可复现的报错
- 虽未找到漏洞,但收集到了资产信息、技术栈或后续测试方向
- 用户特别关注的失败尝试
- **保留策略**:只要行动节点能给后续测试提供启发,就保留;否则忽略
3. **建立清晰的关联关系**
- target → action目标指向属于它的所有行动通过工具执行参数判断目标
- action → action行动之间的逻辑顺序按时间顺序但只连接有逻辑关系的
- **重要只连接属于同一目标的action节点不同目标的action节点之间不应该连接**
- action → vulnerability行动发现的漏洞
- vulnerability → vulnerability漏洞间的因果关系如SQL注入 → 信息泄露)
- **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接**
4. **节点属性**
- 每个节点需要id, type, label, risk_score, metadata
- action节点需要
- tool_name: 工具名称
- tool_intent: 工具调用意图(如"端口扫描"、"漏洞扫描"
- ai_analysis: AI对工具结果的分析总结简洁不超过100字失败节点需解释线索价值
- findings: 关键发现(列表)
- status: "success" | "failed_insight"(失败但有价值的线索)
- hints: ["下一步建议1", "限制条件2"](失败节点可提供的线索列表)
- vulnerability节点需要type, description, severity, location
## 对话数据
`)
// 添加消息
promptBuilder.WriteString("\n### 对话消息\n\n")
for i, msg := range contextData.Messages {
promptBuilder.WriteString(fmt.Sprintf("消息%d [%s]:\n", i+1, msg.Role))
isUserMessage := strings.EqualFold(msg.Role, "user")
// 用户输入必须原样提供给攻击链模型
if isUserMessage {
promptBuilder.WriteString(fmt.Sprintf("%s\n\n", msg.Content))
} else if summary, ok := contextData.SummarizedItems[msg.ID]; ok {
promptBuilder.WriteString(fmt.Sprintf("[已总结] %s\n\n", summary))
} else {
content := msg.Content
if len(content) > 5000 {
content = content[:5000] + "..."
}
promptBuilder.WriteString(fmt.Sprintf("%s\n\n", content))
}
// 添加过程详情
if details, ok := contextData.ProcessDetails[msg.ID]; ok {
for _, detail := range details {
if detail.EventType == "thinking" {
thinkingText := detail.Message
if summary, ok := contextData.SummarizedItems[detail.ID]; ok {
thinkingText = "[已总结] " + summary
} else if len(thinkingText) > 2000 {
thinkingText = thinkingText[:2000] + "..."
}
promptBuilder.WriteString(fmt.Sprintf("思考过程: %s\n", thinkingText))
}
}
}
promptBuilder.WriteString("\n")
}
// 添加工具执行记录关联对应的AI回复
promptBuilder.WriteString("\n### 工具执行记录包含对应的AI分析\n\n")
// 构建工具执行ID到消息的映射找到工具执行后AI的回复
execToMessageMap := b.buildExecutionToMessageMap(contextData)
for i, exec := range contextData.Executions {
// 检查是否是错误/失败的执行
isError := exec.Error != "" || (exec.Result != nil && exec.Result.IsError)
statusText := "成功"
if isError {
statusText = "失败(可能包含线索)"
}
promptBuilder.WriteString(fmt.Sprintf("执行%d [%s] (ID: %s) - 状态: %s\n", i+1, exec.ToolName, exec.ID, statusText))
promptBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments)))
if isError && exec.Error != "" {
promptBuilder.WriteString(fmt.Sprintf("错误信息: %s\n", exec.Error))
}
// 检查是否已总结
var resultText string
if exec.Result != nil {
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text + "\n"
}
}
}
// 检查结果是否为空或无效
if strings.TrimSpace(resultText) == "" {
if isError {
promptBuilder.WriteString("工具执行结果: [失败但未返回正文]\n")
} else {
promptBuilder.WriteString("工具执行结果: **已忽略(结果为空)**\n\n")
continue
}
} else {
if summary, ok := contextData.SummarizedItems[exec.ID]; ok {
promptBuilder.WriteString(fmt.Sprintf("工具执行结果: [已总结] %s\n", summary))
} else {
if len(resultText) > 5000 {
resultText = resultText[:5000] + "..."
}
promptBuilder.WriteString(fmt.Sprintf("工具执行结果: %s\n", resultText))
}
}
// 添加对应的AI分析工具执行后AI的回复
if aiMessage, ok := execToMessageMap[exec.ID]; ok {
aiContent := aiMessage.Content
if len(aiContent) > 2000 {
aiContent = aiContent[:2000] + "..."
}
promptBuilder.WriteString(fmt.Sprintf("AI分析: %s\n", aiContent))
}
promptBuilder.WriteString("\n")
}
promptBuilder.WriteString(`
## 输出格式
请以JSON格式返回攻击链格式如下
{
"nodes": [
{
"id": "node_1",
"type": "target|action|vulnerability",
"label": "节点标签清晰、简洁action节点要描述"做了什么"和"发现了什么"",
"risk_score": 0-100,
"tool_execution_id": "执行记录的真实IDaction节点必须使用上面执行记录中的ID字段",
"metadata": {
"target": "目标target节点",
"tool_name": "工具名称action节点",
"tool_intent": "工具调用意图action节点如"端口扫描"、"漏洞扫描"",
"ai_analysis": "AI对工具结果的分析总结action节点不超过100字",
"findings": ["发现1", "发现2"]action节点关键发现列表,
"vulnerability_type": "漏洞类型vulnerability节点",
"description": "描述vulnerability节点",
"severity": "critical|high|medium|lowvulnerability节点",
"location": "漏洞位置vulnerability节点"
}
}
],
"edges": [
{
"source": "node_1",
"target": "node_2",
"type": "leads_to|discovers|enables",
"weight": 1-5
}
]
}
## 重要要求
1. **节点合并**
- 每个工具执行和对应的AI分析必须合并为一个action节点
- action节点的label要清晰描述"做了什么"、"结果/线索是什么"
- 例如:"使用Nmap扫描192.168.1.1发现22、80、443端口开放" 或 "执行Sqlmap被WAF拦截提示403并暴露防护厂商"
- 若为失败但有线索的行动请在metadata.status中标记为"failed_insight"并在findings/hints里写清线索价值
2. **过滤无效节点**
- **必须忽略**没有任何输出、没有线索的失败执行
- **必须保留**失败但提供关键线索的执行确保metadata里解释清楚
- 只保留对学习或溯源有帮助的节点
3. **简化结构**
- 只创建target、action、vulnerability三种节点
- 不要创建discovery、decision等节点
- 让攻击链清晰、有教育意义
4. **关联关系**
- target → action目标指向属于它的所有行动通过工具执行参数判断目标
- action → action按时间顺序连接但只连接有逻辑关系的
- **重要只连接属于同一目标的action节点不同目标的action节点之间不应该连接**
- action → vulnerability行动发现的漏洞
- vulnerability → vulnerability漏洞间的因果关系
- **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接**
5. **多目标处理(重要!)**
- 如果对话中测试了多个不同的目标如先测试A网页后测试B网页必须
- 为每个不同的目标创建独立的target节点
- 每个target节点只关联属于它的action和vulnerability节点
- 不同目标的节点之间**不应该**建立任何关联关系
- 这样会形成多个独立的攻击链分支,每个分支对应一个测试目标
6. **节点数量控制**
- 如果节点太多(>20个优先保留最重要的节点
- 合并相似的action节点如同一工具的连续调用如果结果相似
只返回JSON不要包含其他解释文字。`)
return promptBuilder.String(), nil
}
// buildExecutionToMessageMap 构建工具执行ID到AI消息的映射
// 找到每个工具执行后AI的回复消息
func (b *Builder) buildExecutionToMessageMap(contextData *ContextData) map[string]database.Message {
execToMessageMap := make(map[string]database.Message)
// 遍历消息找到包含工具执行ID的消息通常是assistant消息
for _, msg := range contextData.Messages {
if msg.Role != "assistant" {
continue
}
// 检查消息中是否引用了工具执行ID
// 通常工具执行后AI会在回复中引用这些执行ID
for _, execID := range msg.MCPExecutionIDs {
// 找到对应的工具执行
for _, exec := range contextData.Executions {
if exec.ID == execID {
// 如果这个执行还没有关联的消息,或者当前消息时间更晚,则更新
if existingMsg, exists := execToMessageMap[execID]; !exists || msg.CreatedAt.After(existingMsg.CreatedAt) {
execToMessageMap[execID] = msg
}
break
}
}
}
}
// 如果通过MCPExecutionIDs找不到尝试按时间顺序匹配
// 找到每个工具执行后最近的assistant消息
for _, exec := range contextData.Executions {
if _, exists := execToMessageMap[exec.ID]; exists {
continue
}
// 找到执行时间之后最近的assistant消息
var closestMsg *database.Message
for i := range contextData.Messages {
msg := &contextData.Messages[i]
if msg.Role == "assistant" && msg.CreatedAt.After(exec.StartTime) {
if closestMsg == nil || msg.CreatedAt.Before(closestMsg.CreatedAt) {
closestMsg = msg
}
}
}
if closestMsg != nil {
execToMessageMap[exec.ID] = *closestMsg
}
}
return execToMessageMap
}
// formatArguments 格式化工具参数
func (b *Builder) formatArguments(args map[string]interface{}) string {
if args == nil {
return "{}"
}
jsonData, _ := json.Marshal(args)
return string(jsonData)
}
// compressLongestItem 压缩最长的子节点
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
var longestID string
var longestType string
var longestContent string
maxLength := 0
// 查找最长的消息
for _, msg := range contextData.Messages {
if strings.EqualFold(msg.Role, "user") {
continue
}
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
continue
}
length := len(msg.Content)
if length > maxLength {
maxLength = length
longestID = msg.ID
longestType = "message"
longestContent = msg.Content
}
}
// 查找最长的工具执行结果
for _, exec := range contextData.Executions {
if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized {
continue
}
if exec.Result != nil {
var resultText string
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text + "\n"
}
}
length := len(resultText)
if length > maxLength {
maxLength = length
longestID = exec.ID
longestType = "execution"
longestContent = resultText
}
}
}
// 查找最长的思考过程
for _, details := range contextData.ProcessDetails {
for _, detail := range details {
if detail.EventType == "thinking" {
if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized {
continue
}
length := len(detail.Message)
if length > maxLength {
maxLength = length
longestID = detail.ID
longestType = "thinking"
longestContent = detail.Message
}
}
}
}
if longestID == "" {
return fmt.Errorf("没有找到需要压缩的内容")
}
b.logger.Info("压缩最长子节点",
zap.String("id", longestID),
zap.String("type", longestType),
zap.Int("length", maxLength))
// 使用AI总结
summary, err := b.summarizeContent(ctx, longestType, longestContent)
if err != nil {
return fmt.Errorf("总结内容失败: %w", err)
}
// 保存总结
contextData.SummarizedItems[longestID] = summary
b.logger.Info("压缩完成",
zap.String("id", longestID),
zap.Int("originalLength", maxLength),
zap.Int("summaryLength", len(summary)))
return nil
}
// summarizeContent 总结内容
func (b *Builder) summarizeContent(ctx context.Context, contentType, content string) (string, error) {
var prompt string
switch contentType {
case "message":
prompt = fmt.Sprintf(`请总结以下AI回复的关键信息保留所有重要的安全发现、漏洞信息和测试结果。用简洁的中文总结不超过500字。
AI回复
%s
总结:`, content)
case "execution":
prompt = fmt.Sprintf(`请总结以下工具执行结果的关键信息保留所有发现的漏洞、重要发现和测试结果。用简洁的中文总结不超过500字。
工具执行结果:
%s
总结:`, content)
case "thinking":
prompt = fmt.Sprintf(`请总结以下AI思考过程的关键决策和思路保留所有重要的决策点和测试策略。用简洁的中文总结不超过300字。
思考过程:
%s
总结:`, content)
default:
return "", fmt.Errorf("未知的内容类型: %s", contentType)
}
requestBody := map[string]interface{}{
"model": b.openAIConfig.Model,
"messages": []map[string]interface{}{
{
"role": "system",
"content": "你是一个专业的安全测试分析师,擅长总结安全测试相关的信息。请用简洁的中文总结关键信息。",
},
{
"role": "user",
"content": prompt,
},
},
"temperature": 0.3,
"max_tokens": 1000,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
resp, err := b.openAIClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
}
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
}
if len(apiResponse.Choices) == 0 {
return "", fmt.Errorf("API未返回有效响应")
}
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
}
// callAIForChainGeneration 调用AI生成攻击链
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
requestBody := map[string]interface{}{
"model": b.openAIConfig.Model,
"messages": []map[string]interface{}{
{
"role": "system",
"content": "你是一个专业的安全测试分析师擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。",
},
{
"role": "user",
"content": prompt,
},
},
"temperature": 0.3,
"max_tokens": 8000,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
resp, err := b.openAIClient.Do(req)
if err != nil {
// 检查是否是上下文过长错误
if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") {
return "", fmt.Errorf("context length exceeded")
}
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
bodyStr := string(body)
// 检查是否是上下文过长错误
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
return "", fmt.Errorf("context length exceeded")
}
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr)
}
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
}
if len(apiResponse.Choices) == 0 {
return "", fmt.Errorf("API未返回有效响应")
}
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
// 尝试提取JSON可能包含markdown代码块
content = strings.TrimPrefix(content, "```json")
content = strings.TrimPrefix(content, "```")
content = strings.TrimSuffix(content, "```")
content = strings.TrimSpace(content)
return content, nil
}
// ChainJSON 攻击链JSON结构
type ChainJSON struct {
Nodes []struct {
ID string `json:"id"`
Type string `json:"type"`
Label string `json:"label"`
RiskScore int `json:"risk_score"`
ToolExecutionID string `json:"tool_execution_id,omitempty"`
Metadata map[string]interface{} `json:"metadata"`
} `json:"nodes"`
Edges []struct {
Source string `json:"source"`
Target string `json:"target"`
Type string `json:"type"`
Weight int `json:"weight"`
} `json:"edges"`
}
// parseChainJSON 解析攻击链JSON
func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) {
var chainData ChainJSON
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
return nil, fmt.Errorf("解析JSON失败: %w", err)
}
// 创建execution ID映射AI可能返回简单的索引或ID需要映射到真实的execution ID
executionMap := make(map[string]string) // AI返回的ID -> 真实execution ID
for i, exec := range executions {
// 支持多种可能的AI返回格式
executionMap[fmt.Sprintf("exec_%d", i+1)] = exec.ID
executionMap[fmt.Sprintf("execution_%d", i+1)] = exec.ID
executionMap[exec.ID] = exec.ID // 如果AI直接返回真实ID
executionMap[fmt.Sprintf("tool_%d", i+1)] = exec.ID // AI可能用tool_1格式
executionMap[fmt.Sprintf("执行%d", i+1)] = exec.ID // 中文格式
executionMap[fmt.Sprintf("执行_%d", i+1)] = exec.ID
}
// 创建节点ID映射AI返回的ID -> 新的UUID
nodeIDMap := make(map[string]string)
// 转换为Chain结构并过滤无效节点
nodes := make([]Node, 0, len(chainData.Nodes))
for _, n := range chainData.Nodes {
// 过滤无效节点
if b.shouldFilterNode(n, executions) {
b.logger.Info("过滤无效节点",
zap.String("nodeID", n.ID),
zap.String("nodeType", n.Type),
zap.String("label", n.Label))
continue
}
// 生成新的UUID节点ID
newNodeID := fmt.Sprintf("node_%s", uuid.New().String())
nodeIDMap[n.ID] = newNodeID
node := Node{
ID: newNodeID,
Type: n.Type,
Label: n.Label,
RiskScore: n.RiskScore,
Metadata: n.Metadata,
}
if node.Metadata == nil {
node.Metadata = make(map[string]interface{})
}
// 处理tool_execution_id如果是action或vulnerability节点需要映射到真实的execution ID
if n.ToolExecutionID != "" {
if realExecID, ok := executionMap[n.ToolExecutionID]; ok {
node.ToolExecutionID = realExecID
} else {
// 检查是否是真实的execution IDUUID格式
// 如果是直接使用如果不是尝试从节点ID推断
if len(n.ToolExecutionID) > 20 { // UUID通常很长
node.ToolExecutionID = n.ToolExecutionID
} else {
// 可能是简单的ID尝试从节点ID推断
if realExecID, ok := executionMap[n.ID]; ok {
node.ToolExecutionID = realExecID
} else {
b.logger.Warn("无法映射tool_execution_id",
zap.String("nodeID", n.ID),
zap.String("toolExecutionID", n.ToolExecutionID))
// 对于action节点如果没有有效的execution ID清空它避免外键约束失败
if n.Type == "action" {
node.ToolExecutionID = ""
}
}
}
}
} else if n.Type == "action" || n.Type == "vulnerability" {
// 如果AI没有提供tool_execution_id尝试从节点ID推断
// 例如tool_1 -> 查找exec_1
if realExecID, ok := executionMap[n.ID]; ok {
node.ToolExecutionID = realExecID
} else {
b.logger.Warn("action/vulnerability节点缺少tool_execution_id",
zap.String("nodeID", n.ID),
zap.String("nodeType", n.Type))
}
}
nodes = append(nodes, node)
}
// 转换边更新source和target为新的节点ID
edges := make([]Edge, 0, len(chainData.Edges))
for _, e := range chainData.Edges {
sourceID, ok := nodeIDMap[e.Source]
if !ok {
b.logger.Warn("边的源节点ID未找到", zap.String("source", e.Source))
continue
}
targetID, ok := nodeIDMap[e.Target]
if !ok {
b.logger.Warn("边的目标节点ID未找到", zap.String("target", e.Target))
continue
}
edge := Edge{
ID: fmt.Sprintf("edge_%s", uuid.New().String()),
Source: sourceID,
Target: targetID,
Type: e.Type,
Weight: e.Weight,
}
edges = append(edges, edge)
}
// 过滤掉指向已删除节点的边
filteredEdges := make([]Edge, 0, len(edges))
for _, edge := range edges {
// 检查source和target节点是否都存在
sourceExists := false
targetExists := false
for _, node := range nodes {
if node.ID == edge.Source {
sourceExists = true
}
if node.ID == edge.Target {
targetExists = true
}
}
if sourceExists && targetExists {
filteredEdges = append(filteredEdges, edge)
} else {
b.logger.Warn("过滤无效边",
zap.String("edgeID", edge.ID),
zap.String("source", edge.Source),
zap.String("target", edge.Target),
zap.Bool("sourceExists", sourceExists),
zap.Bool("targetExists", targetExists))
}
}
return &Chain{
Nodes: nodes,
Edges: filteredEdges,
}, nil
}
// shouldFilterNode 判断是否应该过滤掉这个节点
func (b *Builder) shouldFilterNode(n struct {
ID string `json:"id"`
Type string `json:"type"`
Label string `json:"label"`
RiskScore int `json:"risk_score"`
ToolExecutionID string `json:"tool_execution_id,omitempty"`
Metadata map[string]interface{} `json:"metadata"`
}, executions []*mcp.ToolExecution) bool {
// 只允许target、action、vulnerability三种节点类型
if n.Type != "target" && n.Type != "action" && n.Type != "vulnerability" {
return true
}
// 对于action节点检查对应的工具执行是否有效
if n.Type == "action" {
if n.ToolExecutionID == "" {
// 没有关联工具执行的action节点可能是无效的
return true
}
// 查找对应的工具执行
var exec *mcp.ToolExecution
for _, e := range executions {
if e.ID == n.ToolExecutionID {
exec = e
break
}
}
if exec == nil {
// 找不到对应的工具执行,可能是无效的
return true
}
// 检查工具执行是否错误或失败
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
if !hasInsightfulFailure(n.Metadata) {
return true
}
}
// 检查工具执行结果是否为空
if exec.Result == nil || len(exec.Result.Content) == 0 {
if !hasInsightfulFailure(n.Metadata) {
return true
}
}
// 检查结果文本是否为空
var resultText string
if exec.Result != nil {
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text
}
}
}
if strings.TrimSpace(resultText) == "" {
if !hasInsightfulFailure(n.Metadata) {
return true
}
}
}
// 检查节点标签是否为空或无效
if strings.TrimSpace(n.Label) == "" {
return true
}
// 检查标签中是否包含错误/失败的关键词
labelLower := strings.ToLower(n.Label)
errorKeywords := []string{"错误", "失败", "无效", "error", "failed", "invalid", "empty", "空"}
for _, keyword := range errorKeywords {
if strings.Contains(labelLower, keyword) {
// 如果标签明确表示错误但节点类型不是vulnerability则过滤
if n.Type != "vulnerability" {
return true
}
}
}
return false
}
func hasInsightfulFailure(metadata map[string]interface{}) bool {
if metadata == nil {
return false
}
if status, ok := metadata["status"].(string); ok {
normalized := strings.ToLower(strings.TrimSpace(status))
if normalized == "failed_insight" || normalized == "failed_clue" || normalized == "failed_with_hint" {
return true
}
}
if hint, ok := metadata["hint"].(string); ok && strings.TrimSpace(hint) != "" {
return true
}
if hints, ok := metadata["hints"].([]interface{}); ok && len(hints) > 0 {
return true
}
if insight, ok := metadata["insight"].(string); ok && strings.TrimSpace(insight) != "" {
return true
}
return false
}