mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-04 18:12:34 +02:00
Add files via upload
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -25,6 +26,8 @@ type Builder struct {
|
||||
logger *zap.Logger
|
||||
openAIClient *http.Client
|
||||
openAIConfig *config.OpenAIConfig
|
||||
tokenCounter agent.TokenCounter
|
||||
maxTokens int // 最大tokens限制,默认100000
|
||||
}
|
||||
|
||||
// Node 攻击链节点(使用database包的类型)
|
||||
@@ -47,11 +50,26 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
maxTokens := 100000 // 默认100k tokens,可以根据模型调整
|
||||
// 根据模型设置合理的默认值
|
||||
if openAIConfig != nil {
|
||||
model := strings.ToLower(openAIConfig.Model)
|
||||
if strings.Contains(model, "gpt-4") {
|
||||
maxTokens = 128000 // gpt-4通常支持128k
|
||||
} else if strings.Contains(model, "gpt-3.5") {
|
||||
maxTokens = 16000 // gpt-3.5-turbo通常支持16k
|
||||
} else if strings.Contains(model, "deepseek") {
|
||||
maxTokens = 131072 // deepseek-chat通常支持131k
|
||||
}
|
||||
}
|
||||
|
||||
return &Builder{
|
||||
db: db,
|
||||
logger: logger,
|
||||
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
|
||||
openAIConfig: openAIConfig,
|
||||
tokenCounter: agent.NewTikTokenCounter(),
|
||||
maxTokens: maxTokens,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,6 +230,17 @@ func (b *Builder) prepareContextData(messages []database.Message, executions []*
|
||||
|
||||
// generateChainWithRetry 生成攻击链(带重试和压缩机制)
|
||||
func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) {
|
||||
// 在第一次尝试前,先检查tokens并压缩(如果需要)
|
||||
totalTokens, err := b.countPromptTokens(contextData)
|
||||
if err == nil && totalTokens > b.maxTokens {
|
||||
b.logger.Info("检测到tokens超过限制,提前压缩",
|
||||
zap.Int("totalTokens", totalTokens),
|
||||
zap.Int("maxTokens", b.maxTokens))
|
||||
if err := b.compressContextData(ctx, contextData); err != nil {
|
||||
return nil, fmt.Errorf("压缩上下文失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
b.logger.Info("尝试生成攻击链",
|
||||
zap.Int("attempt", attempt+1),
|
||||
@@ -232,8 +261,8 @@ func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *Conte
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Error(err))
|
||||
|
||||
// 压缩最长的子节点
|
||||
if err := b.compressLongestItem(ctx, contextData); err != nil {
|
||||
// 使用分片压缩
|
||||
if err := b.compressContextData(ctx, contextData); err != nil {
|
||||
return nil, fmt.Errorf("压缩上下文失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -552,94 +581,434 @@ func (b *Builder) formatArguments(args map[string]interface{}) string {
|
||||
return string(jsonData)
|
||||
}
|
||||
|
||||
// compressLongestItem 压缩最长的子节点
|
||||
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
|
||||
var longestID string
|
||||
var longestType string
|
||||
var longestContent string
|
||||
maxLength := 0
|
||||
// countPromptTokens 计算prompt的总tokens数
|
||||
func (b *Builder) countPromptTokens(contextData *ContextData) (int, error) {
|
||||
prompt, err := b.buildChainGenerationPrompt(contextData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("构建提示词失败: %w", err)
|
||||
}
|
||||
|
||||
// 查找最长的消息
|
||||
if b.tokenCounter == nil || b.openAIConfig == nil {
|
||||
// 如果没有token计数器或配置,使用简单的估算(4个字符=1个token)
|
||||
return len(prompt) / 4, nil
|
||||
}
|
||||
|
||||
model := b.openAIConfig.Model
|
||||
if model == "" {
|
||||
model = "gpt-4" // 默认模型
|
||||
}
|
||||
|
||||
count, err := b.tokenCounter.Count(model, prompt)
|
||||
if err != nil {
|
||||
// 如果计算失败,使用估算
|
||||
return len(prompt) / 4, nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// compressContextData 使用分片压缩方式压缩上下文数据
|
||||
func (b *Builder) compressContextData(ctx context.Context, contextData *ContextData) error {
|
||||
// 计算当前tokens
|
||||
totalTokens, err := b.countPromptTokens(contextData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("计算tokens失败: %w", err)
|
||||
}
|
||||
|
||||
b.logger.Info("开始压缩上下文",
|
||||
zap.Int("totalTokens", totalTokens),
|
||||
zap.Int("maxTokens", b.maxTokens))
|
||||
|
||||
// 如果tokens在限制内,不需要压缩
|
||||
if totalTokens <= b.maxTokens {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 计算需要分成多少份
|
||||
numChunks := (totalTokens + b.maxTokens - 1) / b.maxTokens // 向上取整
|
||||
if numChunks < 2 {
|
||||
numChunks = 2 // 至少分成2份
|
||||
}
|
||||
|
||||
b.logger.Info("将上下文分成多个片段进行压缩",
|
||||
zap.Int("totalTokens", totalTokens),
|
||||
zap.Int("maxTokens", b.maxTokens),
|
||||
zap.Int("numChunks", numChunks))
|
||||
|
||||
// 按时间顺序将数据分成多个片段
|
||||
chunks, err := b.splitContextDataByTime(contextData, numChunks)
|
||||
if err != nil {
|
||||
return fmt.Errorf("分割上下文数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 对每个片段进行摘要
|
||||
summaries := make([]string, 0, len(chunks))
|
||||
for i, chunk := range chunks {
|
||||
b.logger.Info("压缩片段",
|
||||
zap.Int("chunkIndex", i+1),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.Int("chunkSize", len(chunk.Messages)+len(chunk.Executions)))
|
||||
|
||||
summary, err := b.summarizeContextChunk(ctx, chunk)
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "Authentication") || strings.Contains(err.Error(), "api key") || strings.Contains(err.Error(), "invalid") {
|
||||
return fmt.Errorf("压缩片段%d失败(API认证错误,请检查OpenAI配置): %w", i+1, err)
|
||||
}
|
||||
return fmt.Errorf("压缩片段%d失败: %w", i+1, err)
|
||||
}
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
|
||||
// 将摘要合并到contextData中
|
||||
// 保留用户消息,清空其他数据,用摘要替换
|
||||
var userMessages []database.Message
|
||||
for _, msg := range contextData.Messages {
|
||||
if strings.EqualFold(msg.Role, "user") {
|
||||
continue
|
||||
}
|
||||
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
|
||||
continue
|
||||
}
|
||||
length := len(msg.Content)
|
||||
if length > maxLength {
|
||||
maxLength = length
|
||||
longestID = msg.ID
|
||||
longestType = "message"
|
||||
longestContent = msg.Content
|
||||
userMessages = append(userMessages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// 查找最长的工具执行结果
|
||||
// 清空非用户消息和执行记录
|
||||
contextData.Messages = userMessages
|
||||
contextData.Executions = []*mcp.ToolExecution{}
|
||||
contextData.ProcessDetails = make(map[string][]database.ProcessDetail)
|
||||
|
||||
// 创建一个综合摘要消息
|
||||
combinedSummary := strings.Join(summaries, "\n\n---\n\n")
|
||||
summaryMsg := database.Message{
|
||||
ID: uuid.New().String(),
|
||||
Role: "assistant",
|
||||
Content: fmt.Sprintf("[上下文摘要 - 包含%d个片段的压缩内容]\n\n%s", len(summaries), combinedSummary),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
contextData.Messages = append(contextData.Messages, summaryMsg)
|
||||
|
||||
// 检查压缩后的tokens
|
||||
compressedTokens, err := b.countPromptTokens(contextData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("计算压缩后tokens失败: %w", err)
|
||||
}
|
||||
|
||||
b.logger.Info("压缩完成",
|
||||
zap.Int("originalTokens", totalTokens),
|
||||
zap.Int("compressedTokens", compressedTokens),
|
||||
zap.Int("reduction", totalTokens-compressedTokens))
|
||||
|
||||
// 如果压缩后仍然超过限制,递归压缩
|
||||
if compressedTokens > b.maxTokens {
|
||||
b.logger.Info("压缩后仍然超过限制,继续递归压缩",
|
||||
zap.Int("compressedTokens", compressedTokens),
|
||||
zap.Int("maxTokens", b.maxTokens))
|
||||
return b.compressContextData(ctx, contextData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextChunk 上下文数据片段
|
||||
type ContextChunk struct {
|
||||
Messages []database.Message
|
||||
Executions []*mcp.ToolExecution
|
||||
ProcessDetails map[string][]database.ProcessDetail
|
||||
}
|
||||
|
||||
// splitContextDataByTime 按时间顺序将上下文数据分成多个片段
|
||||
func (b *Builder) splitContextDataByTime(contextData *ContextData, numChunks int) ([]*ContextChunk, error) {
|
||||
if numChunks <= 0 {
|
||||
return nil, fmt.Errorf("片段数量必须大于0")
|
||||
}
|
||||
|
||||
// 收集所有带时间戳的项目
|
||||
type timeItem struct {
|
||||
time time.Time
|
||||
itemType string // "message", "execution", "thinking"
|
||||
message *database.Message
|
||||
execution *mcp.ToolExecution
|
||||
processDetail *database.ProcessDetail
|
||||
}
|
||||
|
||||
var items []timeItem
|
||||
|
||||
// 添加消息(跳过已总结的)
|
||||
for i := range contextData.Messages {
|
||||
msg := &contextData.Messages[i]
|
||||
if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized {
|
||||
continue
|
||||
}
|
||||
items = append(items, timeItem{
|
||||
time: msg.CreatedAt,
|
||||
itemType: "message",
|
||||
message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
// 添加工具执行(跳过已总结的)
|
||||
for _, exec := range contextData.Executions {
|
||||
if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized {
|
||||
continue
|
||||
}
|
||||
if exec.Result != nil {
|
||||
var resultText string
|
||||
for _, content := range exec.Result.Content {
|
||||
if content.Type == "text" {
|
||||
resultText += content.Text + "\n"
|
||||
}
|
||||
}
|
||||
length := len(resultText)
|
||||
if length > maxLength {
|
||||
maxLength = length
|
||||
longestID = exec.ID
|
||||
longestType = "execution"
|
||||
longestContent = resultText
|
||||
}
|
||||
}
|
||||
items = append(items, timeItem{
|
||||
time: exec.StartTime,
|
||||
itemType: "execution",
|
||||
execution: exec,
|
||||
})
|
||||
}
|
||||
|
||||
// 查找最长的思考过程
|
||||
// 添加思考过程(跳过已总结的)
|
||||
for _, details := range contextData.ProcessDetails {
|
||||
for _, detail := range details {
|
||||
for i := range details {
|
||||
detail := &details[i]
|
||||
if detail.EventType == "thinking" {
|
||||
if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized {
|
||||
continue
|
||||
}
|
||||
length := len(detail.Message)
|
||||
if length > maxLength {
|
||||
maxLength = length
|
||||
longestID = detail.ID
|
||||
longestType = "thinking"
|
||||
longestContent = detail.Message
|
||||
}
|
||||
items = append(items, timeItem{
|
||||
time: detail.CreatedAt,
|
||||
itemType: "thinking",
|
||||
processDetail: detail,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if longestID == "" {
|
||||
return fmt.Errorf("没有找到需要压缩的内容")
|
||||
if len(items) == 0 {
|
||||
return nil, fmt.Errorf("没有可分割的数据")
|
||||
}
|
||||
|
||||
b.logger.Info("压缩最长子节点",
|
||||
zap.String("id", longestID),
|
||||
zap.String("type", longestType),
|
||||
zap.Int("length", maxLength))
|
||||
// 按时间排序
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].time.Before(items[j].time)
|
||||
})
|
||||
|
||||
// 计算每个片段的大小
|
||||
chunkSize := (len(items) + numChunks - 1) / numChunks // 向上取整
|
||||
|
||||
// 创建片段
|
||||
chunks := make([]*ContextChunk, 0, numChunks)
|
||||
for i := 0; i < len(items); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(items) {
|
||||
end = len(items)
|
||||
}
|
||||
|
||||
chunk := &ContextChunk{
|
||||
Messages: []database.Message{},
|
||||
Executions: []*mcp.ToolExecution{},
|
||||
ProcessDetails: make(map[string][]database.ProcessDetail),
|
||||
}
|
||||
|
||||
for j := i; j < end; j++ {
|
||||
item := items[j]
|
||||
switch item.itemType {
|
||||
case "message":
|
||||
chunk.Messages = append(chunk.Messages, *item.message)
|
||||
case "execution":
|
||||
chunk.Executions = append(chunk.Executions, item.execution)
|
||||
case "thinking":
|
||||
if item.processDetail != nil {
|
||||
msgID := item.processDetail.MessageID
|
||||
chunk.ProcessDetails[msgID] = append(chunk.ProcessDetails[msgID], *item.processDetail)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// getModelMaxContextLength 获取模型的最大上下文长度
|
||||
func (b *Builder) getModelMaxContextLength() int {
|
||||
if b.openAIConfig == nil {
|
||||
return 131072 // 默认值
|
||||
}
|
||||
model := strings.ToLower(b.openAIConfig.Model)
|
||||
if strings.Contains(model, "gpt-4") {
|
||||
return 128000
|
||||
} else if strings.Contains(model, "gpt-3.5") {
|
||||
return 16000
|
||||
} else if strings.Contains(model, "deepseek") {
|
||||
return 131072
|
||||
}
|
||||
return 131072 // 默认值
|
||||
}
|
||||
|
||||
// summarizeContextChunk 总结一个上下文片段
|
||||
func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk) (string, error) {
|
||||
// 先构建内容
|
||||
content, err := b.buildChunkContent(chunk)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 使用AI总结
|
||||
summary, err := b.summarizeContent(ctx, longestType, longestContent)
|
||||
promptTemplate := `请详细总结以下安全测试对话片段的关键信息。虽然需要压缩内容,但必须保留所有重要的技术细节和上下文信息,确保后续攻击链生成时能够准确理解整个测试过程。
|
||||
|
||||
**必须详细保留的内容:**
|
||||
1. **所有工具执行记录**:
|
||||
- 工具名称、执行参数、执行结果(包括成功和失败)
|
||||
- 失败执行的错误信息、状态码、响应头等关键线索
|
||||
- 工具输出的关键数据(端口、服务版本、漏洞信息等)
|
||||
- 每个工具执行的时间顺序和上下文关系
|
||||
|
||||
2. **所有发现的漏洞和潜在安全问题**:
|
||||
- 漏洞类型、严重程度、位置、利用方式
|
||||
- 验证过程和结果
|
||||
- 漏洞之间的关联关系
|
||||
|
||||
3. **所有测试目标和资产信息**:
|
||||
- IP地址、域名、URL、端口等
|
||||
- 发现的服务、技术栈、版本信息
|
||||
- 资产之间的关联关系
|
||||
|
||||
4. **所有测试步骤和决策过程**:
|
||||
- 每个测试步骤的详细描述(做了什么、为什么做、结果如何)
|
||||
- AI的分析思路和决策依据
|
||||
- 失败尝试的原因和从中获得的线索
|
||||
|
||||
5. **所有关键发现和线索**:
|
||||
- 成功发现的详细信息
|
||||
- 失败但提供线索的尝试(错误信息、限制条件、下一步建议等)
|
||||
- 收集到的任何有价值的信息(凭据、令牌、配置信息等)
|
||||
|
||||
**总结要求:**
|
||||
- 用结构化的方式组织信息,按时间顺序或逻辑顺序排列
|
||||
- 对于每个工具执行,必须包含:工具名、目标、参数、结果/错误、关键发现
|
||||
- 对于每个漏洞,必须包含:类型、位置、严重程度、验证结果
|
||||
- 保留所有技术细节,不要过度简化
|
||||
- 确保后续AI能够根据这个摘要完整重建攻击链
|
||||
|
||||
对话片段:
|
||||
%s
|
||||
|
||||
请给出详细且结构化的技术摘要(建议1000-2000字,确保信息完整):`
|
||||
|
||||
// 检查prompt tokens,如果超过限制,需要进一步压缩内容
|
||||
maxContextLength := b.getModelMaxContextLength()
|
||||
maxPromptTokens := maxContextLength - 2000 // 留出空间给响应和系统消息
|
||||
|
||||
// 尝试构建完整prompt并检查tokens
|
||||
fullPrompt := fmt.Sprintf(promptTemplate, content)
|
||||
promptTokens, err := b.countTextTokens(fullPrompt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("总结内容失败: %w", err)
|
||||
// 如果计算失败,使用估算
|
||||
promptTokens = len(fullPrompt) / 4
|
||||
}
|
||||
|
||||
// 保存总结
|
||||
contextData.SummarizedItems[longestID] = summary
|
||||
// 如果prompt太大,需要进一步压缩内容
|
||||
if promptTokens > maxPromptTokens {
|
||||
b.logger.Warn("片段内容过大,需要进一步压缩",
|
||||
zap.Int("promptTokens", promptTokens),
|
||||
zap.Int("maxPromptTokens", maxPromptTokens))
|
||||
|
||||
b.logger.Info("压缩完成",
|
||||
zap.String("id", longestID),
|
||||
zap.Int("originalLength", maxLength),
|
||||
zap.Int("summaryLength", len(summary)))
|
||||
// 递归压缩:将chunk进一步分割
|
||||
compressedContent, err := b.compressLargeChunk(ctx, chunk, maxPromptTokens)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("压缩大片段失败: %w", err)
|
||||
}
|
||||
content = compressedContent
|
||||
}
|
||||
|
||||
return nil
|
||||
prompt := fmt.Sprintf(promptTemplate, content)
|
||||
|
||||
// 检查配置
|
||||
if b.openAIConfig == nil {
|
||||
return "", fmt.Errorf("OpenAI配置未初始化")
|
||||
}
|
||||
if b.openAIConfig.APIKey == "" {
|
||||
return "", fmt.Errorf("OpenAI API Key未配置")
|
||||
}
|
||||
if b.openAIConfig.Model == "" {
|
||||
return "", fmt.Errorf("OpenAI Model未配置")
|
||||
}
|
||||
|
||||
// 直接调用AI API进行总结
|
||||
requestBody := map[string]interface{}{
|
||||
"model": b.openAIConfig.Model,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试对话片段,这些摘要将用于后续构建完整的攻击链图。
|
||||
|
||||
**你的专业背景:**
|
||||
- 精通各种安全测试工具(Nmap、SQLMap、Burp Suite、Metasploit等)的使用和结果分析
|
||||
- 熟悉常见漏洞类型(SQL注入、XSS、文件上传、命令执行、目录遍历等)的识别和验证
|
||||
- 理解攻击链的构建逻辑:从信息收集 → 漏洞发现 → 漏洞利用 → 权限提升 → 横向移动
|
||||
- 能够识别失败尝试中的有价值线索(错误信息、状态码、WAF指纹、技术栈信息等)
|
||||
|
||||
**你的总结原则:**
|
||||
1. **完整性优先**:虽然需要压缩,但必须保留所有技术细节,确保后续AI能够完整重建攻击链
|
||||
2. **结构化组织**:按时间顺序或逻辑顺序组织信息,让信息易于理解和追踪
|
||||
3. **技术精准**:使用准确的技术术语,保留具体的数值、版本号、端口号、URL等关键数据
|
||||
4. **上下文关联**:保留测试步骤之间的因果关系和逻辑关联
|
||||
5. **失败价值**:即使是失败的尝试,只要提供了线索(错误信息、限制条件、下一步建议),也要详细记录
|
||||
|
||||
**你需要特别关注的信息类型:**
|
||||
- 工具执行:工具名、目标、参数、完整结果(包括错误和失败)
|
||||
- 漏洞发现:类型、位置、严重程度、验证方法、利用结果
|
||||
- 资产信息:IP、域名、端口、服务版本、技术栈
|
||||
- 测试策略:为什么选择这个工具、为什么测试这个目标、发现了什么线索
|
||||
- 关键数据:凭据、令牌、配置信息、敏感文件内容
|
||||
|
||||
请用专业、详细、结构化的中文进行总结,确保信息完整且易于后续处理。`,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||||
|
||||
resp, err := b.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return "", fmt.Errorf("API未返回有效响应")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
// compressLongestItem 压缩最长的子节点(保留作为备用方法)
|
||||
func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error {
|
||||
// 使用新的分片压缩方法
|
||||
return b.compressContextData(ctx, contextData)
|
||||
}
|
||||
|
||||
// summarizeContent 总结内容
|
||||
@@ -675,8 +1044,28 @@ AI回复:
|
||||
"model": b.openAIConfig.Model,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的安全测试分析师,擅长总结安全测试相关的信息。请用简洁的中文总结关键信息。",
|
||||
"role": "system",
|
||||
"content": `你是一个资深的安全测试分析师和渗透测试专家,拥有丰富的实战经验。你的任务是总结安全测试过程中的关键信息,这些摘要将用于构建攻击链图。
|
||||
|
||||
**你的专业背景:**
|
||||
- 精通各种安全测试工具的使用和结果分析(Nmap、SQLMap、Burp Suite、Metasploit、Nuclei等)
|
||||
- 熟悉常见漏洞类型的识别和验证(SQL注入、XSS、文件上传、命令执行、目录遍历、SSRF等)
|
||||
- 理解攻击链的构建逻辑和测试流程
|
||||
- 能够识别失败尝试中的有价值线索
|
||||
|
||||
**你的总结原则:**
|
||||
1. **保留技术细节**:保留所有重要的技术信息,包括工具名、参数、结果、错误信息、状态码等
|
||||
2. **突出关键发现**:重点记录发现的漏洞、安全问题、资产信息、凭据等
|
||||
3. **记录失败线索**:即使是失败的尝试,如果提供了错误信息、限制条件或下一步建议,也要详细记录
|
||||
4. **保持准确性**:使用准确的技术术语,保留具体的数值、版本号、端口号等关键数据
|
||||
5. **结构化表达**:用清晰、有条理的方式组织信息
|
||||
|
||||
**根据内容类型,你需要特别关注:**
|
||||
- **AI回复**:提取安全发现、漏洞信息、测试结果、分析思路、决策依据
|
||||
- **工具执行**:记录工具名、目标、参数、完整结果(成功或失败)、关键发现、错误信息
|
||||
- **思考过程**:提取关键决策点、测试策略、分析思路、下一步计划
|
||||
|
||||
请用专业、准确、简洁的中文进行总结,确保信息完整且易于理解。`,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -730,6 +1119,146 @@ AI回复:
|
||||
return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
// buildChunkContent 构建chunk的文本内容
|
||||
func (b *Builder) buildChunkContent(chunk *ContextChunk) (string, error) {
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// 添加消息
|
||||
for _, msg := range chunk.Messages {
|
||||
if strings.EqualFold(msg.Role, "user") {
|
||||
contentBuilder.WriteString(fmt.Sprintf("用户消息: %s\n\n", msg.Content))
|
||||
} else {
|
||||
contentBuilder.WriteString(fmt.Sprintf("AI回复: %s\n\n", msg.Content))
|
||||
}
|
||||
}
|
||||
|
||||
// 添加工具执行
|
||||
for _, exec := range chunk.Executions {
|
||||
contentBuilder.WriteString(fmt.Sprintf("工具执行 [%s] (ID: %s):\n", exec.ToolName, exec.ID))
|
||||
contentBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments)))
|
||||
|
||||
if exec.Error != "" {
|
||||
contentBuilder.WriteString(fmt.Sprintf("错误: %s\n", exec.Error))
|
||||
}
|
||||
|
||||
if exec.Result != nil {
|
||||
var resultText string
|
||||
for _, content := range exec.Result.Content {
|
||||
if content.Type == "text" {
|
||||
resultText += content.Text + "\n"
|
||||
}
|
||||
}
|
||||
if resultText != "" {
|
||||
// 如果结果太长,截断
|
||||
if len(resultText) > 10000 {
|
||||
resultText = resultText[:10000] + "\n... [内容已截断]"
|
||||
}
|
||||
contentBuilder.WriteString(fmt.Sprintf("结果: %s\n", resultText))
|
||||
}
|
||||
}
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// 添加思考过程
|
||||
for _, details := range chunk.ProcessDetails {
|
||||
for _, detail := range details {
|
||||
if detail.EventType == "thinking" {
|
||||
thinkingText := detail.Message
|
||||
// 如果思考过程太长,截断
|
||||
if len(thinkingText) > 5000 {
|
||||
thinkingText = thinkingText[:5000] + "\n... [内容已截断]"
|
||||
}
|
||||
contentBuilder.WriteString(fmt.Sprintf("思考过程: %s\n\n", thinkingText))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content := contentBuilder.String()
|
||||
if content == "" {
|
||||
return "", fmt.Errorf("片段内容为空")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// compressLargeChunk 压缩过大的chunk(递归分割)
|
||||
func (b *Builder) compressLargeChunk(ctx context.Context, chunk *ContextChunk, maxTokens int) (string, error) {
|
||||
// 将chunk进一步分割成更小的子chunk
|
||||
// 简单策略:按消息和执行数量平均分割
|
||||
totalItems := len(chunk.Messages) + len(chunk.Executions)
|
||||
if totalItems <= 1 {
|
||||
// 如果只有一个项目,直接截断内容
|
||||
content, _ := b.buildChunkContent(chunk)
|
||||
if len(content) > maxTokens*4 { // 粗略估算:1 token ≈ 4字符
|
||||
content = content[:maxTokens*4] + "\n... [内容过大,已截断]"
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// 分成2个子chunk
|
||||
mid := totalItems / 2
|
||||
subChunk1 := &ContextChunk{
|
||||
Messages: []database.Message{},
|
||||
Executions: []*mcp.ToolExecution{},
|
||||
ProcessDetails: make(map[string][]database.ProcessDetail),
|
||||
}
|
||||
subChunk2 := &ContextChunk{
|
||||
Messages: []database.Message{},
|
||||
Executions: []*mcp.ToolExecution{},
|
||||
ProcessDetails: make(map[string][]database.ProcessDetail),
|
||||
}
|
||||
|
||||
// 分配消息
|
||||
for i, msg := range chunk.Messages {
|
||||
if i < mid {
|
||||
subChunk1.Messages = append(subChunk1.Messages, msg)
|
||||
} else {
|
||||
subChunk2.Messages = append(subChunk2.Messages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// 分配执行
|
||||
execStart := len(chunk.Messages)
|
||||
for i, exec := range chunk.Executions {
|
||||
if execStart+i < mid {
|
||||
subChunk1.Executions = append(subChunk1.Executions, exec)
|
||||
} else {
|
||||
subChunk2.Executions = append(subChunk2.Executions, exec)
|
||||
}
|
||||
}
|
||||
|
||||
// 递归压缩子chunk
|
||||
summary1, err := b.summarizeContextChunk(ctx, subChunk1)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("压缩子chunk1失败: %w", err)
|
||||
}
|
||||
|
||||
summary2, err := b.summarizeContextChunk(ctx, subChunk2)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("压缩子chunk2失败: %w", err)
|
||||
}
|
||||
|
||||
// 合并摘要
|
||||
return fmt.Sprintf("片段1摘要:\n%s\n\n---\n\n片段2摘要:\n%s", summary1, summary2), nil
|
||||
}
|
||||
|
||||
// countTextTokens 计算文本的tokens数
|
||||
func (b *Builder) countTextTokens(text string) (int, error) {
|
||||
if b.tokenCounter == nil || b.openAIConfig == nil {
|
||||
return len(text) / 4, nil
|
||||
}
|
||||
|
||||
model := b.openAIConfig.Model
|
||||
if model == "" {
|
||||
model = "gpt-4"
|
||||
}
|
||||
|
||||
count, err := b.tokenCounter.Count(model, text)
|
||||
if err != nil {
|
||||
return len(text) / 4, nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// callAIForChainGeneration 调用AI生成攻击链
|
||||
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
|
||||
requestBody := map[string]interface{}{
|
||||
@@ -985,74 +1514,143 @@ func (b *Builder) shouldFilterNode(n struct {
|
||||
return true
|
||||
}
|
||||
|
||||
// 对于action节点,检查对应的工具执行是否有效
|
||||
if n.Type == "action" {
|
||||
if n.ToolExecutionID == "" {
|
||||
// 没有关联工具执行的action节点,可能是无效的
|
||||
return true
|
||||
}
|
||||
|
||||
// 查找对应的工具执行
|
||||
var exec *mcp.ToolExecution
|
||||
for _, e := range executions {
|
||||
if e.ID == n.ToolExecutionID {
|
||||
exec = e
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if exec == nil {
|
||||
// 找不到对应的工具执行,可能是无效的
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查工具执行是否错误或失败
|
||||
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
|
||||
if !hasInsightfulFailure(n.Metadata) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查工具执行结果是否为空
|
||||
if exec.Result == nil || len(exec.Result.Content) == 0 {
|
||||
if !hasInsightfulFailure(n.Metadata) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查结果文本是否为空
|
||||
var resultText string
|
||||
if exec.Result != nil {
|
||||
for _, content := range exec.Result.Content {
|
||||
if content.Type == "text" {
|
||||
resultText += content.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(resultText) == "" {
|
||||
if !hasInsightfulFailure(n.Metadata) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查节点标签是否为空或无效
|
||||
if strings.TrimSpace(n.Label) == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查标签中是否包含错误/失败的关键词
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
errorKeywords := []string{"错误", "失败", "无效", "error", "failed", "invalid", "empty", "空"}
|
||||
for _, keyword := range errorKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
// 如果标签明确表示错误,但节点类型不是vulnerability,则过滤
|
||||
if n.Type != "vulnerability" {
|
||||
return true
|
||||
// 对于vulnerability节点,即使没有tool_execution_id也应该保留(漏洞可能不是直接来自工具执行)
|
||||
if n.Type == "vulnerability" {
|
||||
// 只要标签有意义就保留
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于target节点,只要标签有意义就保留
|
||||
if n.Type == "target" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于action节点,进行更宽松的检查
|
||||
if n.Type == "action" {
|
||||
// 如果executions为空(可能是压缩后的场景),只要标签有意义就保留
|
||||
if len(executions) == 0 {
|
||||
// 压缩场景下,只要标签不是明显无效就保留
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
// 只过滤明显无效的标签
|
||||
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
|
||||
for _, keyword := range invalidKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果有tool_execution_id,尝试查找对应的工具执行
|
||||
if n.ToolExecutionID != "" {
|
||||
var exec *mcp.ToolExecution
|
||||
for _, e := range executions {
|
||||
if e.ID == n.ToolExecutionID {
|
||||
exec = e
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if exec != nil {
|
||||
// 找到了对应的工具执行,检查是否有效
|
||||
// 检查工具执行是否错误或失败
|
||||
if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) {
|
||||
// 失败但有线索的应该保留
|
||||
if !hasInsightfulFailure(n.Metadata) {
|
||||
// 即使没有明确标记为有线索,如果标签描述了具体内容,也保留
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
// 如果标签包含具体的技术信息(端口、服务、漏洞等),说明有价值
|
||||
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
|
||||
hasValuableInfo := false
|
||||
for _, keyword := range valuableKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
hasValuableInfo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasValuableInfo {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查工具执行结果是否为空
|
||||
if exec.Result == nil || len(exec.Result.Content) == 0 {
|
||||
// 结果为空,但如果有线索或标签有意义,也保留
|
||||
if !hasInsightfulFailure(n.Metadata) {
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
|
||||
hasValuableInfo := false
|
||||
for _, keyword := range valuableKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
hasValuableInfo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasValuableInfo {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 检查结果文本是否为空
|
||||
var resultText string
|
||||
for _, content := range exec.Result.Content {
|
||||
if content.Type == "text" {
|
||||
resultText += content.Text
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(resultText) == "" {
|
||||
// 结果文本为空,但如果有线索或标签有意义,也保留
|
||||
if !hasInsightfulFailure(n.Metadata) {
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
valuableKeywords := []string{"端口", "服务", "漏洞", "扫描", "发现", "获取", "验证", "port", "service", "vulnerability", "scan", "found", "discover"}
|
||||
hasValuableInfo := false
|
||||
for _, keyword := range valuableKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
hasValuableInfo = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasValuableInfo {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 找不到对应的工具执行,但可能是压缩后的场景
|
||||
// 只要标签有意义就保留(不要因为找不到execution就过滤掉)
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
|
||||
for _, keyword := range invalidKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 标签有意义,保留
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// 没有tool_execution_id,但可能是压缩后的场景或AI生成的节点
|
||||
// 只要标签有意义就保留
|
||||
labelLower := strings.ToLower(n.Label)
|
||||
invalidKeywords := []string{"空节点", "无效节点", "empty node", "invalid node"}
|
||||
for _, keyword := range invalidKeywords {
|
||||
if strings.Contains(labelLower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 标签有意义,保留
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 默认保留(已经通过了所有检查)
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user