mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
1683 lines
57 KiB
Go
1683 lines
57 KiB
Go
package attackchain
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/agent"
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/database"
|
||
"cyberstrike-ai/internal/mcp"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// Builder 攻击链构建器
|
||
type Builder struct {
|
||
db *database.DB
|
||
logger *zap.Logger
|
||
openAIClient *http.Client
|
||
openAIConfig *config.OpenAIConfig
|
||
tokenCounter agent.TokenCounter
|
||
maxTokens int // 最大tokens限制,默认100000
|
||
}
|
||
|
||
// Node 攻击链节点(使用database包的类型)
|
||
type Node = database.AttackChainNode
|
||
|
||
// Edge 攻击链边(使用database包的类型)
|
||
type Edge = database.AttackChainEdge
|
||
|
||
// Chain 完整的攻击链
|
||
type Chain struct {
|
||
Nodes []Node `json:"nodes"`
|
||
Edges []Edge `json:"edges"`
|
||
}
|
||
|
||
// NewBuilder 创建新的攻击链构建器
|
||
func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder {
|
||
transport := &http.Transport{
|
||
MaxIdleConns: 100,
|
||
MaxIdleConnsPerHost: 10,
|
||
IdleConnTimeout: 90 * time.Second,
|
||
}
|
||
|
||
maxTokens := 100000 // 默认100k tokens,可以根据模型调整
|
||
// 根据模型设置合理的默认值
|
||
if openAIConfig != nil {
|
||
model := strings.ToLower(openAIConfig.Model)
|
||
if strings.Contains(model, "gpt-4") {
|
||
maxTokens = 128000 // gpt-4通常支持128k
|
||
} else if strings.Contains(model, "gpt-3.5") {
|
||
maxTokens = 16000 // gpt-3.5-turbo通常支持16k
|
||
} else if strings.Contains(model, "deepseek") {
|
||
maxTokens = 131072 // deepseek-chat通常支持131k
|
||
}
|
||
}
|
||
|
||
return &Builder{
|
||
db: db,
|
||
logger: logger,
|
||
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
|
||
openAIConfig: openAIConfig,
|
||
tokenCounter: agent.NewTikTokenCounter(),
|
||
maxTokens: maxTokens,
|
||
}
|
||
}
|
||
|
||
// BuildChainFromConversation 从对话构建攻击链(一次性生成整个图)
|
||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||
b.logger.Info("开始构建攻击链(一次性生成)", zap.String("conversationId", conversationID))
|
||
|
||
// 1. 获取对话消息和工具执行记录
|
||
messages, err := b.db.GetMessages(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取对话消息失败: %w", err)
|
||
}
|
||
|
||
executions, err := b.getToolExecutionsByConversation(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取工具执行记录失败: %w", err)
|
||
}
|
||
|
||
// 获取过程详情
|
||
processDetailsMap, err := b.db.GetProcessDetailsByConversation(conversationID)
|
||
if err != nil {
|
||
b.logger.Warn("获取过程详情失败", zap.Error(err))
|
||
processDetailsMap = make(map[string][]database.ProcessDetail)
|
||
}
|
||
|
||
if len(executions) == 0 && len(messages) == 0 {
|
||
b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID))
|
||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||
}
|
||
|
||
// 2. 准备上下文数据
|
||
contextData, err := b.prepareContextData(messages, executions, processDetailsMap)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("准备上下文数据失败: %w", err)
|
||
}
|
||
|
||
// 3. 一次性生成攻击链(带重试和压缩机制)
|
||
chain, err := b.generateChainWithRetry(ctx, contextData, 5)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成攻击链失败: %w", err)
|
||
}
|
||
|
||
// 4. 保存到数据库
|
||
if err := b.saveChain(conversationID, chain.Nodes, chain.Edges); err != nil {
|
||
b.logger.Warn("保存攻击链失败", zap.Error(err))
|
||
// 不返回错误,继续返回结果
|
||
}
|
||
|
||
b.logger.Info("攻击链构建完成",
|
||
zap.String("conversationId", conversationID),
|
||
zap.Int("nodes", len(chain.Nodes)),
|
||
zap.Int("edges", len(chain.Edges)))
|
||
|
||
return chain, nil
|
||
}
|
||
|
||
// getToolExecutionsByConversation 获取对话的工具执行记录
|
||
func (b *Builder) getToolExecutionsByConversation(conversationID string) ([]*mcp.ToolExecution, error) {
|
||
// 通过conversation_id关联messages,再通过mcp_execution_ids关联tool_executions
|
||
// 简化实现:直接查询所有工具执行记录,然后过滤(实际应该优化查询)
|
||
allExecutions, err := b.db.LoadToolExecutions()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取对话的消息,提取mcp_execution_ids
|
||
messages, err := b.db.GetMessages(conversationID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 收集所有execution IDs
|
||
executionIDSet := make(map[string]bool)
|
||
for _, msg := range messages {
|
||
if len(msg.MCPExecutionIDs) > 0 {
|
||
for _, id := range msg.MCPExecutionIDs {
|
||
executionIDSet[id] = true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 过滤执行记录
|
||
var filteredExecutions []*mcp.ToolExecution
|
||
for _, exec := range allExecutions {
|
||
if executionIDSet[exec.ID] {
|
||
filteredExecutions = append(filteredExecutions, exec)
|
||
}
|
||
}
|
||
|
||
// 按时间排序
|
||
sort.Slice(filteredExecutions, func(i, j int) bool {
|
||
return filteredExecutions[i].StartTime.Before(filteredExecutions[j].StartTime)
|
||
})
|
||
|
||
return filteredExecutions, nil
|
||
}
|
||
|
||
// saveChain 保存攻击链到数据库
|
||
func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error {
|
||
// 先删除旧的攻击链数据
|
||
if err := b.db.DeleteAttackChain(conversationID); err != nil {
|
||
b.logger.Warn("删除旧攻击链失败", zap.Error(err))
|
||
}
|
||
|
||
// 保存节点
|
||
for _, node := range nodes {
|
||
metadataJSON, _ := json.Marshal(node.Metadata)
|
||
if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, node.ToolExecutionID, string(metadataJSON), node.RiskScore); err != nil {
|
||
b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 保存边
|
||
for _, edge := range edges {
|
||
if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil {
|
||
b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// LoadChainFromDatabase 从数据库加载攻击链
|
||
func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) {
|
||
nodes, err := b.db.LoadAttackChainNodes(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载攻击链节点失败: %w", err)
|
||
}
|
||
|
||
edges, err := b.db.LoadAttackChainEdges(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载攻击链边失败: %w", err)
|
||
}
|
||
|
||
return &Chain{
|
||
Nodes: nodes,
|
||
Edges: edges,
|
||
}, nil
|
||
}
|
||
|
||
// ContextData 上下文数据(用于一次性生成攻击链)
|
||
type ContextData struct {
|
||
Messages []database.Message `json:"messages"`
|
||
Executions []*mcp.ToolExecution `json:"executions"`
|
||
ProcessDetails map[string][]database.ProcessDetail `json:"process_details"`
|
||
SummarizedItems map[string]string `json:"summarized_items"` // 已总结的项目(key: 原始ID, value: 总结内容)
|
||
}
|
||
|
||
// prepareContextData 准备上下文数据
|
||
func (b *Builder) prepareContextData(messages []database.Message, executions []*mcp.ToolExecution, processDetails map[string][]database.ProcessDetail) (*ContextData, error) {
|
||
return &ContextData{
|
||
Messages: messages,
|
||
Executions: executions,
|
||
ProcessDetails: processDetails,
|
||
SummarizedItems: make(map[string]string),
|
||
}, nil
|
||
}
|
||
|
||
// generateChainWithRetry 生成攻击链(带重试和压缩机制)
|
||
func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) {
|
||
// 在第一次尝试前,先检查tokens并压缩(如果需要)
|
||
totalTokens, err := b.countPromptTokens(contextData)
|
||
if err == nil && totalTokens > b.maxTokens {
|
||
b.logger.Info("检测到tokens超过限制,提前压缩",
|
||
zap.Int("totalTokens", totalTokens),
|
||
zap.Int("maxTokens", b.maxTokens))
|
||
if err := b.compressContextData(ctx, contextData); err != nil {
|
||
return nil, fmt.Errorf("压缩上下文失败: %w", err)
|
||
}
|
||
}
|
||
|
||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||
b.logger.Info("尝试生成攻击链",
|
||
zap.Int("attempt", attempt+1),
|
||
zap.Int("maxRetries", maxRetries))
|
||
|
||
// 构建提示词
|
||
prompt, err := b.buildChainGenerationPrompt(contextData)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("构建提示词失败: %w", err)
|
||
}
|
||
|
||
// 调用AI生成攻击链
|
||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||
if err != nil {
|
||
// 检查是否是上下文过长错误
|
||
if strings.Contains(err.Error(), "context length") || strings.Contains(err.Error(), "too long") || strings.Contains(err.Error(), "context length exceeded") {
|
||
b.logger.Warn("上下文过长,尝试压缩",
|
||
zap.Int("attempt", attempt+1),
|
||
zap.Error(err))
|
||
|
||
// 使用分片压缩
|
||
if err := b.compressContextData(ctx, contextData); err != nil {
|
||
return nil, fmt.Errorf("压缩上下文失败: %w", err)
|
||
}
|
||
|
||
// 重试
|
||
continue
|
||
}
|
||
|
||
return nil, fmt.Errorf("AI生成失败: %w", err)
|
||
}
|
||
|
||
// 解析JSON(传入executions用于ID映射)
|
||
chain, err := b.parseChainJSON(chainJSON, contextData.Executions)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解析攻击链JSON失败: %w", err)
|
||
}
|
||
|
||
return chain, nil
|
||
}
|
||
|
||
return nil, fmt.Errorf("生成攻击链失败:超过最大重试次数 %d", maxRetries)
|
||
}
|
||
|
||
// buildChainGenerationPrompt 构建攻击链生成提示词
|
||
func (b *Builder) buildChainGenerationPrompt(contextData *ContextData) (string, error) {
|
||
var promptBuilder strings.Builder
|
||
|
||
promptBuilder.WriteString(`你是一个专业的安全测试分析师。请根据以下对话和工具执行记录,生成清晰、有教育意义的攻击链图。
|
||
|
||
## 核心原则
|
||
|
||
**目标:让不懂渗透测试的同学可以通过这个攻击链路学习到知识,而不是无数个节点看花眼。**
|
||
**即便某些工具执行或漏洞挖掘没有成功,只要它们提供了关键线索、错误提示或下一步思路,也要被保留下来。**
|
||
|
||
## 任务要求
|
||
|
||
1. **节点类型(简化,只保留3种)**:
|
||
- **target(目标)**:从用户输入中提取测试目标(IP、域名、URL等)
|
||
- **重要:如果对话中测试了多个不同的目标(如先测试A网页,后测试B网页),必须为每个不同的目标创建独立的target节点**
|
||
- 每个target节点只关联属于它的action节点(通过工具执行参数中的目标来判断)
|
||
- 不同目标的action节点之间**不应该**建立关联关系
|
||
- **action(行动)**:**工具执行 + AI分析结果 = 一个action节点**
|
||
- 将每个工具执行和AI对该工具结果的分析合并为一个action节点
|
||
- 节点标签应该清晰描述"做了什么"、"得到了什么结果或线索"(例如:"使用Nmap扫描端口,发现22、80、443端口开放" 或 "尝试SQLmap,虽然失败但提示存在WAF拦截")
|
||
- 默认关注成功的执行;但如果执行失败却提供了有价值的线索(错误信息、资产指纹、下一步建议等),也要保留,记为"带线索的失败"行动
|
||
- **重要:action节点必须关联到正确的target节点(通过工具执行参数判断目标)**
|
||
- **vulnerability(漏洞)**:从工具执行结果和AI分析中提取的**真实漏洞**(不是所有发现都是漏洞)。若验证失败但能明确表明某个漏洞利用方向不可行,可作为行动节点的线索描述,而不是漏洞节点。
|
||
|
||
2. **过滤规则(重要!)**:
|
||
- **默认忽略**彻底无效的信息:完全没有输出、没有任何线索的失败执行仍需过滤
|
||
- **必须保留**下列失败执行:
|
||
- 错误信息里包含了潜在线索、受限条件、可复现的报错
|
||
- 虽未找到漏洞,但收集到了资产信息、技术栈或后续测试方向
|
||
- 用户特别关注的失败尝试
|
||
- **保留策略**:只要行动节点能给后续测试提供启发,就保留;否则忽略
|
||
|
||
3. **建立清晰的关联关系**:
|
||
- target → action:目标指向属于它的所有行动(通过工具执行参数判断目标)
|
||
- action → action:行动之间的逻辑顺序(按时间顺序,但只连接有逻辑关系的)
|
||
- **重要:只连接属于同一目标的action节点,不同目标的action节点之间不应该连接**
|
||
- action → vulnerability:行动发现的漏洞
|
||
- vulnerability → vulnerability:漏洞间的因果关系(如SQL注入 → 信息泄露)
|
||
- **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接**
|
||
|
||
4. **节点属性**:
|
||
- 每个节点需要:id, type, label, risk_score, metadata
|
||
- action节点需要:
|
||
- tool_name: 工具名称
|
||
- tool_intent: 工具调用意图(如"端口扫描"、"漏洞扫描")
|
||
- ai_analysis: AI对工具结果的分析总结(简洁,不超过100字,失败节点需解释线索价值)
|
||
- findings: 关键发现(列表)
|
||
- status: "success" | "failed_insight"(失败但有价值的线索)
|
||
- hints: ["下一步建议1", "限制条件2"](失败节点可提供的线索列表)
|
||
- vulnerability节点需要:type, description, severity, location
|
||
|
||
## 对话数据
|
||
|
||
`)
|
||
|
||
// 添加消息
|
||
promptBuilder.WriteString("\n### 对话消息\n\n")
|
||
for i, msg := range contextData.Messages {
|
||
promptBuilder.WriteString(fmt.Sprintf("消息%d [%s]:\n", i+1, msg.Role))
|
||
|
||
isUserMessage := strings.EqualFold(msg.Role, "user")
|
||
// 用户输入必须原样提供给攻击链模型
|
||
if isUserMessage {
|
||
promptBuilder.WriteString(fmt.Sprintf("%s\n\n", msg.Content))
|
||
} else if summary, ok := contextData.SummarizedItems[msg.ID]; ok {
|
||
promptBuilder.WriteString(fmt.Sprintf("[已总结] %s\n\n", summary))
|
||
} else {
|
||
content := msg.Content
|
||
if len(content) > 5000 {
|
||
content = content[:5000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("%s\n\n", content))
|
||
}
|
||
|
||
// 添加过程详情
|
||
if details, ok := contextData.ProcessDetails[msg.ID]; ok {
|
||
for _, detail := range details {
|
||
if detail.EventType == "thinking" {
|
||
thinkingText := detail.Message
|
||
if summary, ok := contextData.SummarizedItems[detail.ID]; ok {
|
||
thinkingText = "[已总结] " + summary
|
||
} else if len(thinkingText) > 2000 {
|
||
thinkingText = thinkingText[:2000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("思考过程: %s\n", thinkingText))
|
||
}
|
||
}
|
||
}
|
||
promptBuilder.WriteString("\n")
|
||
}
|
||
|
||
// 添加工具执行记录(关联对应的AI回复)
|
||
promptBuilder.WriteString("\n### 工具执行记录(包含对应的AI分析)\n\n")
|
||
|
||
// 构建工具执行ID到消息的映射(找到工具执行后AI的回复)
|
||
execToMessageMap := b.buildExecutionToMessageMap(contextData)
|
||
|
||
for i, exec := range contextData.Executions {
|
||
// 检查是否是错误/失败的执行
|
||
isError := exec.Error != "" || (exec.Result != nil && exec.Result.IsError)
|
||
|
||
statusText := "成功"
|
||
if isError {
|
||
statusText = "失败(可能包含线索)"
|
||
}
|
||
|
||
promptBuilder.WriteString(fmt.Sprintf("执行%d [%s] (ID: %s) - 状态: %s\n", i+1, exec.ToolName, exec.ID, statusText))
|
||
promptBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments)))
|
||
|
||
if isError && exec.Error != "" {
|
||
promptBuilder.WriteString(fmt.Sprintf("错误信息: %s\n", exec.Error))
|
||
}
|
||
|
||
// 检查是否已总结
|
||
var resultText string
|
||
if exec.Result != nil {
|
||
for _, content := range exec.Result.Content {
|
||
if content.Type == "text" {
|
||
resultText += content.Text + "\n"
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查结果是否为空或无效
|
||
if strings.TrimSpace(resultText) == "" {
|
||
if isError {
|
||
promptBuilder.WriteString("工具执行结果: [失败但未返回正文]\n")
|
||
} else {
|
||
promptBuilder.WriteString("工具执行结果: **已忽略(结果为空)**\n\n")
|
||
continue
|
||
}
|
||
} else {
|
||
if summary, ok := contextData.SummarizedItems[exec.ID]; ok {
|
||
promptBuilder.WriteString(fmt.Sprintf("工具执行结果: [已总结] %s\n", summary))
|
||
} else {
|
||
if len(resultText) > 5000 {
|
||
resultText = resultText[:5000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("工具执行结果: %s\n", resultText))
|
||
}
|
||
}
|
||
|
||
// 添加对应的AI分析(工具执行后AI的回复)
|
||
if aiMessage, ok := execToMessageMap[exec.ID]; ok {
|
||
aiContent := aiMessage.Content
|
||
if len(aiContent) > 2000 {
|
||
aiContent = aiContent[:2000] + "..."
|
||
}
|
||
promptBuilder.WriteString(fmt.Sprintf("AI分析: %s\n", aiContent))
|
||
}
|
||
|
||
promptBuilder.WriteString("\n")
|
||
}
|
||
|
||
promptBuilder.WriteString(`
|
||
|
||
## 输出格式
|
||
|
||
请以JSON格式返回攻击链,格式如下:
|
||
|
||
{
|
||
"nodes": [
|
||
{
|
||
"id": "node_1",
|
||
"type": "target|action|vulnerability",
|
||
"label": "节点标签(清晰、简洁,action节点要描述"做了什么"和"发现了什么")",
|
||
"risk_score": 0-100,
|
||
"tool_execution_id": "执行记录的真实ID(action节点必须使用上面执行记录中的ID字段)",
|
||
"metadata": {
|
||
"target": "目标(target节点)",
|
||
"tool_name": "工具名称(action节点)",
|
||
"tool_intent": "工具调用意图(action节点,如"端口扫描"、"漏洞扫描")",
|
||
"ai_analysis": "AI对工具结果的分析总结(action节点,不超过100字)",
|
||
"findings": ["发现1", "发现2"](action节点,关键发现列表),
|
||
"vulnerability_type": "漏洞类型(vulnerability节点)",
|
||
"description": "描述(vulnerability节点)",
|
||
"severity": "critical|high|medium|low(vulnerability节点)",
|
||
"location": "漏洞位置(vulnerability节点)"
|
||
}
|
||
}
|
||
],
|
||
"edges": [
|
||
{
|
||
"source": "node_1",
|
||
"target": "node_2",
|
||
"type": "leads_to|discovers|enables",
|
||
"weight": 1-5
|
||
}
|
||
]
|
||
}
|
||
|
||
## 重要要求
|
||
|
||
1. **节点合并**:
|
||
- 每个工具执行和对应的AI分析必须合并为一个action节点
|
||
- action节点的label要清晰描述"做了什么"、"结果/线索是什么"
|
||
- 例如:"使用Nmap扫描192.168.1.1,发现22、80、443端口开放" 或 "执行Sqlmap被WAF拦截,提示403并暴露防护厂商"
|
||
- 若为失败但有线索的行动,请在metadata.status中标记为"failed_insight",并在findings/hints里写清线索价值
|
||
|
||
2. **过滤无效节点**:
|
||
- **必须忽略**没有任何输出、没有线索的失败执行
|
||
- **必须保留**失败但提供关键线索的执行,确保metadata里解释清楚
|
||
- 只保留对学习或溯源有帮助的节点
|
||
|
||
3. **简化结构**:
|
||
- 只创建target、action、vulnerability三种节点
|
||
- 不要创建discovery、decision等节点
|
||
- 让攻击链清晰、有教育意义
|
||
|
||
4. **关联关系**:
|
||
- target → action:目标指向属于它的所有行动(通过工具执行参数判断目标)
|
||
- action → action:按时间顺序连接,但只连接有逻辑关系的
|
||
- **重要:只连接属于同一目标的action节点,不同目标的action节点之间不应该连接**
|
||
- action → vulnerability:行动发现的漏洞
|
||
- vulnerability → vulnerability:漏洞间的因果关系
|
||
- **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接**
|
||
|
||
5. **多目标处理(重要!)**:
|
||
- 如果对话中测试了多个不同的目标(如先测试A网页,后测试B网页),必须:
|
||
- 为每个不同的目标创建独立的target节点
|
||
- 每个target节点只关联属于它的action和vulnerability节点
|
||
- 不同目标的节点之间**不应该**建立任何关联关系
|
||
- 这样会形成多个独立的攻击链分支,每个分支对应一个测试目标
|
||
|
||
6. **节点数量控制**:
|
||
- 如果节点太多(>20个),优先保留最重要的节点
|
||
- 合并相似的action节点(如同一工具的连续调用,如果结果相似)
|
||
|
||
只返回JSON,不要包含其他解释文字。`)
|
||
|
||
return promptBuilder.String(), nil
|
||
}
|
||
|
||
// buildExecutionToMessageMap 构建工具执行ID到AI消息的映射
|
||
// 找到每个工具执行后AI的回复消息
|
||
func (b *Builder) buildExecutionToMessageMap(contextData *ContextData) map[string]database.Message {
|
||
execToMessageMap := make(map[string]database.Message)
|
||
|
||
// 遍历消息,找到包含工具执行ID的消息(通常是assistant消息)
|
||
for _, msg := range contextData.Messages {
|
||
if msg.Role != "assistant" {
|
||
continue
|
||
}
|
||
|
||
// 检查消息中是否引用了工具执行ID
|
||
// 通常工具执行后,AI会在回复中引用这些执行ID
|
||
for _, execID := range msg.MCPExecutionIDs {
|
||
// 找到对应的工具执行
|
||
for _, exec := range contextData.Executions {
|
||
if exec.ID == execID {
|
||
// 如果这个执行还没有关联的消息,或者当前消息时间更晚,则更新
|
||
if existingMsg, exists := execToMessageMap[execID]; !exists || msg.CreatedAt.After(existingMsg.CreatedAt) {
|
||
execToMessageMap[execID] = msg
|
||
}
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果通过MCPExecutionIDs找不到,尝试按时间顺序匹配
|
||
// 找到每个工具执行后最近的assistant消息
|
||
for _, exec := range contextData.Executions {
|
||
if _, exists := execToMessageMap[exec.ID]; exists {
|
||
continue
|
||
}
|
||
|
||
// 找到执行时间之后最近的assistant消息
|
||
var closestMsg *database.Message
|
||
for i := range contextData.Messages {
|
||
msg := &contextData.Messages[i]
|
||
if msg.Role == "assistant" && msg.CreatedAt.After(exec.StartTime) {
|
||
if closestMsg == nil || msg.CreatedAt.Before(closestMsg.CreatedAt) {
|
||
closestMsg = msg
|
||
}
|
||
}
|
||
}
|
||
|
||
if closestMsg != nil {
|
||
execToMessageMap[exec.ID] = *closestMsg
|
||
}
|
||
}
|
||
|
||
return execToMessageMap
|
||
}
|
||
|
||
// formatArguments 格式化工具参数
|
||
func (b *Builder) formatArguments(args map[string]interface{}) string {
|
||
if args == nil {
|
||
return "{}"
|
||
}
|
||
jsonData, _ := json.Marshal(args)
|
||
return string(jsonData)
|
||
}
|
||
|
||
// countPromptTokens 计算prompt的总tokens数
|
||
func (b *Builder) countPromptTokens(contextData *ContextData) (int, error) {
|
||
prompt, err := b.buildChainGenerationPrompt(contextData)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("构建提示词失败: %w", err)
|
||
}
|
||
|
||
if b.tokenCounter == nil || b.openAIConfig == nil {
|
||
// 如果没有token计数器或配置,使用简单的估算(4个字符=1个token)
|
||
return len(prompt) / 4, nil
|
||
}
|
||
|
||
model := b.openAIConfig.Model
|
||
if model == "" {
|
||
model = "gpt-4" // 默认模型
|
||
}
|
||
|
||
count, err := b.tokenCounter.Count(model, prompt)
|
||
if err != nil {
|
||
// 如果计算失败,使用估算
|
||
return len(prompt) / 4, nil
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
// compressContextData 使用分片压缩方式压缩上下文数据
|
||
func (b *Builder) compressContextData(ctx context.Context, contextData *ContextData) error {
|
||
// 计算当前tokens
|
||
totalTokens, err := b.countPromptTokens(contextData)
|
||
if err != nil {
|
||
return fmt.Errorf("计算tokens失败: %w", err)
|
||
}
|
||
|
||
b.logger.Info("开始压缩上下文",
|
||
zap.Int("totalTokens", totalTokens),
|
||
zap.Int("maxTokens", b.maxTokens))
|
||
|
||
// 如果tokens在限制内,不需要压缩
|
||
if totalTokens <= b.maxTokens {
|
||
return nil
|
||
}
|
||
|
||
// 计算需要分成多少份
|
||
numChunks := (totalTokens + b.maxTokens - 1) / b.maxTokens // 向上取整
|
||
if numChunks < 2 {
|
||
numChunks = 2 // 至少分成2份
|
||
}
|
||
|
||
b.logger.Info("将上下文分成多个片段进行压缩",
|
||
zap.Int("totalTokens", totalTokens),
|
||
zap.Int("maxTokens", b.maxTokens),
|
||
zap.Int("numChunks", numChunks))
|
||
|
||
// 按时间顺序将数据分成多个片段
|
||
chunks, err := b.splitContextDataByTime(contextData, numChunks)
|
||
if err != nil {
|
||
return fmt.Errorf("分割上下文数据失败: %w", err)
|
||
}
|
||
|
||
// 对每个片段进行摘要
|
||
summaries := make([]string, 0, len(chunks))
|
||
for i, chunk := range chunks {
|
||
b.logger.Info("压缩片段",
|
||
zap.Int("chunkIndex", i+1),
|
||
zap.Int("totalChunks", len(chunks)),
|
||
zap.Int("chunkSize", len(chunk.Messages)+len(chunk.Executions)))
|
||
|
||
summary, err := b.summarizeContextChunk(ctx, chunk)
|
||
if err != nil {
|
||
// 检查是否是认证错误
|
||
if strings.Contains(err.Error(), "Authentication") || strings.Contains(err.Error(), "api key") || strings.Contains(err.Error(), "invalid") {
|
||
return fmt.Errorf("压缩片段%d失败(API认证错误,请检查OpenAI配置): %w", i+1, err)
|
||
}
|
||
return fmt.Errorf("压缩片段%d失败: %w", i+1, err)
|
||
}
|
||
summaries = append(summaries, summary)
|
||
}
|
||
|
||
// 将摘要合并到contextData中
|
||
// 保留用户消息,清空其他数据,用摘要替换
|
||
var userMessages []database.Message
|
||
for _, msg := range contextData.Messages {
|
||
if strings.EqualFold(msg.Role, "user") {
|
||
userMessages = append(userMessages, msg)
|
||
}
|
||
}
|
||
|
||
// 清空非用户消息和执行记录
|
||
contextData.Messages = userMessages
|
||
contextData.Executions = []*mcp.ToolExecution{}
|
||
contextData.ProcessDetails = make(map[string][]database.ProcessDetail)
|
||
|
||
// 创建一个综合摘要消息
|
||
combinedSummary := strings.Join(summaries, "\n\n---\n\n")
|
||
summaryMsg := database.Message{
|
||
ID: uuid.New().String(),
|
||
Role: "assistant",
|
||
Content: fmt.Sprintf("[上下文摘要 - 包含%d个片段的压缩内容]\n\n%s", len(summaries), combinedSummary),
|
||
CreatedAt: time.Now(),
|
||
}
|
||
contextData.Messages = append(contextData.Messages, summaryMsg)
|
||
|
||
// 检查压缩后的tokens
|
||
compressedTokens, err := b.countPromptTokens(contextData)
|
||
if err != nil {
|
||
return fmt.Errorf("计算压缩后tokens失败: %w", err)
|
||
}
|
||
|
||
b.logger.Info("压缩完成",
|
||
zap.Int("originalTokens", totalTokens),
|
||
zap.Int("compressedTokens", compressedTokens),
|
||
zap.Int("reduction", totalTokens-compressedTokens))
|
||
|
||
// 如果压缩后仍然超过限制,递归压缩
|
||
if compressedTokens > b.maxTokens {
|
||
b.logger.Info("压缩后仍然超过限制,继续递归压缩",
|
||
zap.Int("compressedTokens", compressedTokens),
|
||
zap.Int("maxTokens", b.maxTokens))
|
||
return b.compressContextData(ctx, contextData)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ContextChunk 上下文数据片段
|
||
type ContextChunk struct {
|
||
Messages []database.Message
|
||
Executions []*mcp.ToolExecution
|
||
ProcessDetails map[string][]database.ProcessDetail
|
||
}
|
||
|
||
// splitContextDataByTime 按时间顺序将上下文数据分成多个片段
|
||
func (b *Builder) splitContextDataByTime(contextData *ContextData, numChunks int) ([]*ContextChunk, error) {
|
||
if numChunks <= 0 {
|
||
return nil, fmt.Errorf("片段数量必须大于0")
|
||
}
|
||
|
||
// 收集所有带时间戳的项目
|
||
type timeItem struct {
|
||
time time.Time
|
||
itemType string // "message", "execution", "thinking"
|
||
message *database.Message
|
||
execution *mcp.ToolExecution
|
||
processDetail *database.ProcessDetail
|
||
}
|
||
|
||
var items []timeItem
|
||
|
||
// 添加消息(跳过已总结的)
|
||
for i := range contextData.Messages {
|
||
msg := &contextData.Messages[i]
|
||
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
|
||
continue
|
||
}
|
||
items = append(items, timeItem{
|
||
time: msg.CreatedAt,
|
||
itemType: "message",
|
||
message: msg,
|
||
})
|
||
}
|
||
|
||
// 添加工具执行(跳过已总结的)
|
||
for _, exec := range contextData.Executions {
|
||
if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized {
|
||
continue
|
||
}
|
||
items = append(items, timeItem{
|
||
time: exec.StartTime,
|
||
itemType: "execution",
|
||
execution: exec,
|
||
})
|
||
}
|
||
|
||
// 添加思考过程(跳过已总结的)
|
||
for _, details := range contextData.ProcessDetails {
|
||
for i := range details {
|
||
detail := &details[i]
|
||
if detail.EventType == "thinking" {
|
||
if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized {
|
||
continue
|
||
}
|
||
items = append(items, timeItem{
|
||
time: detail.CreatedAt,
|
||
itemType: "thinking",
|
||
processDetail: detail,
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(items) == 0 {
|
||
return nil, fmt.Errorf("没有可分割的数据")
|
||
}
|
||
|
||
// 按时间排序
|
||
sort.Slice(items, func(i, j int) bool {
|
||
return items[i].time.Before(items[j].time)
|
||
})
|
||
|
||
// 计算每个片段的大小
|
||
chunkSize := (len(items) + numChunks - 1) / numChunks // 向上取整
|
||
|
||
// 创建片段
|
||
chunks := make([]*ContextChunk, 0, numChunks)
|
||
for i := 0; i < len(items); i += chunkSize {
|
||
end := i + chunkSize
|
||
if end > len(items) {
|
||
end = len(items)
|
||
}
|
||
|
||
chunk := &ContextChunk{
|
||
Messages: []database.Message{},
|
||
Executions: []*mcp.ToolExecution{},
|
||
ProcessDetails: make(map[string][]database.ProcessDetail),
|
||
}
|
||
|
||
for j := i; j < end; j++ {
|
||
item := items[j]
|
||
switch item.itemType {
|
||
case "message":
|
||
chunk.Messages = append(chunk.Messages, *item.message)
|
||
case "execution":
|
||
chunk.Executions = append(chunk.Executions, item.execution)
|
||
case "thinking":
|
||
if item.processDetail != nil {
|
||
msgID := item.processDetail.MessageID
|
||
chunk.ProcessDetails[msgID] = append(chunk.ProcessDetails[msgID], *item.processDetail)
|
||
}
|
||
}
|
||
}
|
||
|
||
chunks = append(chunks, chunk)
|
||
}
|
||
|
||
return chunks, nil
|
||
}
|
||
|
||
// getModelMaxContextLength 获取模型的最大上下文长度
|
||
func (b *Builder) getModelMaxContextLength() int {
|
||
if b.openAIConfig == nil {
|
||
return 131072 // 默认值
|
||
}
|
||
model := strings.ToLower(b.openAIConfig.Model)
|
||
if strings.Contains(model, "gpt-4") {
|
||
return 128000
|
||
} else if strings.Contains(model, "gpt-3.5") {
|
||
return 16000
|
||
} else if strings.Contains(model, "deepseek") {
|
||
return 131072
|
||
}
|
||
return 131072 // 默认值
|
||
}
|
||
|
||
// summarizeContextChunk 总结一个上下文片段
|
||
func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk) (string, error) {
|
||
// 先构建内容
|
||
content, err := b.buildChunkContent(chunk)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 使用AI总结
|
||
promptTemplate := `请详细总结以下安全测试对话片段的关键信息。虽然需要压缩内容,但必须保留所有重要的技术细节和上下文信息,确保后续攻击链生成时能够准确理解整个测试过程。
|
||
|
||
**必须详细保留的内容:**
|
||
1. **所有工具执行记录**:
|
||
- 工具名称、执行参数、执行结果(包括成功和失败)
|
||
- 失败执行的错误信息、状态码、响应头等关键线索
|
||
- 工具输出的关键数据(端口、服务版本、漏洞信息等)
|
||
- 每个工具执行的时间顺序和上下文关系
|
||
|
||
2. **所有发现的漏洞和潜在安全问题**:
|
||
- 漏洞类型、严重程度、位置、利用方式
|
||
- 验证过程和结果
|
||
- 漏洞之间的关联关系
|
||
|
||
3. **所有测试目标和资产信息**:
|
||
- IP地址、域名、URL、端口等
|
||
- 发现的服务、技术栈、版本信息
|
||
- 资产之间的关联关系
|
||
|
||
4. **所有测试步骤和决策过程**:
|
||
- 每个测试步骤的详细描述(做了什么、为什么做、结果如何)
|
||
- AI的分析思路和决策依据
|
||
- 失败尝试的原因和从中获得的线索
|
||
|
||
5. **所有关键发现和线索**:
|
||
- 成功发现的详细信息
|
||
- 失败但提供线索的尝试(错误信息、限制条件、下一步建议等)
|
||
- 收集到的任何有价值的信息(凭据、令牌、配置信息等)
|
||
|
||
**总结要求:**
|
||
- 用结构化的方式组织信息,按时间顺序或逻辑顺序排列
|
||
- 对于每个工具执行,必须包含:工具名、目标、参数、结果/错误、关键发现
|
||
- 对于每个漏洞,必须包含:类型、位置、严重程度、验证结果
|
||
- 保留所有技术细节,不要过度简化
|
||
- 确保后续AI能够根据这个摘要完整重建攻击链
|
||
|
||
对话片段:
|
||
%s
|
||
|
||
请给出详细且结构化的技术摘要(建议1000-2000字,确保信息完整):`
|
||
|
||
// 检查prompt tokens,如果超过限制,需要进一步压缩内容
|
||
maxContextLength := b.getModelMaxContextLength()
|
||
maxPromptTokens := maxContextLength - 2000 // 留出空间给响应和系统消息
|
||
|
||
// 尝试构建完整prompt并检查tokens
|
||
fullPrompt := fmt.Sprintf(promptTemplate, content)
|
||
promptTokens, err := b.countTextTokens(fullPrompt)
|
||
if err != nil {
|
||
// 如果计算失败,使用估算
|
||
promptTokens = len(fullPrompt) / 4
|
||
}
|
||
|
||
// 如果prompt太大,需要进一步压缩内容
|
||
if promptTokens > maxPromptTokens {
|
||
b.logger.Warn("片段内容过大,需要进一步压缩",
|
||
zap.Int("promptTokens", promptTokens),
|
||
zap.Int("maxPromptTokens", maxPromptTokens))
|
||
|
||
// 递归压缩:将chunk进一步分割
|
||
compressedContent, err := b.compressLargeChunk(ctx, chunk, maxPromptTokens)
|
||
if err != nil {
|
||
return "", fmt.Errorf("压缩大片段失败: %w", err)
|
||
}
|
||
content = compressedContent
|
||
}
|
||
|
||
prompt := fmt.Sprintf(promptTemplate, content)
|
||
|
||
// 检查配置
|
||
if b.openAIConfig == nil {
|
||
return "", fmt.Errorf("OpenAI配置未初始化")
|
||
}
|
||
if b.openAIConfig.APIKey == "" {
|
||
return "", fmt.Errorf("OpenAI API Key未配置")
|
||
}
|
||
if b.openAIConfig.Model == "" {
|
||
return "", fmt.Errorf("OpenAI Model未配置")
|
||
}
|
||
|
||
// 直接调用AI API进行总结
|
||
requestBody := map[string]interface{}{
|
||
"model": b.openAIConfig.Model,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "system",
|
||
"content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试对话片段,这些摘要将用于后续构建完整的攻击链图。
|
||
|
||
**你的专业背景:**
|
||
- 精通各种安全测试工具(Nmap、SQLMap、Burp Suite、Metasploit等)的使用和结果分析
|
||
- 熟悉常见漏洞类型(SQL注入、XSS、文件上传、命令执行、目录遍历等)的识别和验证
|
||
- 理解攻击链的构建逻辑:从信息收集 → 漏洞发现 → 漏洞利用 → 权限提升 → 横向移动
|
||
- 能够识别失败尝试中的有价值线索(错误信息、状态码、WAF指纹、技术栈信息等)
|
||
|
||
**你的总结原则:**
|
||
1. **完整性优先**:虽然需要压缩,但必须保留所有技术细节,确保后续AI能够完整重建攻击链
|
||
2. **结构化组织**:按时间顺序或逻辑顺序组织信息,让信息易于理解和追踪
|
||
3. **技术精准**:使用准确的技术术语,保留具体的数值、版本号、端口号、URL等关键数据
|
||
4. **上下文关联**:保留测试步骤之间的因果关系和逻辑关联
|
||
5. **失败价值**:即使是失败的尝试,只要提供了线索(错误信息、限制条件、下一步建议),也要详细记录
|
||
|
||
**你需要特别关注的信息类型:**
|
||
- 工具执行:工具名、目标、参数、完整结果(包括错误和失败)
|
||
- 漏洞发现:类型、位置、严重程度、验证方法、利用结果
|
||
- 资产信息:IP、域名、端口、服务版本、技术栈
|
||
- 测试策略:为什么选择这个工具、为什么测试这个目标、发现了什么线索
|
||
- 关键数据:凭据、令牌、配置信息、敏感文件内容
|
||
|
||
请用专业、详细、结构化的中文进行总结,确保信息完整且易于后续处理。`,
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
},
|
||
},
|
||
"temperature": 0.3,
|
||
"max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容
|
||
}
|
||
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||
|
||
resp, err := b.openAIClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var apiResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if len(apiResponse.Choices) == 0 {
|
||
return "", fmt.Errorf("API未返回有效响应")
|
||
}
|
||
|
||
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
|
||
}
|
||
|
||
// compressLongestItem 压缩最长的子节点(保留作为备用方法)
|
||
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
|
||
// 使用新的分片压缩方法
|
||
return b.compressContextData(ctx, contextData)
|
||
}
|
||
|
||
// summarizeContent 总结内容
|
||
func (b *Builder) summarizeContent(ctx context.Context, contentType, content string) (string, error) {
|
||
var prompt string
|
||
switch contentType {
|
||
case "message":
|
||
prompt = fmt.Sprintf(`请总结以下AI回复的关键信息,保留所有重要的安全发现、漏洞信息和测试结果。用简洁的中文总结,不超过500字。
|
||
|
||
AI回复:
|
||
%s
|
||
|
||
总结:`, content)
|
||
case "execution":
|
||
prompt = fmt.Sprintf(`请总结以下工具执行结果的关键信息,保留所有发现的漏洞、重要发现和测试结果。用简洁的中文总结,不超过500字。
|
||
|
||
工具执行结果:
|
||
%s
|
||
|
||
总结:`, content)
|
||
case "thinking":
|
||
prompt = fmt.Sprintf(`请总结以下AI思考过程的关键决策和思路,保留所有重要的决策点和测试策略。用简洁的中文总结,不超过300字。
|
||
|
||
思考过程:
|
||
%s
|
||
|
||
总结:`, content)
|
||
default:
|
||
return "", fmt.Errorf("未知的内容类型: %s", contentType)
|
||
}
|
||
|
||
requestBody := map[string]interface{}{
|
||
"model": b.openAIConfig.Model,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "system",
|
||
"content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试过程中的关键信息,这些摘要将用于构建攻击链图。
|
||
|
||
**你的专业背景:**
|
||
- 精通各种安全测试工具的使用和结果分析(Nmap、SQLMap、Burp Suite、Metasploit、Nuclei等)
|
||
- 熟悉常见漏洞类型的识别和验证(SQL注入、XSS、文件上传、命令执行、目录遍历、SSRF等)
|
||
- 理解攻击链的构建逻辑和测试流程
|
||
- 能够识别失败尝试中的有价值线索
|
||
|
||
**你的总结原则:**
|
||
1. **保留技术细节**:保留所有重要的技术信息,包括工具名、参数、结果、错误信息、状态码等
|
||
2. **突出关键发现**:重点记录发现的漏洞、安全问题、资产信息、凭据等
|
||
3. **记录失败线索**:即使是失败的尝试,如果提供了错误信息、限制条件或下一步建议,也要详细记录
|
||
4. **保持准确性**:使用准确的技术术语,保留具体的数值、版本号、端口号等关键数据
|
||
5. **结构化表达**:用清晰、有条理的方式组织信息
|
||
|
||
**根据内容类型,你需要特别关注:**
|
||
- **AI回复**:提取安全发现、漏洞信息、测试结果、分析思路、决策依据
|
||
- **工具执行**:记录工具名、目标、参数、完整结果(成功或失败)、关键发现、错误信息
|
||
- **思考过程**:提取关键决策点、测试策略、分析思路、下一步计划
|
||
|
||
请用专业、准确、简洁的中文进行总结,确保信息完整且易于理解。`,
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
},
|
||
},
|
||
"temperature": 0.3,
|
||
"max_tokens": 1000,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||
|
||
resp, err := b.openAIClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var apiResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if len(apiResponse.Choices) == 0 {
|
||
return "", fmt.Errorf("API未返回有效响应")
|
||
}
|
||
|
||
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
|
||
}
|
||
|
||
// buildChunkContent 构建chunk的文本内容
|
||
func (b *Builder) buildChunkContent(chunk *ContextChunk) (string, error) {
|
||
var contentBuilder strings.Builder
|
||
|
||
// 添加消息
|
||
for _, msg := range chunk.Messages {
|
||
if strings.EqualFold(msg.Role, "user") {
|
||
contentBuilder.WriteString(fmt.Sprintf("用户消息: %s\n\n", msg.Content))
|
||
} else {
|
||
contentBuilder.WriteString(fmt.Sprintf("AI回复: %s\n\n", msg.Content))
|
||
}
|
||
}
|
||
|
||
// 添加工具执行
|
||
for _, exec := range chunk.Executions {
|
||
contentBuilder.WriteString(fmt.Sprintf("工具执行 [%s] (ID: %s):\n", exec.ToolName, exec.ID))
|
||
contentBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments)))
|
||
|
||
if exec.Error != "" {
|
||
contentBuilder.WriteString(fmt.Sprintf("错误: %s\n", exec.Error))
|
||
}
|
||
|
||
if exec.Result != nil {
|
||
var resultText string
|
||
for _, content := range exec.Result.Content {
|
||
if content.Type == "text" {
|
||
resultText += content.Text + "\n"
|
||
}
|
||
}
|
||
if resultText != "" {
|
||
// 如果结果太长,截断
|
||
if len(resultText) > 10000 {
|
||
resultText = resultText[:10000] + "\n... [内容已截断]"
|
||
}
|
||
contentBuilder.WriteString(fmt.Sprintf("结果: %s\n", resultText))
|
||
}
|
||
}
|
||
contentBuilder.WriteString("\n")
|
||
}
|
||
|
||
// 添加思考过程
|
||
for _, details := range chunk.ProcessDetails {
|
||
for _, detail := range details {
|
||
if detail.EventType == "thinking" {
|
||
thinkingText := detail.Message
|
||
// 如果思考过程太长,截断
|
||
if len(thinkingText) > 5000 {
|
||
thinkingText = thinkingText[:5000] + "\n... [内容已截断]"
|
||
}
|
||
contentBuilder.WriteString(fmt.Sprintf("思考过程: %s\n\n", thinkingText))
|
||
}
|
||
}
|
||
}
|
||
|
||
content := contentBuilder.String()
|
||
if content == "" {
|
||
return "", fmt.Errorf("片段内容为空")
|
||
}
|
||
return content, nil
|
||
}
|
||
|
||
// compressLargeChunk 压缩过大的chunk(递归分割)
|
||
func (b *Builder) compressLargeChunk(ctx context.Context, chunk *ContextChunk, maxTokens int) (string, error) {
|
||
// 将chunk进一步分割成更小的子chunk
|
||
// 简单策略:按消息和执行数量平均分割
|
||
totalItems := len(chunk.Messages) + len(chunk.Executions)
|
||
if totalItems <= 1 {
|
||
// 如果只有一个项目,直接截断内容
|
||
content, _ := b.buildChunkContent(chunk)
|
||
if len(content) > maxTokens*4 { // 粗略估算:1 token ≈ 4字符
|
||
content = content[:maxTokens*4] + "\n... [内容过大,已截断]"
|
||
}
|
||
return content, nil
|
||
}
|
||
|
||
// 分成2个子chunk
|
||
mid := totalItems / 2
|
||
subChunk1 := &ContextChunk{
|
||
Messages: []database.Message{},
|
||
Executions: []*mcp.ToolExecution{},
|
||
ProcessDetails: make(map[string][]database.ProcessDetail),
|
||
}
|
||
subChunk2 := &ContextChunk{
|
||
Messages: []database.Message{},
|
||
Executions: []*mcp.ToolExecution{},
|
||
ProcessDetails: make(map[string][]database.ProcessDetail),
|
||
}
|
||
|
||
// 分配消息
|
||
for i, msg := range chunk.Messages {
|
||
if i < mid {
|
||
subChunk1.Messages = append(subChunk1.Messages, msg)
|
||
} else {
|
||
subChunk2.Messages = append(subChunk2.Messages, msg)
|
||
}
|
||
}
|
||
|
||
// 分配执行
|
||
execStart := len(chunk.Messages)
|
||
for i, exec := range chunk.Executions {
|
||
if execStart+i < mid {
|
||
subChunk1.Executions = append(subChunk1.Executions, exec)
|
||
} else {
|
||
subChunk2.Executions = append(subChunk2.Executions, exec)
|
||
}
|
||
}
|
||
|
||
// 递归压缩子chunk
|
||
summary1, err := b.summarizeContextChunk(ctx, subChunk1)
|
||
if err != nil {
|
||
return "", fmt.Errorf("压缩子chunk1失败: %w", err)
|
||
}
|
||
|
||
summary2, err := b.summarizeContextChunk(ctx, subChunk2)
|
||
if err != nil {
|
||
return "", fmt.Errorf("压缩子chunk2失败: %w", err)
|
||
}
|
||
|
||
// 合并摘要
|
||
return fmt.Sprintf("片段1摘要:\n%s\n\n---\n\n片段2摘要:\n%s", summary1, summary2), nil
|
||
}
|
||
|
||
// countTextTokens 计算文本的tokens数
|
||
func (b *Builder) countTextTokens(text string) (int, error) {
|
||
if b.tokenCounter == nil || b.openAIConfig == nil {
|
||
return len(text) / 4, nil
|
||
}
|
||
|
||
model := b.openAIConfig.Model
|
||
if model == "" {
|
||
model = "gpt-4"
|
||
}
|
||
|
||
count, err := b.tokenCounter.Count(model, text)
|
||
if err != nil {
|
||
return len(text) / 4, nil
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
// callAIForChainGeneration 调用AI生成攻击链
|
||
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
|
||
requestBody := map[string]interface{}{
|
||
"model": b.openAIConfig.Model,
|
||
"messages": []map[string]interface{}{
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。",
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
},
|
||
},
|
||
"temperature": 0.3,
|
||
"max_tokens": 8000,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||
|
||
resp, err := b.openAIClient.Do(req)
|
||
if err != nil {
|
||
// 检查是否是上下文过长错误
|
||
if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") {
|
||
return "", fmt.Errorf("context length exceeded")
|
||
}
|
||
return "", fmt.Errorf("请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
bodyStr := string(body)
|
||
// 检查是否是上下文过长错误
|
||
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
|
||
return "", fmt.Errorf("context length exceeded")
|
||
}
|
||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr)
|
||
}
|
||
|
||
var apiResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||
}
|
||
|
||
if len(apiResponse.Choices) == 0 {
|
||
return "", fmt.Errorf("API未返回有效响应")
|
||
}
|
||
|
||
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
|
||
// 尝试提取JSON(可能包含markdown代码块)
|
||
content = strings.TrimPrefix(content, "```json")
|
||
content = strings.TrimPrefix(content, "```")
|
||
content = strings.TrimSuffix(content, "```")
|
||
content = strings.TrimSpace(content)
|
||
|
||
return content, nil
|
||
}
|
||
|
||
// ChainJSON 攻击链JSON结构
|
||
type ChainJSON struct {
|
||
Nodes []struct {
|
||
ID string `json:"id"`
|
||
Type string `json:"type"`
|
||
Label string `json:"label"`
|
||
RiskScore int `json:"risk_score"`
|
||
ToolExecutionID string `json:"tool_execution_id,omitempty"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
} `json:"nodes"`
|
||
Edges []struct {
|
||
Source string `json:"source"`
|
||
Target string `json:"target"`
|
||
Type string `json:"type"`
|
||
Weight int `json:"weight"`
|
||
} `json:"edges"`
|
||
}
|
||
|
||
// parseChainJSON 解析攻击链JSON
|
||
func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) {
|
||
var chainData ChainJSON
|
||
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
|
||
return nil, fmt.Errorf("解析JSON失败: %w", err)
|
||
}
|
||
|
||
// 创建execution ID映射(AI可能返回简单的索引或ID,需要映射到真实的execution ID)
|
||
executionMap := make(map[string]string) // AI返回的ID -> 真实execution ID
|
||
for i, exec := range executions {
|
||
// 支持多种可能的AI返回格式
|
||
executionMap[fmt.Sprintf("exec_%d", i+1)] = exec.ID
|
||
executionMap[fmt.Sprintf("execution_%d", i+1)] = exec.ID
|
||
executionMap[exec.ID] = exec.ID // 如果AI直接返回真实ID
|
||
executionMap[fmt.Sprintf("tool_%d", i+1)] = exec.ID // AI可能用tool_1格式
|
||
executionMap[fmt.Sprintf("执行%d", i+1)] = exec.ID // 中文格式
|
||
executionMap[fmt.Sprintf("执行_%d", i+1)] = exec.ID
|
||
}
|
||
|
||
// 创建节点ID映射(AI返回的ID -> 新的UUID)
|
||
nodeIDMap := make(map[string]string)
|
||
|
||
// 转换为Chain结构,并过滤无效节点
|
||
nodes := make([]Node, 0, len(chainData.Nodes))
|
||
for _, n := range chainData.Nodes {
|
||
// 过滤无效节点
|
||
if b.shouldFilterNode(n, executions) {
|
||
b.logger.Info("过滤无效节点",
|
||
zap.String("nodeID", n.ID),
|
||
zap.String("nodeType", n.Type),
|
||
zap.String("label", n.Label))
|
||
continue
|
||
}
|
||
|
||
// 生成新的UUID节点ID
|
||
newNodeID := fmt.Sprintf("node_%s", uuid.New().String())
|
||
nodeIDMap[n.ID] = newNodeID
|
||
|
||
node := Node{
|
||
ID: newNodeID,
|
||
Type: n.Type,
|
||
Label: n.Label,
|
||
RiskScore: n.RiskScore,
|
||
Metadata: n.Metadata,
|
||
}
|
||
if node.Metadata == nil {
|
||
node.Metadata = make(map[string]interface{})
|
||
}
|
||
|
||
// 处理tool_execution_id:如果是action或vulnerability节点,需要映射到真实的execution ID
|
||
if n.ToolExecutionID != "" {
|
||
if realExecID, ok := executionMap[n.ToolExecutionID]; ok {
|
||
node.ToolExecutionID = realExecID
|
||
} else {
|
||
// 检查是否是真实的execution ID(UUID格式)
|
||
// 如果是,直接使用;如果不是,尝试从节点ID推断
|
||
if len(n.ToolExecutionID) > 20 { // UUID通常很长
|
||
node.ToolExecutionID = n.ToolExecutionID
|
||
} else {
|
||
// 可能是简单的ID,尝试从节点ID推断
|
||
if realExecID, ok := executionMap[n.ID]; ok {
|
||
node.ToolExecutionID = realExecID
|
||
} else {
|
||
b.logger.Warn("无法映射tool_execution_id",
|
||
zap.String("nodeID", n.ID),
|
||
zap.String("toolExecutionID", n.ToolExecutionID))
|
||
// 对于action节点,如果没有有效的execution ID,清空它(避免外键约束失败)
|
||
if n.Type == "action" {
|
||
node.ToolExecutionID = ""
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} else if n.Type == "action" || n.Type == "vulnerability" {
|
||
// 如果AI没有提供tool_execution_id,尝试从节点ID推断
|
||
// 例如:tool_1 -> 查找exec_1
|
||
if realExecID, ok := executionMap[n.ID]; ok {
|
||
node.ToolExecutionID = realExecID
|
||
} else {
|
||
b.logger.Warn("action/vulnerability节点缺少tool_execution_id",
|
||
zap.String("nodeID", n.ID),
|
||
zap.String("nodeType", n.Type))
|
||
}
|
||
}
|
||
|
||
nodes = append(nodes, node)
|
||
}
|
||
|
||
// 转换边,更新source和target为新的节点ID
|
||
edges := make([]Edge, 0, len(chainData.Edges))
|
||
for _, e := range chainData.Edges {
|
||
sourceID, ok := nodeIDMap[e.Source]
|
||
if !ok {
|
||
b.logger.Warn("边的源节点ID未找到", zap.String("source", e.Source))
|
||
continue
|
||
}
|
||
|
||
targetID, ok := nodeIDMap[e.Target]
|
||
if !ok {
|
||
b.logger.Warn("边的目标节点ID未找到", zap.String("target", e.Target))
|
||
continue
|
||
}
|
||
|
||
edge := Edge{
|
||
ID: fmt.Sprintf("edge_%s", uuid.New().String()),
|
||
Source: sourceID,
|
||
Target: targetID,
|
||
Type: e.Type,
|
||
Weight: e.Weight,
|
||
}
|
||
edges = append(edges, edge)
|
||
}
|
||
|
||
// 过滤掉指向已删除节点的边
|
||
filteredEdges := make([]Edge, 0, len(edges))
|
||
for _, edge := range edges {
|
||
// 检查source和target节点是否都存在
|
||
sourceExists := false
|
||
targetExists := false
|
||
for _, node := range nodes {
|
||
if node.ID == edge.Source {
|
||
sourceExists = true
|
||
}
|
||
if node.ID == edge.Target {
|
||
targetExists = true
|
||
}
|
||
}
|
||
|
||
if sourceExists && targetExists {
|
||
filteredEdges = append(filteredEdges, edge)
|
||
} else {
|
||
b.logger.Warn("过滤无效边",
|
||
zap.String("edgeID", edge.ID),
|
||
zap.String("source", edge.Source),
|
||
zap.String("target", edge.Target),
|
||
zap.Bool("sourceExists", sourceExists),
|
||
zap.Bool("targetExists", targetExists))
|
||
}
|
||
}
|
||
|
||
return &Chain{
|
||
Nodes: nodes,
|
||
Edges: filteredEdges,
|
||
}, nil
|
||
}
|
||
|
||
// shouldFilterNode 判断是否应该过滤掉这个节点
|
||
func (b *Builder) shouldFilterNode(n struct {
|
||
ID string `json:"id"`
|
||
Type string `json:"type"`
|
||
Label string `json:"label"`
|
||
RiskScore int `json:"risk_score"`
|
||
ToolExecutionID string `json:"tool_execution_id,omitempty"`
|
||
Metadata map[string]interface{} `json:"metadata"`
|
||
}, executions []*mcp.ToolExecution) bool {
|
||
// 只允许target、action、vulnerability三种节点类型
|
||
if n.Type != "target" && n.Type != "action" && n.Type != "vulnerability" {
|
||
return true
|
||
}
|
||
|
||
// 检查节点标签是否为空或无效
|
||
if strings.TrimSpace(n.Label) == "" {
|
||
return true
|
||
}
|
||
|
||
// 对于vulnerability节点,即使没有tool_execution_id也应该保留(漏洞可能不是直接来自工具执行)
|
||
if n.Type == "vulnerability" {
|
||
// 只要标签有意义就保留
|
||
return false
|
||
}
|
||
|
||
// 对于target节点,只要标签有意义就保留
|
||
if n.Type == "target" {
|
||
return false
|
||
}
|
||
|
||
// 对于action节点,进行更宽松的检查
|
||
if n.Type == "action" {
|
||
// 如果executions为空(可能是压缩后的场景),只要标签有意义就保留
|
||
if len(executions) == 0 {
|
||
// 压缩场景下,只要标签不是明显无效就保留
|
||
labelLower := strings.ToLower(n.Label)
|
||
// 只过滤明显无效的标签
|
||
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
|
||
for _, keyword := range invalidKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 如果有tool_execution_id,尝试查找对应的工具执行
|
||
if n.ToolExecutionID != "" {
|
||
var exec *mcp.ToolExecution
|
||
for _, e := range executions {
|
||
if e.ID == n.ToolExecutionID {
|
||
exec = e
|
||
break
|
||
}
|
||
}
|
||
|
||
if exec != nil {
|
||
// 找到了对应的工具执行,检查是否有效
|
||
// 检查工具执行是否错误或失败
|
||
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
|
||
// 失败但有线索的应该保留
|
||
if !hasInsightfulFailure(n.Metadata) {
|
||
// 即使没有明确标记为有线索,如果标签描述了具体内容,也保留
|
||
labelLower := strings.ToLower(n.Label)
|
||
// 如果标签包含具体的技术信息(端口、服务、漏洞等),说明有价值
|
||
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
|
||
hasValuableInfo := false
|
||
for _, keyword := range valuableKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
hasValuableInfo = true
|
||
break
|
||
}
|
||
}
|
||
if !hasValuableInfo {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查工具执行结果是否为空
|
||
if exec.Result == nil || len(exec.Result.Content) == 0 {
|
||
// 结果为空,但如果有线索或标签有意义,也保留
|
||
if !hasInsightfulFailure(n.Metadata) {
|
||
labelLower := strings.ToLower(n.Label)
|
||
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
|
||
hasValuableInfo := false
|
||
for _, keyword := range valuableKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
hasValuableInfo = true
|
||
break
|
||
}
|
||
}
|
||
if !hasValuableInfo {
|
||
return true
|
||
}
|
||
}
|
||
} else {
|
||
// 检查结果文本是否为空
|
||
var resultText string
|
||
for _, content := range exec.Result.Content {
|
||
if content.Type == "text" {
|
||
resultText += content.Text
|
||
}
|
||
}
|
||
if strings.TrimSpace(resultText) == "" {
|
||
// 结果文本为空,但如果有线索或标签有意义,也保留
|
||
if !hasInsightfulFailure(n.Metadata) {
|
||
labelLower := strings.ToLower(n.Label)
|
||
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
|
||
hasValuableInfo := false
|
||
for _, keyword := range valuableKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
hasValuableInfo = true
|
||
break
|
||
}
|
||
}
|
||
if !hasValuableInfo {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
// 找不到对应的工具执行,但可能是压缩后的场景
|
||
// 只要标签有意义就保留(不要因为找不到execution就过滤掉)
|
||
labelLower := strings.ToLower(n.Label)
|
||
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
|
||
for _, keyword := range invalidKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
return true
|
||
}
|
||
}
|
||
// 标签有意义,保留
|
||
return false
|
||
}
|
||
} else {
|
||
// 没有tool_execution_id,但可能是压缩后的场景或AI生成的节点
|
||
// 只要标签有意义就保留
|
||
labelLower := strings.ToLower(n.Label)
|
||
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
|
||
for _, keyword := range invalidKeywords {
|
||
if strings.Contains(labelLower, keyword) {
|
||
return true
|
||
}
|
||
}
|
||
// 标签有意义,保留
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 默认保留(已经通过了所有检查)
|
||
return false
|
||
}
|
||
|
||
func hasInsightfulFailure(metadata map[string]interface{}) bool {
|
||
if metadata == nil {
|
||
return false
|
||
}
|
||
|
||
if status, ok := metadata["status"].(string); ok {
|
||
normalized := strings.ToLower(strings.TrimSpace(status))
|
||
if normalized == "failed_insight" || normalized == "failed_clue" || normalized == "failed_with_hint" {
|
||
return true
|
||
}
|
||
}
|
||
|
||
if hint, ok := metadata["hint"].(string); ok && strings.TrimSpace(hint) != "" {
|
||
return true
|
||
}
|
||
|
||
if hints, ok := metadata["hints"].([]interface{}); ok && len(hints) > 0 {
|
||
return true
|
||
}
|
||
|
||
if insight, ok := metadata["insight"].(string); ok && strings.TrimSpace(insight) != "" {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|