Add files via upload

This commit is contained in:
公明
2025-11-22 03:01:50 +08:00
committed by GitHub
parent af0e4b5ccb
commit bc5b368ece

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
@@ -25,6 +26,8 @@ type Builder struct {
logger *zap.Logger
openAIClient *http.Client
openAIConfig *config.OpenAIConfig
tokenCounter agent.TokenCounter
maxTokens int // 最大tokens限制默认100000
}
// Node 攻击链节点使用database包的类型
@@ -47,11 +50,26 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
IdleConnTimeout: 90 * time.Second,
}
maxTokens := 100000 // 默认100k tokens可以根据模型调整
// 根据模型设置合理的默认值
if openAIConfig != nil {
model := strings.ToLower(openAIConfig.Model)
if strings.Contains(model, "gpt-4") {
maxTokens = 128000 // gpt-4通常支持128k
} else if strings.Contains(model, "gpt-3.5") {
maxTokens = 16000 // gpt-3.5-turbo通常支持16k
} else if strings.Contains(model, "deepseek") {
maxTokens = 131072 // deepseek-chat通常支持131k
}
}
return &Builder{
db: db,
logger: logger,
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
openAIConfig: openAIConfig,
tokenCounter: agent.NewTikTokenCounter(),
maxTokens: maxTokens,
}
}
@@ -212,6 +230,17 @@ func (b *Builder) prepareContextData(messages []database.Message, executions []*
// generateChainWithRetry 生成攻击链(带重试和压缩机制)
func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) {
// 在第一次尝试前先检查tokens并压缩如果需要
totalTokens, err := b.countPromptTokens(contextData)
if err == nil && totalTokens > b.maxTokens {
b.logger.Info("检测到tokens超过限制提前压缩",
zap.Int("totalTokens", totalTokens),
zap.Int("maxTokens", b.maxTokens))
if err := b.compressContextData(ctx, contextData); err != nil {
return nil, fmt.Errorf("压缩上下文失败: %w", err)
}
}
for attempt := 0; attempt < maxRetries; attempt++ {
b.logger.Info("尝试生成攻击链",
zap.Int("attempt", attempt+1),
@@ -232,8 +261,8 @@ func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *Conte
zap.Int("attempt", attempt+1),
zap.Error(err))
// 压缩最长的子节点
if err := b.compressLongestItem(ctx, contextData); err != nil {
// 使用分片压缩
if err := b.compressContextData(ctx, contextData); err != nil {
return nil, fmt.Errorf("压缩上下文失败: %w", err)
}
@@ -552,94 +581,434 @@ func (b *Builder) formatArguments(args map[string]interface{}) string {
return string(jsonData)
}
// compressLongestItem 压缩最长的子节点
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
var longestID string
var longestType string
var longestContent string
maxLength := 0
// countPromptTokens 计算prompt的总tokens数
func (b *Builder) countPromptTokens(contextData *ContextData) (int, error) {
prompt, err := b.buildChainGenerationPrompt(contextData)
if err != nil {
return 0, fmt.Errorf("构建提示词失败: %w", err)
}
// 查找最长的消息
if b.tokenCounter == nil || b.openAIConfig == nil {
// 如果没有token计数器或配置使用简单的估算4个字符=1个token
return len(prompt) / 4, nil
}
model := b.openAIConfig.Model
if model == "" {
model = "gpt-4" // 默认模型
}
count, err := b.tokenCounter.Count(model, prompt)
if err != nil {
// 如果计算失败,使用估算
return len(prompt) / 4, nil
}
return count, nil
}
// compressContextData 使用分片压缩方式压缩上下文数据
func (b *Builder) compressContextData(ctx context.Context, contextData *ContextData) error {
// 计算当前tokens
totalTokens, err := b.countPromptTokens(contextData)
if err != nil {
return fmt.Errorf("计算tokens失败: %w", err)
}
b.logger.Info("开始压缩上下文",
zap.Int("totalTokens", totalTokens),
zap.Int("maxTokens", b.maxTokens))
// 如果tokens在限制内不需要压缩
if totalTokens <= b.maxTokens {
return nil
}
// 计算需要分成多少份
numChunks := (totalTokens + b.maxTokens - 1) / b.maxTokens // 向上取整
if numChunks < 2 {
numChunks = 2 // 至少分成2份
}
b.logger.Info("将上下文分成多个片段进行压缩",
zap.Int("totalTokens", totalTokens),
zap.Int("maxTokens", b.maxTokens),
zap.Int("numChunks", numChunks))
// 按时间顺序将数据分成多个片段
chunks, err := b.splitContextDataByTime(contextData, numChunks)
if err != nil {
return fmt.Errorf("分割上下文数据失败: %w", err)
}
// 对每个片段进行摘要
summaries := make([]string, 0, len(chunks))
for i, chunk := range chunks {
b.logger.Info("压缩片段",
zap.Int("chunkIndex", i+1),
zap.Int("totalChunks", len(chunks)),
zap.Int("chunkSize", len(chunk.Messages)+len(chunk.Executions)))
summary, err := b.summarizeContextChunk(ctx, chunk)
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "Authentication") || strings.Contains(err.Error(), "api key") || strings.Contains(err.Error(), "invalid") {
return fmt.Errorf("压缩片段%d失败API认证错误请检查OpenAI配置: %w", i+1, err)
}
return fmt.Errorf("压缩片段%d失败: %w", i+1, err)
}
summaries = append(summaries, summary)
}
// 将摘要合并到contextData中
// 保留用户消息,清空其他数据,用摘要替换
var userMessages []database.Message
for _, msg := range contextData.Messages {
if strings.EqualFold(msg.Role, "user") {
continue
}
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
continue
}
length := len(msg.Content)
if length > maxLength {
maxLength = length
longestID = msg.ID
longestType = "message"
longestContent = msg.Content
userMessages = append(userMessages, msg)
}
}
// 查找最长的工具执行结果
// 清空非用户消息和执行记录
contextData.Messages = userMessages
contextData.Executions = []*mcp.ToolExecution{}
contextData.ProcessDetails = make(map[string][]database.ProcessDetail)
// 创建一个综合摘要消息
combinedSummary := strings.Join(summaries, "\n\n---\n\n")
summaryMsg := database.Message{
ID: uuid.New().String(),
Role: "assistant",
Content: fmt.Sprintf("[上下文摘要 - 包含%d个片段的压缩内容]\n\n%s", len(summaries), combinedSummary),
CreatedAt: time.Now(),
}
contextData.Messages = append(contextData.Messages, summaryMsg)
// 检查压缩后的tokens
compressedTokens, err := b.countPromptTokens(contextData)
if err != nil {
return fmt.Errorf("计算压缩后tokens失败: %w", err)
}
b.logger.Info("压缩完成",
zap.Int("originalTokens", totalTokens),
zap.Int("compressedTokens", compressedTokens),
zap.Int("reduction", totalTokens-compressedTokens))
// 如果压缩后仍然超过限制,递归压缩
if compressedTokens > b.maxTokens {
b.logger.Info("压缩后仍然超过限制,继续递归压缩",
zap.Int("compressedTokens", compressedTokens),
zap.Int("maxTokens", b.maxTokens))
return b.compressContextData(ctx, contextData)
}
return nil
}
// ContextChunk 上下文数据片段
type ContextChunk struct {
Messages []database.Message
Executions []*mcp.ToolExecution
ProcessDetails map[string][]database.ProcessDetail
}
// splitContextDataByTime 按时间顺序将上下文数据分成多个片段
func (b *Builder) splitContextDataByTime(contextData *ContextData, numChunks int) ([]*ContextChunk, error) {
if numChunks <= 0 {
return nil, fmt.Errorf("片段数量必须大于0")
}
// 收集所有带时间戳的项目
type timeItem struct {
time time.Time
itemType string // "message", "execution", "thinking"
message *database.Message
execution *mcp.ToolExecution
processDetail *database.ProcessDetail
}
var items []timeItem
// 添加消息(跳过已总结的)
for i := range contextData.Messages {
msg := &contextData.Messages[i]
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
continue
}
items = append(items, timeItem{
time: msg.CreatedAt,
itemType: "message",
message: msg,
})
}
// 添加工具执行(跳过已总结的)
for _, exec := range contextData.Executions {
if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized {
continue
}
if exec.Result != nil {
var resultText string
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text + "\n"
}
}
length := len(resultText)
if length > maxLength {
maxLength = length
longestID = exec.ID
longestType = "execution"
longestContent = resultText
}
}
items = append(items, timeItem{
time: exec.StartTime,
itemType: "execution",
execution: exec,
})
}
// 查找最长的思考过程
// 添加思考过程(跳过已总结的)
for _, details := range contextData.ProcessDetails {
for _, detail := range details {
for i := range details {
detail := &details[i]
if detail.EventType == "thinking" {
if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized {
continue
}
length := len(detail.Message)
if length > maxLength {
maxLength = length
longestID = detail.ID
longestType = "thinking"
longestContent = detail.Message
}
items = append(items, timeItem{
time: detail.CreatedAt,
itemType: "thinking",
processDetail: detail,
})
}
}
}
if longestID == "" {
return fmt.Errorf("没有找到需要压缩的内容")
if len(items) == 0 {
return nil, fmt.Errorf("没有可分割的数据")
}
b.logger.Info("压缩最长子节点",
zap.String("id", longestID),
zap.String("type", longestType),
zap.Int("length", maxLength))
// 按时间排序
sort.Slice(items, func(i, j int) bool {
return items[i].time.Before(items[j].time)
})
// 计算每个片段的大小
chunkSize := (len(items) + numChunks - 1) / numChunks // 向上取整
// 创建片段
chunks := make([]*ContextChunk, 0, numChunks)
for i := 0; i < len(items); i += chunkSize {
end := i + chunkSize
if end > len(items) {
end = len(items)
}
chunk := &ContextChunk{
Messages: []database.Message{},
Executions: []*mcp.ToolExecution{},
ProcessDetails: make(map[string][]database.ProcessDetail),
}
for j := i; j < end; j++ {
item := items[j]
switch item.itemType {
case "message":
chunk.Messages = append(chunk.Messages, *item.message)
case "execution":
chunk.Executions = append(chunk.Executions, item.execution)
case "thinking":
if item.processDetail != nil {
msgID := item.processDetail.MessageID
chunk.ProcessDetails[msgID] = append(chunk.ProcessDetails[msgID], *item.processDetail)
}
}
}
chunks = append(chunks, chunk)
}
return chunks, nil
}
// getModelMaxContextLength 获取模型的最大上下文长度
func (b *Builder) getModelMaxContextLength() int {
if b.openAIConfig == nil {
return 131072 // 默认值
}
model := strings.ToLower(b.openAIConfig.Model)
if strings.Contains(model, "gpt-4") {
return 128000
} else if strings.Contains(model, "gpt-3.5") {
return 16000
} else if strings.Contains(model, "deepseek") {
return 131072
}
return 131072 // 默认值
}
// summarizeContextChunk 总结一个上下文片段
func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk) (string, error) {
// 先构建内容
content, err := b.buildChunkContent(chunk)
if err != nil {
return "", err
}
// 使用AI总结
summary, err := b.summarizeContent(ctx, longestType, longestContent)
promptTemplate := `请详细总结以下安全测试对话片段的关键信息。虽然需要压缩内容,但必须保留所有重要的技术细节和上下文信息,确保后续攻击链生成时能够准确理解整个测试过程。
**必须详细保留的内容:**
1. **所有工具执行记录**
- 工具名称、执行参数、执行结果(包括成功和失败)
- 失败执行的错误信息、状态码、响应头等关键线索
- 工具输出的关键数据(端口、服务版本、漏洞信息等)
- 每个工具执行的时间顺序和上下文关系
2. **所有发现的漏洞和潜在安全问题**
- 漏洞类型、严重程度、位置、利用方式
- 验证过程和结果
- 漏洞之间的关联关系
3. **所有测试目标和资产信息**
- IP地址、域名、URL、端口等
- 发现的服务、技术栈、版本信息
- 资产之间的关联关系
4. **所有测试步骤和决策过程**
- 每个测试步骤的详细描述(做了什么、为什么做、结果如何)
- AI的分析思路和决策依据
- 失败尝试的原因和从中获得的线索
5. **所有关键发现和线索**
- 成功发现的详细信息
- 失败但提供线索的尝试(错误信息、限制条件、下一步建议等)
- 收集到的任何有价值的信息(凭据、令牌、配置信息等)
**总结要求:**
- 用结构化的方式组织信息,按时间顺序或逻辑顺序排列
- 对于每个工具执行,必须包含:工具名、目标、参数、结果/错误、关键发现
- 对于每个漏洞,必须包含:类型、位置、严重程度、验证结果
- 保留所有技术细节,不要过度简化
- 确保后续AI能够根据这个摘要完整重建攻击链
对话片段:
%s
请给出详细且结构化的技术摘要建议1000-2000字确保信息完整`
// 检查prompt tokens如果超过限制需要进一步压缩内容
maxContextLength := b.getModelMaxContextLength()
maxPromptTokens := maxContextLength - 2000 // 留出空间给响应和系统消息
// 尝试构建完整prompt并检查tokens
fullPrompt := fmt.Sprintf(promptTemplate, content)
promptTokens, err := b.countTextTokens(fullPrompt)
if err != nil {
return fmt.Errorf("总结内容失败: %w", err)
// 如果计算失败,使用估算
promptTokens = len(fullPrompt) / 4
}
// 保存总结
contextData.SummarizedItems[longestID] = summary
// 如果prompt太大需要进一步压缩内容
if promptTokens > maxPromptTokens {
b.logger.Warn("片段内容过大,需要进一步压缩",
zap.Int("promptTokens", promptTokens),
zap.Int("maxPromptTokens", maxPromptTokens))
b.logger.Info("压缩完成",
zap.String("id", longestID),
zap.Int("originalLength", maxLength),
zap.Int("summaryLength", len(summary)))
// 递归压缩将chunk进一步分割
compressedContent, err := b.compressLargeChunk(ctx, chunk, maxPromptTokens)
if err != nil {
return "", fmt.Errorf("压缩大片段失败: %w", err)
}
content = compressedContent
}
return nil
prompt := fmt.Sprintf(promptTemplate, content)
// 检查配置
if b.openAIConfig == nil {
return "", fmt.Errorf("OpenAI配置未初始化")
}
if b.openAIConfig.APIKey == "" {
return "", fmt.Errorf("OpenAI API Key未配置")
}
if b.openAIConfig.Model == "" {
return "", fmt.Errorf("OpenAI Model未配置")
}
// 直接调用AI API进行总结
requestBody := map[string]interface{}{
"model": b.openAIConfig.Model,
"messages": []map[string]interface{}{
{
"role": "system",
"content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试对话片段,这些摘要将用于后续构建完整的攻击链图。
**你的专业背景:**
- 精通各种安全测试工具Nmap、SQLMap、Burp Suite、Metasploit等的使用和结果分析
- 熟悉常见漏洞类型SQL注入、XSS、文件上传、命令执行、目录遍历等的识别和验证
- 理解攻击链的构建逻辑:从信息收集 → 漏洞发现 → 漏洞利用 → 权限提升 → 横向移动
- 能够识别失败尝试中的有价值线索错误信息、状态码、WAF指纹、技术栈信息等
**你的总结原则:**
1. **完整性优先**虽然需要压缩但必须保留所有技术细节确保后续AI能够完整重建攻击链
2. **结构化组织**:按时间顺序或逻辑顺序组织信息,让信息易于理解和追踪
3. **技术精准**使用准确的技术术语保留具体的数值、版本号、端口号、URL等关键数据
4. **上下文关联**:保留测试步骤之间的因果关系和逻辑关联
5. **失败价值**:即使是失败的尝试,只要提供了线索(错误信息、限制条件、下一步建议),也要详细记录
**你需要特别关注的信息类型:**
- 工具执行:工具名、目标、参数、完整结果(包括错误和失败)
- 漏洞发现:类型、位置、严重程度、验证方法、利用结果
- 资产信息IP、域名、端口、服务版本、技术栈
- 测试策略:为什么选择这个工具、为什么测试这个目标、发现了什么线索
- 关键数据:凭据、令牌、配置信息、敏感文件内容
请用专业、详细、结构化的中文进行总结,确保信息完整且易于后续处理。`,
},
{
"role": "user",
"content": prompt,
},
},
"temperature": 0.3,
"max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
resp, err := b.openAIClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
}
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
}
if len(apiResponse.Choices) == 0 {
return "", fmt.Errorf("API未返回有效响应")
}
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
}
// compressLongestItem 压缩最长的子节点(保留作为备用方法)
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
// 使用新的分片压缩方法
return b.compressContextData(ctx, contextData)
}
// summarizeContent 总结内容
@@ -675,8 +1044,28 @@ AI回复
"model": b.openAIConfig.Model,
"messages": []map[string]interface{}{
{
"role": "system",
"content": "你是一个专业的安全测试分析师,擅长总结安全测试相关的信息。请用简洁的中文总结关键信息。",
"role": "system",
"content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试过程中的关键信息,这些摘要将用于构建攻击链图。
**你的专业背景:**
- 精通各种安全测试工具的使用和结果分析Nmap、SQLMap、Burp Suite、Metasploit、Nuclei等
- 熟悉常见漏洞类型的识别和验证SQL注入、XSS、文件上传、命令执行、目录遍历、SSRF等
- 理解攻击链的构建逻辑和测试流程
- 能够识别失败尝试中的有价值线索
**你的总结原则:**
1. **保留技术细节**:保留所有重要的技术信息,包括工具名、参数、结果、错误信息、状态码等
2. **突出关键发现**:重点记录发现的漏洞、安全问题、资产信息、凭据等
3. **记录失败线索**:即使是失败的尝试,如果提供了错误信息、限制条件或下一步建议,也要详细记录
4. **保持准确性**:使用准确的技术术语,保留具体的数值、版本号、端口号等关键数据
5. **结构化表达**:用清晰、有条理的方式组织信息
**根据内容类型,你需要特别关注:**
- **AI回复**:提取安全发现、漏洞信息、测试结果、分析思路、决策依据
- **工具执行**:记录工具名、目标、参数、完整结果(成功或失败)、关键发现、错误信息
- **思考过程**:提取关键决策点、测试策略、分析思路、下一步计划
请用专业、准确、简洁的中文进行总结,确保信息完整且易于理解。`,
},
{
"role": "user",
@@ -730,6 +1119,146 @@ AI回复
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
}
// buildChunkContent 构建chunk的文本内容
func (b *Builder) buildChunkContent(chunk *ContextChunk) (string, error) {
var contentBuilder strings.Builder
// 添加消息
for _, msg := range chunk.Messages {
if strings.EqualFold(msg.Role, "user") {
contentBuilder.WriteString(fmt.Sprintf("用户消息: %s\n\n", msg.Content))
} else {
contentBuilder.WriteString(fmt.Sprintf("AI回复: %s\n\n", msg.Content))
}
}
// 添加工具执行
for _, exec := range chunk.Executions {
contentBuilder.WriteString(fmt.Sprintf("工具执行 [%s] (ID: %s):\n", exec.ToolName, exec.ID))
contentBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments)))
if exec.Error != "" {
contentBuilder.WriteString(fmt.Sprintf("错误: %s\n", exec.Error))
}
if exec.Result != nil {
var resultText string
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text + "\n"
}
}
if resultText != "" {
// 如果结果太长,截断
if len(resultText) > 10000 {
resultText = resultText[:10000] + "\n... [内容已截断]"
}
contentBuilder.WriteString(fmt.Sprintf("结果: %s\n", resultText))
}
}
contentBuilder.WriteString("\n")
}
// 添加思考过程
for _, details := range chunk.ProcessDetails {
for _, detail := range details {
if detail.EventType == "thinking" {
thinkingText := detail.Message
// 如果思考过程太长,截断
if len(thinkingText) > 5000 {
thinkingText = thinkingText[:5000] + "\n... [内容已截断]"
}
contentBuilder.WriteString(fmt.Sprintf("思考过程: %s\n\n", thinkingText))
}
}
}
content := contentBuilder.String()
if content == "" {
return "", fmt.Errorf("片段内容为空")
}
return content, nil
}
// compressLargeChunk 压缩过大的chunk递归分割
func (b *Builder) compressLargeChunk(ctx context.Context, chunk *ContextChunk, maxTokens int) (string, error) {
// 将chunk进一步分割成更小的子chunk
// 简单策略:按消息和执行数量平均分割
totalItems := len(chunk.Messages) + len(chunk.Executions)
if totalItems <= 1 {
// 如果只有一个项目,直接截断内容
content, _ := b.buildChunkContent(chunk)
if len(content) > maxTokens*4 { // 粗略估算1 token ≈ 4字符
content = content[:maxTokens*4] + "\n... [内容过大,已截断]"
}
return content, nil
}
// 分成2个子chunk
mid := totalItems / 2
subChunk1 := &ContextChunk{
Messages: []database.Message{},
Executions: []*mcp.ToolExecution{},
ProcessDetails: make(map[string][]database.ProcessDetail),
}
subChunk2 := &ContextChunk{
Messages: []database.Message{},
Executions: []*mcp.ToolExecution{},
ProcessDetails: make(map[string][]database.ProcessDetail),
}
// 分配消息
for i, msg := range chunk.Messages {
if i < mid {
subChunk1.Messages = append(subChunk1.Messages, msg)
} else {
subChunk2.Messages = append(subChunk2.Messages, msg)
}
}
// 分配执行
execStart := len(chunk.Messages)
for i, exec := range chunk.Executions {
if execStart+i < mid {
subChunk1.Executions = append(subChunk1.Executions, exec)
} else {
subChunk2.Executions = append(subChunk2.Executions, exec)
}
}
// 递归压缩子chunk
summary1, err := b.summarizeContextChunk(ctx, subChunk1)
if err != nil {
return "", fmt.Errorf("压缩子chunk1失败: %w", err)
}
summary2, err := b.summarizeContextChunk(ctx, subChunk2)
if err != nil {
return "", fmt.Errorf("压缩子chunk2失败: %w", err)
}
// 合并摘要
return fmt.Sprintf("片段1摘要\n%s\n\n---\n\n片段2摘要\n%s", summary1, summary2), nil
}
// countTextTokens 计算文本的tokens数
func (b *Builder) countTextTokens(text string) (int, error) {
if b.tokenCounter == nil || b.openAIConfig == nil {
return len(text) / 4, nil
}
model := b.openAIConfig.Model
if model == "" {
model = "gpt-4"
}
count, err := b.tokenCounter.Count(model, text)
if err != nil {
return len(text) / 4, nil
}
return count, nil
}
// callAIForChainGeneration 调用AI生成攻击链
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
requestBody := map[string]interface{}{
@@ -985,74 +1514,143 @@ func (b *Builder) shouldFilterNode(n struct {
return true
}
// 对于action节点检查对应的工具执行是否有效
if n.Type == "action" {
if n.ToolExecutionID == "" {
// 没有关联工具执行的action节点可能是无效的
return true
}
// 查找对应的工具执行
var exec *mcp.ToolExecution
for _, e := range executions {
if e.ID == n.ToolExecutionID {
exec = e
break
}
}
if exec == nil {
// 找不到对应的工具执行,可能是无效的
return true
}
// 检查工具执行是否错误或失败
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
if !hasInsightfulFailure(n.Metadata) {
return true
}
}
// 检查工具执行结果是否为空
if exec.Result == nil || len(exec.Result.Content) == 0 {
if !hasInsightfulFailure(n.Metadata) {
return true
}
}
// 检查结果文本是否为空
var resultText string
if exec.Result != nil {
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text
}
}
}
if strings.TrimSpace(resultText) == "" {
if !hasInsightfulFailure(n.Metadata) {
return true
}
}
}
// 检查节点标签是否为空或无效
if strings.TrimSpace(n.Label) == "" {
return true
}
// 检查标签中是否包含错误/失败的关键词
labelLower := strings.ToLower(n.Label)
errorKeywords := []string{"错误", "失败", "无效", "error", "failed", "invalid", "empty", "空"}
for _, keyword := range errorKeywords {
if strings.Contains(labelLower, keyword) {
// 如果标签明确表示错误但节点类型不是vulnerability则过滤
if n.Type != "vulnerability" {
return true
// 对于vulnerability节点即使没有tool_execution_id也应该保留漏洞可能不是直接来自工具执行
if n.Type == "vulnerability" {
// 只要标签有意义就保留
return false
}
// 对于target节点只要标签有意义就保留
if n.Type == "target" {
return false
}
// 对于action节点进行更宽松的检查
if n.Type == "action" {
// 如果executions为空可能是压缩后的场景只要标签有意义就保留
if len(executions) == 0 {
// 压缩场景下,只要标签不是明显无效就保留
labelLower := strings.ToLower(n.Label)
// 只过滤明显无效的标签
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
for _, keyword := range invalidKeywords {
if strings.Contains(labelLower, keyword) {
return true
}
}
return false
}
// 如果有tool_execution_id尝试查找对应的工具执行
if n.ToolExecutionID != "" {
var exec *mcp.ToolExecution
for _, e := range executions {
if e.ID == n.ToolExecutionID {
exec = e
break
}
}
if exec != nil {
// 找到了对应的工具执行,检查是否有效
// 检查工具执行是否错误或失败
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
// 失败但有线索的应该保留
if !hasInsightfulFailure(n.Metadata) {
// 即使没有明确标记为有线索,如果标签描述了具体内容,也保留
labelLower := strings.ToLower(n.Label)
// 如果标签包含具体的技术信息(端口、服务、漏洞等),说明有价值
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
hasValuableInfo := false
for _, keyword := range valuableKeywords {
if strings.Contains(labelLower, keyword) {
hasValuableInfo = true
break
}
}
if !hasValuableInfo {
return true
}
}
}
// 检查工具执行结果是否为空
if exec.Result == nil || len(exec.Result.Content) == 0 {
// 结果为空,但如果有线索或标签有意义,也保留
if !hasInsightfulFailure(n.Metadata) {
labelLower := strings.ToLower(n.Label)
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
hasValuableInfo := false
for _, keyword := range valuableKeywords {
if strings.Contains(labelLower, keyword) {
hasValuableInfo = true
break
}
}
if !hasValuableInfo {
return true
}
}
} else {
// 检查结果文本是否为空
var resultText string
for _, content := range exec.Result.Content {
if content.Type == "text" {
resultText += content.Text
}
}
if strings.TrimSpace(resultText) == "" {
// 结果文本为空,但如果有线索或标签有意义,也保留
if !hasInsightfulFailure(n.Metadata) {
labelLower := strings.ToLower(n.Label)
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
hasValuableInfo := false
for _, keyword := range valuableKeywords {
if strings.Contains(labelLower, keyword) {
hasValuableInfo = true
break
}
}
if !hasValuableInfo {
return true
}
}
}
}
} else {
// 找不到对应的工具执行,但可能是压缩后的场景
// 只要标签有意义就保留不要因为找不到execution就过滤掉
labelLower := strings.ToLower(n.Label)
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
for _, keyword := range invalidKeywords {
if strings.Contains(labelLower, keyword) {
return true
}
}
// 标签有意义,保留
return false
}
} else {
// 没有tool_execution_id但可能是压缩后的场景或AI生成的节点
// 只要标签有意义就保留
labelLower := strings.ToLower(n.Label)
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
for _, keyword := range invalidKeywords {
if strings.Contains(labelLower, keyword) {
return true
}
}
// 标签有意义,保留
return false
}
}
// 默认保留(已经通过了所有检查)
return false
}