mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
1085 lines
35 KiB
Go
1085 lines
35 KiB
Go
package attackchain
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/database"
|
||
"cyberstrike-ai/internal/mcp"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// Builder 攻击链构建器
|
||
type Builder struct {
|
||
db *database.DB
|
||
logger *zap.Logger
|
||
openAIClient *http.Client
|
||
openAIConfig *config.OpenAIConfig
|
||
}
|
||
|
||
// Node 攻击链节点(使用database包的类型)
|
||
type Node = database.AttackChainNode
|
||
|
||
// Edge 攻击链边(使用database包的类型)
|
||
type Edge = database.AttackChainEdge
|
||
|
||
// Chain 完整的攻击链
|
||
type Chain struct {
|
||
Nodes []Node `json:"nodes"`
|
||
Edges []Edge `json:"edges"`
|
||
}
|
||
|
||
// NewBuilder 创建新的攻击链构建器
|
||
func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder {
|
||
transport := &http.Transport{
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 10,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
}
|
||
|
||
return &Builder{
|
||
db: db,
|
||
logger: logger,
|
||
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
|
||
openAIConfig: openAIConfig,
|
||
}
|
||
}
|
||
|
||
// BuildChainFromConversation 从对话构建攻击链(一次性生成整个图)
|
||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||
b.logger.Info("开始构建攻击链(一次性生成)", zap.String("conversationId", conversationID))
|
||
|
||
// 1. 获取对话消息和工具执行记录
|
||
messages, err := b.db.GetMessages(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取对话消息失败: %w", err)
|
||
}
|
||
|
||
executions, err := b.getToolExecutionsByConversation(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取工具执行记录失败: %w", err)
|
||
}
|
||
|
||
// 获取过程详情
|
||
processDetailsMap, err := b.db.GetProcessDetailsByConversation(conversationID)
|
||
if err != nil {
|
||
b.logger.Warn("获取过程详情失败", zap.Error(err))
|
||
processDetailsMap = make(map[string][]database.ProcessDetail)
|
||
}
|
||
|
||
if len(executions) == 0 && len(messages) == 0 {
|
||
b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID))
|
||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||
}
|
||
|
||
// 2. 准备上下文数据
|
||
contextData, err := b.prepareContextData(messages, executions, processDetailsMap)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("准备上下文数据失败: %w", err)
|
||
}
|
||
|
||
// 3. 一次性生成攻击链(带重试和压缩机制)
|
||
chain, err := b.generateChainWithRetry(ctx, contextData, 5)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成攻击链失败: %w", err)
|
||
}
|
||
|
||
// 4. 保存到数据库
|
||
if err := b.saveChain(conversationID, chain.Nodes, chain.Edges); err != nil {
|
||
b.logger.Warn("保存攻击链失败", zap.Error(err))
|
||
// 不返回错误,继续返回结果
|
||
}
|
||
|
||
b.logger.Info("攻击链构建完成",
|
||
zap.String("conversationId", conversationID),
|
||
zap.Int("nodes", len(chain.Nodes)),
|
||
zap.Int("edges", len(chain.Edges)))
|
||
|
||
return chain, nil
|
||
}
|
||
|
||
// getToolExecutionsByConversation 获取对话的工具执行记录
|
||
func (b *Builder) getToolExecutionsByConversation(conversationID string) ([]*mcp.ToolExecution, error) {
|
||
// 通过conversation_id关联messages,再通过mcp_execution_ids关联tool_executions
|
||
// 简化实现:直接查询所有工具执行记录,然后过滤(实际应该优化查询)
|
||
allExecutions, err := b.db.LoadToolExecutions()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取对话的消息,提取mcp_execution_ids
|
||
messages, err := b.db.GetMessages(conversationID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 收集所有execution IDs
|
||
executionIDSet := make(map[string]bool)
|
||
for _, msg := range messages {
|
||
if len(msg.MCPExecutionIDs) > 0 {
|
||
for _, id := range msg.MCPExecutionIDs {
|
||
executionIDSet[id] = true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 过滤执行记录
|
||
var filteredExecutions []*mcp.ToolExecution
|
||
for _, exec := range allExecutions {
|
||
if executionIDSet[exec.ID] {
|
||
filteredExecutions = append(filteredExecutions, exec)
|
||
}
|
||
}
|
||
|
||
// 按时间排序
|
||
sort.Slice(filteredExecutions, func(i, j int) bool {
|
||
return filteredExecutions[i].StartTime.Before(filteredExecutions[j].StartTime)
|
||
})
|
||
|
||
return filteredExecutions, nil
|
||
}
|
||
|
||
// saveChain 保存攻击链到数据库
|
||
func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error {
|
||
// 先删除旧的攻击链数据
|
||
if err := b.db.DeleteAttackChain(conversationID); err != nil {
|
||
b.logger.Warn("删除旧攻击链失败", zap.Error(err))
|
||
}
|
||
|
||
// 保存节点
|
||
for _, node := range nodes {
|
||
metadataJSON, _ := json.Marshal(node.Metadata)
|
||
if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, node.ToolExecutionID, string(metadataJSON), node.RiskScore); err != nil {
|
||
b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 保存边
|
||
for _, edge := range edges {
|
||
if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil {
|
||
b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// LoadChainFromDatabase 从数据库加载攻击链
|
||
func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) {
|
||
nodes, err := b.db.LoadAttackChainNodes(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载攻击链节点失败: %w", err)
|
||
}
|
||
|
||
edges, err := b.db.LoadAttackChainEdges(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载攻击链边失败: %w", err)
|
||
}
|
||
|
||
return &Chain{
|
||
Nodes: nodes,
|
||
Edges: edges,
|
||
}, nil
|
||
}
|
||
|
||
// ContextData 上下文数据(用于一次性生成攻击链)
|
||
type ContextData struct {
|
||
Messages []database.Message `json:"messages"`
|
||
Executions []*mcp.ToolExecution `json:"executions"`
|
||
ProcessDetails map[string][]database.ProcessDetail `json:"process_details"`
|
||
SummarizedItems map[string]string `json:"summarized_items"` // 已总结的项目(key: 原始ID, value: 总结内容)
|
||
}
|
||
|
||
// prepareContextData 准备上下文数据
|
||
func (b *Builder) prepareContextData(messages []database.Message, executions []*mcp.ToolExecution, processDetails map[string][]database.ProcessDetail) (*ContextData, error) {
|
||
return &ContextData{
|
||
Messages: messages,
|
||
Executions: executions,
|
||
ProcessDetails: processDetails,
|
||
SummarizedItems: make(map[string]string),
|
||
}, nil
|
||
}
|
||
|
||
// generateChainWithRetry 生成攻击链(带重试和压缩机制)
|
||
func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) {
|
||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||
b.logger.Info("尝试生成攻击链",
|
||
zap.Int("attempt", attempt+1),
|
||
zap.Int("maxRetries", maxRetries))
|
||
|
||
// 构建提示词
|
||
prompt, err := b.buildChainGenerationPrompt(contextData)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("构建提示词失败: %w", err)
|
||
}
|
||
|
||
// 调用AI生成攻击链
|
||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||
if err != nil {
|
||
// 检查是否是上下文过长错误
|
||
if strings.Contains(err.Error(), "context length") || strings.Contains(err.Error(), "too long") || strings.Contains(err.Error(), "context length exceeded") {
|
||
b.logger.Warn("上下文过长,尝试压缩",
|
||
zap.Int("attempt", attempt+1),
|
||
zap.Error(err))
|
||
|
||
// 压缩最长的子节点
|
||
if err := b.compressLongestItem(ctx, contextData); err != nil {
|
||
return nil, fmt.Errorf("压缩上下文失败: %w", err)
|
||
}
|
||
|
||
// 重试
|
||
continue
|
||
}
|
||
|
||
return nil, fmt.Errorf("AI生成失败: %w", err)
|
||
}
|
||
|
||
// 解析JSON(传入executions用于ID映射)
|
||
chain, err := b.parseChainJSON(chainJSON, contextData.Executions)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解析攻击链JSON失败: %w", err)
|
||
}
|
||
|
||
return chain, nil
|
||
}
|
||
|
||
return nil, fmt.Errorf("生成攻击链失败:超过最大重试次数 %d", maxRetries)
|
||
}
|
||
|
||
// buildChainGenerationPrompt 构建攻击链生成提示词
|
||
func (b *Builder) buildChainGenerationPrompt(contextData *ContextData) (string, error) {
|
||
var promptBuilder strings.Builder
|
||
|
||
promptBuilder.WriteString(`你是一个专业的安全测试分析师。请根据以下对话和工具执行记录,生成清晰、有教育意义的攻击链图。
|
||
|
||
## 核心原则
|
||
|
||
**目标:让不懂渗透测试的同学可以通过这个攻击链路学习到知识,而不是无数个节点看花眼。**
|
||
**即便某些工具执行或漏洞挖掘没有成功,只要它们提供了关键线索、错误提示或下一步思路,也要被保留下来。**
|
||
|
||
## 任务要求
|
||
|
||
1. **节点类型(简化,只保留3种)**:
|
||
- **target(目标)**:从用户输入中提取测试目标(IP、域名、URL等)
|
||
- **重要:如果对话中测试了多个不同的目标(如先测试A网页,后测试B网页),必须为每个不同的目标创建独立的target节点**
|
||
- 每个target节点只关联属于它的action节点(通过工具执行参数中的目标来判断)
|
||
- 不同目标的action节点之间**不应该**建立关联关系
|
||
- **action(行动)**:**工具执行 + AI分析结果 = 一个action节点**
|
||
- 将每个工具执行和AI对该工具结果的分析合并为一个action节点
|
||
- 节点标签应该清晰描述"做了什么"、"得到了什么结果或线索"(例如:"使用Nmap扫描端口,发现22、80、443端口开放" 或 "尝试SQLmap,虽然失败但提示存在WAF拦截")
|
||
- 默认关注成功的执行;但如果执行失败却提供了有价值的线索(错误信息、资产指纹、下一步建议等),也要保留,记为"带线索的失败"行动
|
||
- **重要:action节点必须关联到正确的target节点(通过工具执行参数判断目标)**
|
||
- **vulnerability(漏洞)**:从工具执行结果和AI分析中提取的**真实漏洞**(不是所有发现都是漏洞)。若验证失败但能明确表明某个漏洞利用方向不可行,可作为行动节点的线索描述,而不是漏洞节点。
|
||
|
||
2. **过滤规则(重要!)**:
|
||
- **默认忽略**彻底无效的信息:完全没有输出、没有任何线索的失败执行仍需过滤
|
||
- **必须保留**下列失败执行:
|
||
- 错误信息里包含了潜在线索、受限条件、可复现的报错
|
||
- 虽未找到漏洞,但收集到了资产信息、技术栈或后续测试方向
|
||
- 用户特别关注的失败尝试
|
||
- **保留策略**:只要行动节点能给后续测试提供启发,就保留;否则忽略
|
||
|
||
3. **建立清晰的关联关系**:
|
||
- target → action:目标指向属于它的所有行动(通过工具执行参数判断目标)
|
||
- action → action:行动之间的逻辑顺序(按时间顺序,但只连接有逻辑关系的)
|
||
- **重要:只连接属于同一目标的action节点,不同目标的action节点之间不应该连接**
|
||
- action → vulnerability:行动发现的漏洞
|
||
- vulnerability → vulnerability:漏洞间的因果关系(如SQL注入 → 信息泄露)
|
||
- **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接**
|
||
|
||
4. **节点属性**:
|
||
- 每个节点需要:id, type, label, risk_score, metadata
|
||
- action节点需要:
|
||
- tool_name: 工具名称
|
||
- tool_intent: 工具调用意图(如"端口扫描"、"漏洞扫描")
|
||
- ai_analysis: AI对工具结果的分析总结(简洁,不超过100字,失败节点需解释线索价值)
|
||
- findings: 关键发现(列表)
|
||
- status: "success" | "failed_insight"(失败但有价值的线索)
|
||
- hints: ["下一步建议1", "限制条件2"](失败节点可提供的线索列表)
|
||
- vulnerability节点需要:type, description, severity, location
|
||
|
||
## 对话数据
|
||
|
||
`)
|
||
|
||
// 添加消息
|
||
promptBuilder.WriteString("\n### 对话消息\n\n")
|
||
for i, msg := range contextData.Messages {
|
||
promptBuilder.WriteString(fmt.Sprintf("消息%d [%s]:\n", i+1, msg.Role))
|
||
|
||
isUserMessage := strings.EqualFold(msg.Role, "user")
|
||
// 用户输入必须原样提供给攻击链模型
|
||
if isUserMessage {
|
||
promptBuilder.WriteString(fmt.Sprintf("%s\n\n", msg.Content))
|
||
} else if summary, ok := contextData.SummarizedItems[msg.ID]; ok {
|
||
promptBuilder.WriteString(fmt.Sprintf("[已总结] %s\n\n", summary))
|
||
} else {
|
||
content := msg.Content
|
||
if len(content) > 5000 {
|
||
content = content[:5000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("%s\n\n", content))
|
||
}
|
||
|
||
// 添加过程详情
|
||
if details, ok := contextData.ProcessDetails[msg.ID]; ok {
|
||
for _, detail := range details {
|
||
if detail.EventType == "thinking" {
|
||
thinkingText := detail.Message
|
||
if summary, ok := contextData.SummarizedItems[detail.ID]; ok {
|
||
thinkingText = "[已总结] " + summary
|
||
} else if len(thinkingText) > 2000 {
|
||
thinkingText = thinkingText[:2000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("思考过程: %s\n", thinkingText))
|
||
}
|
||
}
|
||
}
|
||
promptBuilder.WriteString("\n")
|
||
}
|
||
|
||
// 添加工具执行记录(关联对应的AI回复)
|
||
promptBuilder.WriteString("\n### 工具执行记录(包含对应的AI分析)\n\n")
|
||
|
||
// 构建工具执行ID到消息的映射(找到工具执行后AI的回复)
|
||
execToMessageMap := b.buildExecutionToMessageMap(contextData)
|
||
|
||
for i, exec := range contextData.Executions {
|
||
// 检查是否是错误/失败的执行
|
||
isError := exec.Error != "" || (exec.Result != nil && exec.Result.IsError)
|
||
|
||
statusText := "成功"
|
||
if isError {
|
||
statusText = "失败(可能包含线索)"
|
||
}
|
||
|
||
promptBuilder.WriteString(fmt.Sprintf("执行%d [%s] (ID: %s) - 状态: %s\n", i+1, exec.ToolName, exec.ID, statusText))
|
||
promptBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments)))
|
||
|
||
if isError && exec.Error != "" {
|
||
promptBuilder.WriteString(fmt.Sprintf("错误信息: %s\n", exec.Error))
|
||
}
|
||
|
||
// 检查是否已总结
|
||
var resultText string
|
||
if exec.Result != nil {
|
||
for _, content := range exec.Result.Content {
|
||
if content.Type == "text" {
|
||
resultText += content.Text + "\n"
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查结果是否为空或无效
|
||
if strings.TrimSpace(resultText) == "" {
|
||
if isError {
|
||
promptBuilder.WriteString("工具执行结果: [失败但未返回正文]\n")
|
||
} else {
|
||
promptBuilder.WriteString("工具执行结果: **已忽略(结果为空)**\n\n")
|
||
continue
|
||
}
|
||
} else {
|
||
if summary, ok := contextData.SummarizedItems[exec.ID]; ok {
|
||
promptBuilder.WriteString(fmt.Sprintf("工具执行结果: [已总结] %s\n", summary))
|
||
} else {
|
||
if len(resultText) > 5000 {
|
||
resultText = resultText[:5000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("工具执行结果: %s\n", resultText))
|
||
}
|
||
}
|
||
|
||
// 添加对应的AI分析(工具执行后AI的回复)
|
||
if aiMessage, ok := execToMessageMap[exec.ID]; ok {
|
||
aiContent := aiMessage.Content
|
||
if len(aiContent) > 2000 {
|
||
aiContent = aiContent[:2000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("AI分析: %s\n", aiContent))
|
||
}
|
||
|
||
promptBuilder.WriteString("\n")
|
||
}
|
||
|
||
promptBuilder.WriteString(`
|
||
|
||
## 输出格式
|
||
|
||
请以JSON格式返回攻击链,格式如下:
|
||
|
||
{
|
||
"nodes": [
|
||
{
|
||
"id": "node_1",
|
||
"type": "target|action|vulnerability",
|
||
"label": "节点标签(清晰、简洁,action节点要描述"做了什么"和"发现了什么")",
|
||
"risk_score": 0-100,
|
||
"tool_execution_id": "执行记录的真实ID(action节点必须使用上面执行记录中的ID字段)",
|
||
"metadata": {
|
||
"target": "目标(target节点)",
|
||
"tool_name": "工具名称(action节点)",
|
||
"tool_intent": "工具调用意图(action节点,如"端口扫描"、"漏洞扫描")",
|
||
"ai_analysis": "AI对工具结果的分析总结(action节点,不超过100字)",
|
||
"findings": ["发现1", "发现2"](action节点,关键发现列表),
|
||
"vulnerability_type": "漏洞类型(vulnerability节点)",
|
||
"description": "描述(vulnerability节点)",
|
||
"severity": "critical|high|medium|low(vulnerability节点)",
|
||
"location": "漏洞位置(vulnerability节点)"
|
||
}
|
||
}
|
||
],
|
||
"edges": [
|
||
{
|
||
"source": "node_1",
|
||
"target": "node_2",
|
||
"type": "leads_to|discovers|enables",
|
||
"weight": 1-5
|
||
}
|
||
]
|
||
}
|
||
|
||
## 重要要求
|
||
|
||
1. **节点合并**:
|
||
- 每个工具执行和对应的AI分析必须合并为一个action节点
|
||
- action节点的label要清晰描述"做了什么"、"结果/线索是什么"
|
||
- 例如:"使用Nmap扫描192.168.1.1,发现22、80、443端口开放" 或 "执行Sqlmap被WAF拦截,提示403并暴露防护厂商"
|
||
- 若为失败但有线索的行动,请在metadata.status中标记为"failed_insight",并在findings/hints里写清线索价值
|
||
|
||
2. **过滤无效节点**:
|
||
- **必须忽略**没有任何输出、没有线索的失败执行
|
||
- **必须保留**失败但提供关键线索的执行,确保metadata里解释清楚
|
||
- 只保留对学习或溯源有帮助的节点
|
||
|
||
3. **简化结构**:
|
||
- 只创建target、action、vulnerability三种节点
|
||
- 不要创建discovery、decision等节点
|
||
- 让攻击链清晰、有教育意义
|
||
|
||
4. **关联关系**:
|
||
- target → action:目标指向属于它的所有行动(通过工具执行参数判断目标)
|
||
- action → action:按时间顺序连接,但只连接有逻辑关系的
|
||
- **重要:只连接属于同一目标的action节点,不同目标的action节点之间不应该连接**
|
||
- action → vulnerability:行动发现的漏洞
|
||
- vulnerability → vulnerability:漏洞间的因果关系
|
||
- **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接**
|
||
|
||
5. **多目标处理(重要!)**:
|
||
- 如果对话中测试了多个不同的目标(如先测试A网页,后测试B网页),必须:
|
||
- 为每个不同的目标创建独立的target节点
|
||
- 每个target节点只关联属于它的action和vulnerability节点
|
||
- 不同目标的节点之间**不应该**建立任何关联关系
|
||
- 这样会形成多个独立的攻击链分支,每个分支对应一个测试目标
|
||
|
||
6. **节点数量控制**:
|
||
- 如果节点太多(>20个),优先保留最重要的节点
|
||
- 合并相似的action节点(如同一工具的连续调用,如果结果相似)
|
||
|
||
只返回JSON,不要包含其他解释文字。`)
|
||
|
||
return promptBuilder.String(), nil
|
||
}
|
||
|
||
// buildExecutionToMessageMap 构建工具执行ID到AI消息的映射
|
||
// 找到每个工具执行后AI的回复消息
|
||
func (b *Builder) buildExecutionToMessageMap(contextData *ContextData) map[string]database.Message {
|
||
execToMessageMap := make(map[string]database.Message)
|
||
|
||
// 遍历消息,找到包含工具执行ID的消息(通常是assistant消息)
|
||
for _, msg := range contextData.Messages {
|
||
if msg.Role != "assistant" {
|
||
continue
|
||
}
|
||
|
||
// 检查消息中是否引用了工具执行ID
|
||
// 通常工具执行后,AI会在回复中引用这些执行ID
|
||
for _, execID := range msg.MCPExecutionIDs {
|
||
// 找到对应的工具执行
|
||
for _, exec := range contextData.Executions {
|
||
if exec.ID == execID {
|
||
// 如果这个执行还没有关联的消息,或者当前消息时间更晚,则更新
|
||
if existingMsg, exists := execToMessageMap[execID]; !exists || msg.CreatedAt.After(existingMsg.CreatedAt) {
|
||
execToMessageMap[execID] = msg
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果通过MCPExecutionIDs找不到,尝试按时间顺序匹配
|
||
// 找到每个工具执行后最近的assistant消息
|
||
for _, exec := range contextData.Executions {
|
||
if _, exists := execToMessageMap[exec.ID]; exists {
|
||
continue
|
||
}
|
||
|
||
// 找到执行时间之后最近的assistant消息
|
||
var closestMsg *database.Message
|
||
for i := range contextData.Messages {
|
||
msg := &contextData.Messages[i]
|
||
if msg.Role == "assistant" && msg.CreatedAt.After(exec.StartTime) {
|
||
if closestMsg == nil || msg.CreatedAt.Before(closestMsg.CreatedAt) {
|
||
closestMsg = msg
|
||
}
|
||
}
|
||
}
|
||
|
||
if closestMsg != nil {
|
||
execToMessageMap[exec.ID] = *closestMsg
|
||
}
|
||
}
|
||
|
||
return execToMessageMap
|
||
}
|
||
|
||
// formatArguments 格式化工具参数
|
||
func (b *Builder) formatArguments(args map[string]interface{}) string {
|
||
if args == nil {
|
||
return "{}"
|
||
}
|
||
jsonData, _ := json.Marshal(args)
|
||
return string(jsonData)
|
||
}
|
||
|
||
// compressLongestItem 压缩最长的子节点
|
||
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
|
||
var longestID string
|
||
var longestType string
|
||
var longestContent string
|
||
maxLength := 0
|
||
|
||
// 查找最长的消息
|
||
for _, msg := range contextData.Messages {
|
||
if strings.EqualFold(msg.Role, "user") {
|
||
continue
|
||
}
|
||
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
|
||
continue
|
||
}
|
||
length := len(msg.Content)
|
||
if length > maxLength {
|
||
maxLength = length
|
||
longestID = msg.ID
|
||
longestType = "message"
|
||
longestContent = msg.Content
|
||
}
|
||
}
|
||
|
||
// 查找最长的工具执行结果
|
||
for _, exec := range contextData.Executions {
|
||
if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized {
|
||
continue
|
||
}
|
||
if exec.Result != nil {
|
||
var resultText string
|
||
for _, content := range exec.Result.Content {
|
||
if content.Type == "text" {
|
||
resultText += content.Text + "\n"
|
||
}
|
||
}
|
||
length := len(resultText)
|
||
if length > maxLength {
|
||
maxLength = length
|
||
longestID = exec.ID
|
||
longestType = "execution"
|
||
longestContent = resultText
|
||
}
|
||
}
|
||
}
|
||
|
||
// 查找最长的思考过程
|
||
for _, details := range contextData.ProcessDetails {
|
||
for _, detail := range details {
|
||
if detail.EventType == "thinking" {
|
||
if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized {
|
||
continue
|
||
}
|
||
length := len(detail.Message)
|
||
if length > maxLength {
|
||
maxLength = length
|
||
longestID = detail.ID
|
||
longestType = "thinking"
|
||
longestContent = detail.Message
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if longestID == "" {
|
||
return fmt.Errorf("没有找到需要压缩的内容")
|
||
}
|
||
|
||
b.logger.Info("压缩最长子节点",
|
||
zap.String("id", longestID),
|
||
zap.String("type", longestType),
|
||
zap.Int("length", maxLength))
|
||
|
||
// 使用AI总结
|
||
summary, err := b.summarizeContent(ctx, longestType, longestContent)
|
||
if err != nil {
|
||
return fmt.Errorf("总结内容失败: %w", err)
|
||
}
|
||
|
||
// 保存总结
|
||
contextData.SummarizedItems[longestID] = summary
|
||
|
||
b.logger.Info("压缩完成",
|
||
zap.String("id", longestID),
|
||
zap.Int("originalLength", maxLength),
|
||
zap.Int("summaryLength", len(summary)))
|
||
|
||
return nil
|
||
}
|
||
|
||
// summarizeContent 总结内容
|
||
func (b *Builder) summarizeContent(ctx context.Context, contentType, content string) (string, error) {
|
||
var prompt string
|
||
switch contentType {
|
||
case "message":
|
||
prompt = fmt.Sprintf(`请总结以下AI回复的关键信息,保留所有重要的安全发现、漏洞信息和测试结果。用简洁的中文总结,不超过500字。
|
||
|
||
AI回复:
|
||
%s
|
||
|
||
总结:`, content)
|
||
case "execution":
|
||
prompt = fmt.Sprintf(`请总结以下工具执行结果的关键信息,保留所有发现的漏洞、重要发现和测试结果。用简洁的中文总结,不超过500字。
|
||
|
||
工具执行结果:
|
||
%s
|
||
|
||
总结:`, content)
|
||
case "thinking":
|
||
prompt = fmt.Sprintf(`请总结以下AI思考过程的关键决策和思路,保留所有重要的决策点和测试策略。用简洁的中文总结,不超过300字。
|
||
|
||
思考过程:
|
||
%s
|
||
|
||
总结:`, content)
|
||
default:
|
||
return "", fmt.Errorf("未知的内容类型: %s", contentType)
|
||
}
|
||
|
||
requestBody := map[string]interface{}{
|
||
"model": b.openAIConfig.Model,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的安全测试分析师,擅长总结安全测试相关的信息。请用简洁的中文总结关键信息。",
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
},
|
||
},
|
||
"temperature": 0.3,
|
||
"max_tokens": 1000,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||
|
||
resp, err := b.openAIClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var apiResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if len(apiResponse.Choices) == 0 {
|
||
return "", fmt.Errorf("API未返回有效响应")
|
||
}
|
||
|
||
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
|
||
}
|
||
|
||
// callAIForChainGeneration 调用AI生成攻击链
|
||
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
|
||
requestBody := map[string]interface{}{
|
||
"model": b.openAIConfig.Model,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。",
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
},
|
||
},
|
||
"temperature": 0.3,
|
||
"max_tokens": 8000,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||
|
||
resp, err := b.openAIClient.Do(req)
|
||
if err != nil {
|
||
// 检查是否是上下文过长错误
|
||
if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") {
|
||
return "", fmt.Errorf("context length exceeded")
|
||
}
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
bodyStr := string(body)
|
||
// 检查是否是上下文过长错误
|
||
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
|
||
return "", fmt.Errorf("context length exceeded")
|
||
}
|
||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr)
|
||
}
|
||
|
||
var apiResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if len(apiResponse.Choices) == 0 {
|
||
return "", fmt.Errorf("API未返回有效响应")
|
||
}
|
||
|
||
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
|
||
// 尝试提取JSON(可能包含markdown代码块)
|
||
content = strings.TrimPrefix(content, "```json")
|
||
content = strings.TrimPrefix(content, "```")
|
||
content = strings.TrimSuffix(content, "```")
|
||
content = strings.TrimSpace(content)
|
||
|
||
return content, nil
|
||
}
|
||
|
||
// ChainJSON 攻击链JSON结构
|
||
type ChainJSON struct {
|
||
Nodes []struct {
|
||
ID string `json:"id"`
|
||
Type string `json:"type"`
|
||
Label string `json:"label"`
|
||
RiskScore int `json:"risk_score"`
|
||
ToolExecutionID string `json:"tool_execution_id,omitempty"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
} `json:"nodes"`
|
||
Edges []struct {
|
||
Source string `json:"source"`
|
||
Target string `json:"target"`
|
||
Type string `json:"type"`
|
||
Weight int `json:"weight"`
|
||
} `json:"edges"`
|
||
}
|
||
|
||
// parseChainJSON 解析攻击链JSON
|
||
func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) {
|
||
var chainData ChainJSON
|
||
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
|
||
return nil, fmt.Errorf("解析JSON失败: %w", err)
|
||
}
|
||
|
||
// 创建execution ID映射(AI可能返回简单的索引或ID,需要映射到真实的execution ID)
|
||
executionMap := make(map[string]string) // AI返回的ID -> 真实execution ID
|
||
for i, exec := range executions {
|
||
// 支持多种可能的AI返回格式
|
||
executionMap[fmt.Sprintf("exec_%d", i+1)] = exec.ID
|
||
executionMap[fmt.Sprintf("execution_%d", i+1)] = exec.ID
|
||
executionMap[exec.ID] = exec.ID // 如果AI直接返回真实ID
|
||
executionMap[fmt.Sprintf("tool_%d", i+1)] = exec.ID // AI可能用tool_1格式
|
||
executionMap[fmt.Sprintf("执行%d", i+1)] = exec.ID // 中文格式
|
||
executionMap[fmt.Sprintf("执行_%d", i+1)] = exec.ID
|
||
}
|
||
|
||
// 创建节点ID映射(AI返回的ID -> 新的UUID)
|
||
nodeIDMap := make(map[string]string)
|
||
|
||
// 转换为Chain结构,并过滤无效节点
|
||
nodes := make([]Node, 0, len(chainData.Nodes))
|
||
for _, n := range chainData.Nodes {
|
||
// 过滤无效节点
|
||
if b.shouldFilterNode(n, executions) {
|
||
b.logger.Info("过滤无效节点",
|
||
zap.String("nodeID", n.ID),
|
||
zap.String("nodeType", n.Type),
|
||
zap.String("label", n.Label))
|
||
continue
|
||
}
|
||
|
||
// 生成新的UUID节点ID
|
||
newNodeID := fmt.Sprintf("node_%s", uuid.New().String())
|
||
nodeIDMap[n.ID] = newNodeID
|
||
|
||
node := Node{
|
||
ID: newNodeID,
|
||
Type: n.Type,
|
||
Label: n.Label,
|
||
RiskScore: n.RiskScore,
|
||
Metadata: n.Metadata,
|
||
}
|
||
if node.Metadata == nil {
|
||
node.Metadata = make(map[string]interface{})
|
||
}
|
||
|
||
// 处理tool_execution_id:如果是action或vulnerability节点,需要映射到真实的execution ID
|
||
if n.ToolExecutionID != "" {
|
||
if realExecID, ok := executionMap[n.ToolExecutionID]; ok {
|
||
node.ToolExecutionID = realExecID
|
||
} else {
|
||
// 检查是否是真实的execution ID(UUID格式)
|
||
// 如果是,直接使用;如果不是,尝试从节点ID推断
|
||
if len(n.ToolExecutionID) > 20 { // UUID通常很长
|
||
node.ToolExecutionID = n.ToolExecutionID
|
||
} else {
|
||
// 可能是简单的ID,尝试从节点ID推断
|
||
if realExecID, ok := executionMap[n.ID]; ok {
|
||
node.ToolExecutionID = realExecID
|
||
} else {
|
||
b.logger.Warn("无法映射tool_execution_id",
|
||
zap.String("nodeID", n.ID),
|
||
zap.String("toolExecutionID", n.ToolExecutionID))
|
||
// 对于action节点,如果没有有效的execution ID,清空它(避免外键约束失败)
|
||
if n.Type == "action" {
|
||
node.ToolExecutionID = ""
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} else if n.Type == "action" || n.Type == "vulnerability" {
|
||
// 如果AI没有提供tool_execution_id,尝试从节点ID推断
|
||
// 例如:tool_1 -> 查找exec_1
|
||
if realExecID, ok := executionMap[n.ID]; ok {
|
||
node.ToolExecutionID = realExecID
|
||
} else {
|
||
b.logger.Warn("action/vulnerability节点缺少tool_execution_id",
|
||
zap.String("nodeID", n.ID),
|
||
zap.String("nodeType", n.Type))
|
||
}
|
||
}
|
||
|
||
nodes = append(nodes, node)
|
||
}
|
||
|
||
// 转换边,更新source和target为新的节点ID
|
||
edges := make([]Edge, 0, len(chainData.Edges))
|
||
for _, e := range chainData.Edges {
|
||
sourceID, ok := nodeIDMap[e.Source]
|
||
if !ok {
|
||
b.logger.Warn("边的源节点ID未找到", zap.String("source", e.Source))
|
||
continue
|
||
}
|
||
|
||
targetID, ok := nodeIDMap[e.Target]
|
||
if !ok {
|
||
b.logger.Warn("边的目标节点ID未找到", zap.String("target", e.Target))
|
||
continue
|
||
}
|
||
|
||
edge := Edge{
|
||
ID: fmt.Sprintf("edge_%s", uuid.New().String()),
|
||
Source: sourceID,
|
||
Target: targetID,
|
||
Type: e.Type,
|
||
Weight: e.Weight,
|
||
}
|
||
edges = append(edges, edge)
|
||
}
|
||
|
||
// 过滤掉指向已删除节点的边
|
||
filteredEdges := make([]Edge, 0, len(edges))
|
||
for _, edge := range edges {
|
||
// 检查source和target节点是否都存在
|
||
sourceExists := false
|
||
targetExists := false
|
||
for _, node := range nodes {
|
||
if node.ID == edge.Source {
|
||
sourceExists = true
|
||
}
|
||
if node.ID == edge.Target {
|
||
targetExists = true
|
||
}
|
||
}
|
||
|
||
if sourceExists && targetExists {
|
||
filteredEdges = append(filteredEdges, edge)
|
||
} else {
|
||
b.logger.Warn("过滤无效边",
|
||
zap.String("edgeID", edge.ID),
|
||
zap.String("source", edge.Source),
|
||
zap.String("target", edge.Target),
|
||
zap.Bool("sourceExists", sourceExists),
|
||
zap.Bool("targetExists", targetExists))
|
||
}
|
||
}
|
||
|
||
return &Chain{
|
||
Nodes: nodes,
|
||
Edges: filteredEdges,
|
||
}, nil
|
||
}
|
||
|
||
// shouldFilterNode 判断是否应该过滤掉这个节点
|
||
func (b *Builder) shouldFilterNode(n struct {
|
||
ID string `json:"id"`
|
||
Type string `json:"type"`
|
||
Label string `json:"label"`
|
||
RiskScore int `json:"risk_score"`
|
||
ToolExecutionID string `json:"tool_execution_id,omitempty"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
}, executions []*mcp.ToolExecution) bool {
|
||
// 只允许target、action、vulnerability三种节点类型
|
||
if n.Type != "target" && n.Type != "action" && n.Type != "vulnerability" {
|
||
return true
|
||
}
|
||
|
||
// 对于action节点,检查对应的工具执行是否有效
|
||
if n.Type == "action" {
|
||
if n.ToolExecutionID == "" {
|
||
// 没有关联工具执行的action节点,可能是无效的
|
||
return true
|
||
}
|
||
|
||
// 查找对应的工具执行
|
||
var exec *mcp.ToolExecution
|
||
for _, e := range executions {
|
||
if e.ID == n.ToolExecutionID {
|
||
exec = e
|
||
break
|
||
}
|
||
}
|
||
|
||
if exec == nil {
|
||
// 找不到对应的工具执行,可能是无效的
|
||
return true
|
||
}
|
||
|
||
// 检查工具执行是否错误或失败
|
||
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
|
||
if !hasInsightfulFailure(n.Metadata) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
// 检查工具执行结果是否为空
|
||
if exec.Result == nil || len(exec.Result.Content) == 0 {
|
||
if !hasInsightfulFailure(n.Metadata) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
// 检查结果文本是否为空
|
||
var resultText string
|
||
if exec.Result != nil {
|
||
for _, content := range exec.Result.Content {
|
||
if content.Type == "text" {
|
||
resultText += content.Text
|
||
}
|
||
}
|
||
}
|
||
if strings.TrimSpace(resultText) == "" {
|
||
if !hasInsightfulFailure(n.Metadata) {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查节点标签是否为空或无效
|
||
if strings.TrimSpace(n.Label) == "" {
|
||
return true
|
||
}
|
||
|
||
// 检查标签中是否包含错误/失败的关键词
|
||
labelLower := strings.ToLower(n.Label)
|
||
errorKeywords := []string{"错误", "失败", "无效", "error", "failed", "invalid", "empty", "空"}
|
||
for _, keyword := range errorKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
// 如果标签明确表示错误,但节点类型不是vulnerability,则过滤
|
||
if n.Type != "vulnerability" {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func hasInsightfulFailure(metadata map[string]interface{}) bool {
|
||
if metadata == nil {
|
||
return false
|
||
}
|
||
|
||
if status, ok := metadata["status"].(string); ok {
|
||
normalized := strings.ToLower(strings.TrimSpace(status))
|
||
if normalized == "failed_insight" || normalized == "failed_clue" || normalized == "failed_with_hint" {
|
||
return true
|
||
}
|
||
}
|
||
|
||
if hint, ok := metadata["hint"].(string); ok && strings.TrimSpace(hint) != "" {
|
||
return true
|
||
}
|
||
|
||
if hints, ok := metadata["hints"].([]interface{}); ok && len(hints) > 0 {
|
||
return true
|
||
}
|
||
|
||
if insight, ok := metadata["insight"].(string); ok && strings.TrimSpace(insight) != "" {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|