From bc5b368ecedd12de23429a8519980d13dad995f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 22 Nov 2025 03:01:50 +0800 Subject: [PATCH] Add files via upload --- internal/attackchain/builder.go | 840 +++++++++++++++++++++++++++----- 1 file changed, 719 insertions(+), 121 deletions(-) diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go index fec10178..83be2218 100644 --- a/internal/attackchain/builder.go +++ b/internal/attackchain/builder.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/mcp" @@ -25,6 +26,8 @@ type Builder struct { logger *zap.Logger openAIClient *http.Client openAIConfig *config.OpenAIConfig + tokenCounter agent.TokenCounter + maxTokens int // 最大tokens限制,默认100000 } // Node 攻击链节点(使用database包的类型) @@ -47,11 +50,26 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap. IdleConnTimeout: 90 * time.Second, } + maxTokens := 100000 // 默认100k tokens,可以根据模型调整 + // 根据模型设置合理的默认值 + if openAIConfig != nil { + model := strings.ToLower(openAIConfig.Model) + if strings.Contains(model, "gpt-4") { + maxTokens = 128000 // gpt-4通常支持128k + } else if strings.Contains(model, "gpt-3.5") { + maxTokens = 16000 // gpt-3.5-turbo通常支持16k + } else if strings.Contains(model, "deepseek") { + maxTokens = 131072 // deepseek-chat通常支持131k + } + } + return &Builder{ db: db, logger: logger, openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport}, openAIConfig: openAIConfig, + tokenCounter: agent.NewTikTokenCounter(), + maxTokens: maxTokens, } } @@ -212,6 +230,17 @@ func (b *Builder) prepareContextData(messages []database.Message, executions []* // generateChainWithRetry 生成攻击链(带重试和压缩机制) func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) { + // 在第一次尝试前,先检查tokens并压缩(如果需要) + totalTokens, err := b.countPromptTokens(contextData) + if err == nil && totalTokens > b.maxTokens { + b.logger.Info("检测到tokens超过限制,提前压缩", + zap.Int("totalTokens", totalTokens), + zap.Int("maxTokens", b.maxTokens)) + if err := b.compressContextData(ctx, contextData); err != nil { + return nil, fmt.Errorf("压缩上下文失败: %w", err) + } + } + for attempt := 0; attempt < maxRetries; attempt++ { b.logger.Info("尝试生成攻击链", zap.Int("attempt", attempt+1), @@ -232,8 +261,8 @@ func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *Conte zap.Int("attempt", attempt+1), zap.Error(err)) - // 压缩最长的子节点 - if err := b.compressLongestItem(ctx, contextData); err != nil { + // 使用分片压缩 + if err := b.compressContextData(ctx, contextData); err != nil { return nil, fmt.Errorf("压缩上下文失败: %w", err) } @@ -552,94 +581,434 @@ func (b *Builder) formatArguments(args map[string]interface{}) string { return string(jsonData) } -// compressLongestItem 压缩最长的子节点 -func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error { - var longestID string - var longestType string - var longestContent string - maxLength := 0 +// countPromptTokens 计算prompt的总tokens数 +func (b *Builder) countPromptTokens(contextData *ContextData) (int, error) { + prompt, err := b.buildChainGenerationPrompt(contextData) + if err != nil { + return 0, fmt.Errorf("构建提示词失败: %w", err) + } - // 查找最长的消息 + if b.tokenCounter == nil || b.openAIConfig == nil { + // 如果没有token计数器或配置,使用简单的估算(4个字符=1个token) + return len(prompt) / 4, nil + } + + model := b.openAIConfig.Model + if model == "" { + model = "gpt-4" // 默认模型 + } + + count, err := b.tokenCounter.Count(model, prompt) + if err != nil { + // 如果计算失败,使用估算 + return len(prompt) / 4, nil + } + return count, nil +} + +// compressContextData 使用分片压缩方式压缩上下文数据 +func (b *Builder) compressContextData(ctx context.Context, contextData *ContextData) error { + // 计算当前tokens + totalTokens, err := b.countPromptTokens(contextData) + if err != nil { + return fmt.Errorf("计算tokens失败: %w", err) + } + + b.logger.Info("开始压缩上下文", + zap.Int("totalTokens", totalTokens), + zap.Int("maxTokens", b.maxTokens)) + + // 如果tokens在限制内,不需要压缩 + if totalTokens <= b.maxTokens { + return nil + } + + // 计算需要分成多少份 + numChunks := (totalTokens + b.maxTokens - 1) / b.maxTokens // 向上取整 + if numChunks < 2 { + numChunks = 2 // 至少分成2份 + } + + b.logger.Info("将上下文分成多个片段进行压缩", + zap.Int("totalTokens", totalTokens), + zap.Int("maxTokens", b.maxTokens), + zap.Int("numChunks", numChunks)) + + // 按时间顺序将数据分成多个片段 + chunks, err := b.splitContextDataByTime(contextData, numChunks) + if err != nil { + return fmt.Errorf("分割上下文数据失败: %w", err) + } + + // 对每个片段进行摘要 + summaries := make([]string, 0, len(chunks)) + for i, chunk := range chunks { + b.logger.Info("压缩片段", + zap.Int("chunkIndex", i+1), + zap.Int("totalChunks", len(chunks)), + zap.Int("chunkSize", len(chunk.Messages)+len(chunk.Executions))) + + summary, err := b.summarizeContextChunk(ctx, chunk) + if err != nil { + // 检查是否是认证错误 + if strings.Contains(err.Error(), "Authentication") || strings.Contains(err.Error(), "api key") || strings.Contains(err.Error(), "invalid") { + return fmt.Errorf("压缩片段%d失败(API认证错误,请检查OpenAI配置): %w", i+1, err) + } + return fmt.Errorf("压缩片段%d失败: %w", i+1, err) + } + summaries = append(summaries, summary) + } + + // 将摘要合并到contextData中 + // 保留用户消息,清空其他数据,用摘要替换 + var userMessages []database.Message for _, msg := range contextData.Messages { if strings.EqualFold(msg.Role, "user") { - continue - } - if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized { - continue - } - length := len(msg.Content) - if length > maxLength { - maxLength = length - longestID = msg.ID - longestType = "message" - longestContent = msg.Content + userMessages = append(userMessages, msg) } } - // 查找最长的工具执行结果 + // 清空非用户消息和执行记录 + contextData.Messages = userMessages + contextData.Executions = []*mcp.ToolExecution{} + contextData.ProcessDetails = make(map[string][]database.ProcessDetail) + + // 创建一个综合摘要消息 + combinedSummary := strings.Join(summaries, "\n\n---\n\n") + summaryMsg := database.Message{ + ID: uuid.New().String(), + Role: "assistant", + Content: fmt.Sprintf("[上下文摘要 - 包含%d个片段的压缩内容]\n\n%s", len(summaries), combinedSummary), + CreatedAt: time.Now(), + } + contextData.Messages = append(contextData.Messages, summaryMsg) + + // 检查压缩后的tokens + compressedTokens, err := b.countPromptTokens(contextData) + if err != nil { + return fmt.Errorf("计算压缩后tokens失败: %w", err) + } + + b.logger.Info("压缩完成", + zap.Int("originalTokens", totalTokens), + zap.Int("compressedTokens", compressedTokens), + zap.Int("reduction", totalTokens-compressedTokens)) + + // 如果压缩后仍然超过限制,递归压缩 + if compressedTokens > b.maxTokens { + b.logger.Info("压缩后仍然超过限制,继续递归压缩", + zap.Int("compressedTokens", compressedTokens), + zap.Int("maxTokens", b.maxTokens)) + return b.compressContextData(ctx, contextData) + } + + return nil +} + +// ContextChunk 上下文数据片段 +type ContextChunk struct { + Messages []database.Message + Executions []*mcp.ToolExecution + ProcessDetails map[string][]database.ProcessDetail +} + +// splitContextDataByTime 按时间顺序将上下文数据分成多个片段 +func (b *Builder) splitContextDataByTime(contextData *ContextData, numChunks int) ([]*ContextChunk, error) { + if numChunks <= 0 { + return nil, fmt.Errorf("片段数量必须大于0") + } + + // 收集所有带时间戳的项目 + type timeItem struct { + time time.Time + itemType string // "message", "execution", "thinking" + message *database.Message + execution *mcp.ToolExecution + processDetail *database.ProcessDetail + } + + var items []timeItem + + // 添加消息(跳过已总结的) + for i := range contextData.Messages { + msg := &contextData.Messages[i] + if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized { + continue + } + items = append(items, timeItem{ + time: msg.CreatedAt, + itemType: "message", + message: msg, + }) + } + + // 添加工具执行(跳过已总结的) for _, exec := range contextData.Executions { if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized { continue } - if exec.Result != nil { - var resultText string - for _, content := range exec.Result.Content { - if content.Type == "text" { - resultText += content.Text + "\n" - } - } - length := len(resultText) - if length > maxLength { - maxLength = length - longestID = exec.ID - longestType = "execution" - longestContent = resultText - } - } + items = append(items, timeItem{ + time: exec.StartTime, + itemType: "execution", + execution: exec, + }) } - // 查找最长的思考过程 + // 添加思考过程(跳过已总结的) for _, details := range contextData.ProcessDetails { - for _, detail := range details { + for i := range details { + detail := &details[i] if detail.EventType == "thinking" { if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized { continue } - length := len(detail.Message) - if length > maxLength { - maxLength = length - longestID = detail.ID - longestType = "thinking" - longestContent = detail.Message - } + items = append(items, timeItem{ + time: detail.CreatedAt, + itemType: "thinking", + processDetail: detail, + }) } } } - if longestID == "" { - return fmt.Errorf("没有找到需要压缩的内容") + if len(items) == 0 { + return nil, fmt.Errorf("没有可分割的数据") } - b.logger.Info("压缩最长子节点", - zap.String("id", longestID), - zap.String("type", longestType), - zap.Int("length", maxLength)) + // 按时间排序 + sort.Slice(items, func(i, j int) bool { + return items[i].time.Before(items[j].time) + }) + + // 计算每个片段的大小 + chunkSize := (len(items) + numChunks - 1) / numChunks // 向上取整 + + // 创建片段 + chunks := make([]*ContextChunk, 0, numChunks) + for i := 0; i < len(items); i += chunkSize { + end := i + chunkSize + if end > len(items) { + end = len(items) + } + + chunk := &ContextChunk{ + Messages: []database.Message{}, + Executions: []*mcp.ToolExecution{}, + ProcessDetails: make(map[string][]database.ProcessDetail), + } + + for j := i; j < end; j++ { + item := items[j] + switch item.itemType { + case "message": + chunk.Messages = append(chunk.Messages, *item.message) + case "execution": + chunk.Executions = append(chunk.Executions, item.execution) + case "thinking": + if item.processDetail != nil { + msgID := item.processDetail.MessageID + chunk.ProcessDetails[msgID] = append(chunk.ProcessDetails[msgID], *item.processDetail) + } + } + } + + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// getModelMaxContextLength 获取模型的最大上下文长度 +func (b *Builder) getModelMaxContextLength() int { + if b.openAIConfig == nil { + return 131072 // 默认值 + } + model := strings.ToLower(b.openAIConfig.Model) + if strings.Contains(model, "gpt-4") { + return 128000 + } else if strings.Contains(model, "gpt-3.5") { + return 16000 + } else if strings.Contains(model, "deepseek") { + return 131072 + } + return 131072 // 默认值 +} + +// summarizeContextChunk 总结一个上下文片段 +func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk) (string, error) { + // 先构建内容 + content, err := b.buildChunkContent(chunk) + if err != nil { + return "", err + } // 使用AI总结 - summary, err := b.summarizeContent(ctx, longestType, longestContent) + promptTemplate := `请详细总结以下安全测试对话片段的关键信息。虽然需要压缩内容,但必须保留所有重要的技术细节和上下文信息,确保后续攻击链生成时能够准确理解整个测试过程。 + +**必须详细保留的内容:** +1. **所有工具执行记录**: + - 工具名称、执行参数、执行结果(包括成功和失败) + - 失败执行的错误信息、状态码、响应头等关键线索 + - 工具输出的关键数据(端口、服务版本、漏洞信息等) + - 每个工具执行的时间顺序和上下文关系 + +2. **所有发现的漏洞和潜在安全问题**: + - 漏洞类型、严重程度、位置、利用方式 + - 验证过程和结果 + - 漏洞之间的关联关系 + +3. **所有测试目标和资产信息**: + - IP地址、域名、URL、端口等 + - 发现的服务、技术栈、版本信息 + - 资产之间的关联关系 + +4. **所有测试步骤和决策过程**: + - 每个测试步骤的详细描述(做了什么、为什么做、结果如何) + - AI的分析思路和决策依据 + - 失败尝试的原因和从中获得的线索 + +5. **所有关键发现和线索**: + - 成功发现的详细信息 + - 失败但提供线索的尝试(错误信息、限制条件、下一步建议等) + - 收集到的任何有价值的信息(凭据、令牌、配置信息等) + +**总结要求:** +- 用结构化的方式组织信息,按时间顺序或逻辑顺序排列 +- 对于每个工具执行,必须包含:工具名、目标、参数、结果/错误、关键发现 +- 对于每个漏洞,必须包含:类型、位置、严重程度、验证结果 +- 保留所有技术细节,不要过度简化 +- 确保后续AI能够根据这个摘要完整重建攻击链 + +对话片段: +%s + +请给出详细且结构化的技术摘要(建议1000-2000字,确保信息完整):` + + // 检查prompt tokens,如果超过限制,需要进一步压缩内容 + maxContextLength := b.getModelMaxContextLength() + maxPromptTokens := maxContextLength - 2000 // 留出空间给响应和系统消息 + + // 尝试构建完整prompt并检查tokens + fullPrompt := fmt.Sprintf(promptTemplate, content) + promptTokens, err := b.countTextTokens(fullPrompt) if err != nil { - return fmt.Errorf("总结内容失败: %w", err) + // 如果计算失败,使用估算 + promptTokens = len(fullPrompt) / 4 } - // 保存总结 - contextData.SummarizedItems[longestID] = summary + // 如果prompt太大,需要进一步压缩内容 + if promptTokens > maxPromptTokens { + b.logger.Warn("片段内容过大,需要进一步压缩", + zap.Int("promptTokens", promptTokens), + zap.Int("maxPromptTokens", maxPromptTokens)) - b.logger.Info("压缩完成", - zap.String("id", longestID), - zap.Int("originalLength", maxLength), - zap.Int("summaryLength", len(summary))) + // 递归压缩:将chunk进一步分割 + compressedContent, err := b.compressLargeChunk(ctx, chunk, maxPromptTokens) + if err != nil { + return "", fmt.Errorf("压缩大片段失败: %w", err) + } + content = compressedContent + } - return nil + prompt := fmt.Sprintf(promptTemplate, content) + + // 检查配置 + if b.openAIConfig == nil { + return "", fmt.Errorf("OpenAI配置未初始化") + } + if b.openAIConfig.APIKey == "" { + return "", fmt.Errorf("OpenAI API Key未配置") + } + if b.openAIConfig.Model == "" { + return "", fmt.Errorf("OpenAI Model未配置") + } + + // 直接调用AI API进行总结 + requestBody := map[string]interface{}{ + "model": b.openAIConfig.Model, + "messages": []map[string]interface{}{ + { + "role": "system", + "content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试对话片段,这些摘要将用于后续构建完整的攻击链图。 + +**你的专业背景:** +- 精通各种安全测试工具(Nmap、SQLMap、Burp Suite、Metasploit等)的使用和结果分析 +- 熟悉常见漏洞类型(SQL注入、XSS、文件上传、命令执行、目录遍历等)的识别和验证 +- 理解攻击链的构建逻辑:从信息收集 → 漏洞发现 → 漏洞利用 → 权限提升 → 横向移动 +- 能够识别失败尝试中的有价值线索(错误信息、状态码、WAF指纹、技术栈信息等) + +**你的总结原则:** +1. **完整性优先**:虽然需要压缩,但必须保留所有技术细节,确保后续AI能够完整重建攻击链 +2. **结构化组织**:按时间顺序或逻辑顺序组织信息,让信息易于理解和追踪 +3. **技术精准**:使用准确的技术术语,保留具体的数值、版本号、端口号、URL等关键数据 +4. **上下文关联**:保留测试步骤之间的因果关系和逻辑关联 +5. **失败价值**:即使是失败的尝试,只要提供了线索(错误信息、限制条件、下一步建议),也要详细记录 + +**你需要特别关注的信息类型:** +- 工具执行:工具名、目标、参数、完整结果(包括错误和失败) +- 漏洞发现:类型、位置、严重程度、验证方法、利用结果 +- 资产信息:IP、域名、端口、服务版本、技术栈 +- 测试策略:为什么选择这个工具、为什么测试这个目标、发现了什么线索 +- 关键数据:凭据、令牌、配置信息、敏感文件内容 + +请用专业、详细、结构化的中文进行总结,确保信息完整且易于后续处理。`, + }, + { + "role": "user", + "content": prompt, + }, + }, + "temperature": 0.3, + "max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容 + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return "", fmt.Errorf("序列化请求失败: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("创建请求失败: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey) + + resp, err := b.openAIClient.Do(req) + if err != nil { + return "", fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body)) + } + + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + return "", fmt.Errorf("解析响应失败: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return "", fmt.Errorf("API未返回有效响应") + } + + return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil +} + +// compressLongestItem 压缩最长的子节点(保留作为备用方法) +func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error { + // 使用新的分片压缩方法 + return b.compressContextData(ctx, contextData) } // summarizeContent 总结内容 @@ -675,8 +1044,28 @@ AI回复: "model": b.openAIConfig.Model, "messages": []map[string]interface{}{ { - "role": "system", - "content": "你是一个专业的安全测试分析师,擅长总结安全测试相关的信息。请用简洁的中文总结关键信息。", + "role": "system", + "content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试过程中的关键信息,这些摘要将用于构建攻击链图。 + +**你的专业背景:** +- 精通各种安全测试工具的使用和结果分析(Nmap、SQLMap、Burp Suite、Metasploit、Nuclei等) +- 熟悉常见漏洞类型的识别和验证(SQL注入、XSS、文件上传、命令执行、目录遍历、SSRF等) +- 理解攻击链的构建逻辑和测试流程 +- 能够识别失败尝试中的有价值线索 + +**你的总结原则:** +1. **保留技术细节**:保留所有重要的技术信息,包括工具名、参数、结果、错误信息、状态码等 +2. **突出关键发现**:重点记录发现的漏洞、安全问题、资产信息、凭据等 +3. **记录失败线索**:即使是失败的尝试,如果提供了错误信息、限制条件或下一步建议,也要详细记录 +4. **保持准确性**:使用准确的技术术语,保留具体的数值、版本号、端口号等关键数据 +5. **结构化表达**:用清晰、有条理的方式组织信息 + +**根据内容类型,你需要特别关注:** +- **AI回复**:提取安全发现、漏洞信息、测试结果、分析思路、决策依据 +- **工具执行**:记录工具名、目标、参数、完整结果(成功或失败)、关键发现、错误信息 +- **思考过程**:提取关键决策点、测试策略、分析思路、下一步计划 + +请用专业、准确、简洁的中文进行总结,确保信息完整且易于理解。`, }, { "role": "user", @@ -730,6 +1119,146 @@ AI回复: return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil } +// buildChunkContent 构建chunk的文本内容 +func (b *Builder) buildChunkContent(chunk *ContextChunk) (string, error) { + var contentBuilder strings.Builder + + // 添加消息 + for _, msg := range chunk.Messages { + if strings.EqualFold(msg.Role, "user") { + contentBuilder.WriteString(fmt.Sprintf("用户消息: %s\n\n", msg.Content)) + } else { + contentBuilder.WriteString(fmt.Sprintf("AI回复: %s\n\n", msg.Content)) + } + } + + // 添加工具执行 + for _, exec := range chunk.Executions { + contentBuilder.WriteString(fmt.Sprintf("工具执行 [%s] (ID: %s):\n", exec.ToolName, exec.ID)) + contentBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments))) + + if exec.Error != "" { + contentBuilder.WriteString(fmt.Sprintf("错误: %s\n", exec.Error)) + } + + if exec.Result != nil { + var resultText string + for _, content := range exec.Result.Content { + if content.Type == "text" { + resultText += content.Text + "\n" + } + } + if resultText != "" { + // 如果结果太长,截断 + if len(resultText) > 10000 { + resultText = resultText[:10000] + "\n... [内容已截断]" + } + contentBuilder.WriteString(fmt.Sprintf("结果: %s\n", resultText)) + } + } + contentBuilder.WriteString("\n") + } + + // 添加思考过程 + for _, details := range chunk.ProcessDetails { + for _, detail := range details { + if detail.EventType == "thinking" { + thinkingText := detail.Message + // 如果思考过程太长,截断 + if len(thinkingText) > 5000 { + thinkingText = thinkingText[:5000] + "\n... [内容已截断]" + } + contentBuilder.WriteString(fmt.Sprintf("思考过程: %s\n\n", thinkingText)) + } + } + } + + content := contentBuilder.String() + if content == "" { + return "", fmt.Errorf("片段内容为空") + } + return content, nil +} + +// compressLargeChunk 压缩过大的chunk(递归分割) +func (b *Builder) compressLargeChunk(ctx context.Context, chunk *ContextChunk, maxTokens int) (string, error) { + // 将chunk进一步分割成更小的子chunk + // 简单策略:按消息和执行数量平均分割 + totalItems := len(chunk.Messages) + len(chunk.Executions) + if totalItems <= 1 { + // 如果只有一个项目,直接截断内容 + content, _ := b.buildChunkContent(chunk) + if len(content) > maxTokens*4 { // 粗略估算:1 token ≈ 4字符 + content = content[:maxTokens*4] + "\n... [内容过大,已截断]" + } + return content, nil + } + + // 分成2个子chunk + mid := totalItems / 2 + subChunk1 := &ContextChunk{ + Messages: []database.Message{}, + Executions: []*mcp.ToolExecution{}, + ProcessDetails: make(map[string][]database.ProcessDetail), + } + subChunk2 := &ContextChunk{ + Messages: []database.Message{}, + Executions: []*mcp.ToolExecution{}, + ProcessDetails: make(map[string][]database.ProcessDetail), + } + + // 分配消息 + for i, msg := range chunk.Messages { + if i < mid { + subChunk1.Messages = append(subChunk1.Messages, msg) + } else { + subChunk2.Messages = append(subChunk2.Messages, msg) + } + } + + // 分配执行 + execStart := len(chunk.Messages) + for i, exec := range chunk.Executions { + if execStart+i < mid { + subChunk1.Executions = append(subChunk1.Executions, exec) + } else { + subChunk2.Executions = append(subChunk2.Executions, exec) + } + } + + // 递归压缩子chunk + summary1, err := b.summarizeContextChunk(ctx, subChunk1) + if err != nil { + return "", fmt.Errorf("压缩子chunk1失败: %w", err) + } + + summary2, err := b.summarizeContextChunk(ctx, subChunk2) + if err != nil { + return "", fmt.Errorf("压缩子chunk2失败: %w", err) + } + + // 合并摘要 + return fmt.Sprintf("片段1摘要:\n%s\n\n---\n\n片段2摘要:\n%s", summary1, summary2), nil +} + +// countTextTokens 计算文本的tokens数 +func (b *Builder) countTextTokens(text string) (int, error) { + if b.tokenCounter == nil || b.openAIConfig == nil { + return len(text) / 4, nil + } + + model := b.openAIConfig.Model + if model == "" { + model = "gpt-4" + } + + count, err := b.tokenCounter.Count(model, text) + if err != nil { + return len(text) / 4, nil + } + return count, nil +} + // callAIForChainGeneration 调用AI生成攻击链 func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) { requestBody := map[string]interface{}{ @@ -985,74 +1514,143 @@ func (b *Builder) shouldFilterNode(n struct { return true } - // 对于action节点,检查对应的工具执行是否有效 - if n.Type == "action" { - if n.ToolExecutionID == "" { - // 没有关联工具执行的action节点,可能是无效的 - return true - } - - // 查找对应的工具执行 - var exec *mcp.ToolExecution - for _, e := range executions { - if e.ID == n.ToolExecutionID { - exec = e - break - } - } - - if exec == nil { - // 找不到对应的工具执行,可能是无效的 - return true - } - - // 检查工具执行是否错误或失败 - if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) { - if !hasInsightfulFailure(n.Metadata) { - return true - } - } - - // 检查工具执行结果是否为空 - if exec.Result == nil || len(exec.Result.Content) == 0 { - if !hasInsightfulFailure(n.Metadata) { - return true - } - } - - // 检查结果文本是否为空 - var resultText string - if exec.Result != nil { - for _, content := range exec.Result.Content { - if content.Type == "text" { - resultText += content.Text - } - } - } - if strings.TrimSpace(resultText) == "" { - if !hasInsightfulFailure(n.Metadata) { - return true - } - } - } - // 检查节点标签是否为空或无效 if strings.TrimSpace(n.Label) == "" { return true } - // 检查标签中是否包含错误/失败的关键词 - labelLower := strings.ToLower(n.Label) - errorKeywords := []string{"错误", "失败", "无效", "error", "failed", "invalid", "empty", "空"} - for _, keyword := range errorKeywords { - if strings.Contains(labelLower, keyword) { - // 如果标签明确表示错误,但节点类型不是vulnerability,则过滤 - if n.Type != "vulnerability" { - return true + // 对于vulnerability节点,即使没有tool_execution_id也应该保留(漏洞可能不是直接来自工具执行) + if n.Type == "vulnerability" { + // 只要标签有意义就保留 + return false + } + + // 对于target节点,只要标签有意义就保留 + if n.Type == "target" { + return false + } + + // 对于action节点,进行更宽松的检查 + if n.Type == "action" { + // 如果executions为空(可能是压缩后的场景),只要标签有意义就保留 + if len(executions) == 0 { + // 压缩场景下,只要标签不是明显无效就保留 + labelLower := strings.ToLower(n.Label) + // 只过滤明显无效的标签 + invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"} + for _, keyword := range invalidKeywords { + if strings.Contains(labelLower, keyword) { + return true + } } + return false + } + + // 如果有tool_execution_id,尝试查找对应的工具执行 + if n.ToolExecutionID != "" { + var exec *mcp.ToolExecution + for _, e := range executions { + if e.ID == n.ToolExecutionID { + exec = e + break + } + } + + if exec != nil { + // 找到了对应的工具执行,检查是否有效 + // 检查工具执行是否错误或失败 + if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) { + // 失败但有线索的应该保留 + if !hasInsightfulFailure(n.Metadata) { + // 即使没有明确标记为有线索,如果标签描述了具体内容,也保留 + labelLower := strings.ToLower(n.Label) + // 如果标签包含具体的技术信息(端口、服务、漏洞等),说明有价值 + valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"} + hasValuableInfo := false + for _, keyword := range valuableKeywords { + if strings.Contains(labelLower, keyword) { + hasValuableInfo = true + break + } + } + if !hasValuableInfo { + return true + } + } + } + + // 检查工具执行结果是否为空 + if exec.Result == nil || len(exec.Result.Content) == 0 { + // 结果为空,但如果有线索或标签有意义,也保留 + if !hasInsightfulFailure(n.Metadata) { + labelLower := strings.ToLower(n.Label) + valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"} + hasValuableInfo := false + for _, keyword := range valuableKeywords { + if strings.Contains(labelLower, keyword) { + hasValuableInfo = true + break + } + } + if !hasValuableInfo { + return true + } + } + } else { + // 检查结果文本是否为空 + var resultText string + for _, content := range exec.Result.Content { + if content.Type == "text" { + resultText += content.Text + } + } + if strings.TrimSpace(resultText) == "" { + // 结果文本为空,但如果有线索或标签有意义,也保留 + if !hasInsightfulFailure(n.Metadata) { + labelLower := strings.ToLower(n.Label) + valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"} + hasValuableInfo := false + for _, keyword := range valuableKeywords { + if strings.Contains(labelLower, keyword) { + hasValuableInfo = true + break + } + } + if !hasValuableInfo { + return true + } + } + } + } + } else { + // 找不到对应的工具执行,但可能是压缩后的场景 + // 只要标签有意义就保留(不要因为找不到execution就过滤掉) + labelLower := strings.ToLower(n.Label) + invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"} + for _, keyword := range invalidKeywords { + if strings.Contains(labelLower, keyword) { + return true + } + } + // 标签有意义,保留 + return false + } + } else { + // 没有tool_execution_id,但可能是压缩后的场景或AI生成的节点 + // 只要标签有意义就保留 + labelLower := strings.ToLower(n.Label) + invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"} + for _, keyword := range invalidKeywords { + if strings.Contains(labelLower, keyword) { + return true + } + } + // 标签有意义,保留 + return false } } + // 默认保留(已经通过了所有检查) return false }