From d1dc15fa44dcc59d0a90da3020a28164beb3efee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 19 May 2026 16:27:29 +0800 Subject: [PATCH] Add files via upload --- internal/attackchain/builder.go | 151 +++++++++------- internal/attackchain/truncate.go | 248 ++++++++++++++++++++++++++ internal/attackchain/truncate_test.go | 63 +++++++ 3 files changed, 396 insertions(+), 66 deletions(-) create mode 100644 internal/attackchain/truncate.go create mode 100644 internal/attackchain/truncate_test.go diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go index 7543a4e6..f257f5d9 100644 --- a/internal/attackchain/builder.go +++ b/internal/attackchain/builder.go @@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap. } } -// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出) +// BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。 func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) { b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID)) @@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID var reactInputFinal string var dataSource string // 记录数据来源 - // 如果成功获取到保存的ReAct数据,直接使用 - if reactInputJSON != "" && modelOutput != "" { - // 计算 ReAct 输入的哈希值,用于追踪 - hash := sha256.Sum256([]byte(reactInputJSON)) - reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识 + // 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」 + if reactInputJSON != "" { + trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON) + hash := sha256.Sum256([]byte(trimmedJSON)) + reactInputHash := hex.EncodeToString(hash[:])[:16] - // 统计消息数量 var messageCount int - var tempMessages []interface{} - if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil { - messageCount = len(tempMessages) + if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil { + messageCount = len(msgs) + msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput) + reactInputFinal = b.formatAgentTraceFromChatMessages(msgs) + } else { + b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr)) + reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON) + if strings.TrimSpace(modelOutput) != "" { + reactInputFinal += "\n\n## 助手结论(last_react_output)\n\n" + modelOutput + } } - dataSource = "database_last_agent_trace" - b.logger.Info("使用保存的ReAct数据构建攻击链", + dataSource = "last_user_turn_agent_trace" + b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)", zap.String("conversationId", conversationID), zap.String("dataSource", dataSource), - zap.Int("reactInputSize", len(reactInputJSON)), + zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)), + zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)), zap.Int("messageCount", messageCount), zap.String("reactInputHash", reactInputHash), zap.Int("modelOutputSize", len(modelOutput))) - - // 从保存的ReAct输入(JSON格式)中提取用户输入 - // userInput = b.extractUserInputFromReActInput(reactInputJSON) - - // 将JSON格式的messages转换为可读格式 - reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON) } else { // 2. 如果没有保存的ReAct数据,从对话消息构建 dataSource = "messages_table" @@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID } } - // 3. 构建简化的prompt,一次性传递给大模型 - prompt := b.buildSimplePrompt(reactInputFinal, modelOutput) + // 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文) + reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput) + + // 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入) + promptAssistantOut := modelOutput + if reactInputJSON != "" { + promptAssistantOut = "" + } + prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut) // fmt.Println(prompt) // 6. 调用AI生成攻击链(一次性,不做任何处理) chainJSON, err := b.callAIForChainGeneration(ctx, prompt) @@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD return strings.TrimSpace(sb.String()) } -// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入) +// buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。 func (b *Builder) buildAgentTraceInput(messages []database.Message) string { + start := 0 + for i := len(messages) - 1; i >= 0; i-- { + if strings.EqualFold(messages[i].Role, "user") { + start = i + break + } + } var builder strings.Builder - for _, msg := range messages { + for _, msg := range messages[start:] { builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content)) } return builder.String() @@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string { // return "" // } -// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 +// formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。 func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string { - var messages []map[string]interface{} - if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { + trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON) + msgs, err := agent.ParseTraceMessages(trimmed) + if err != nil { b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) - return reactInputJSON // 如果解析失败,返回原始JSON + return trimmed } + return b.formatAgentTraceFromChatMessages(msgs) +} +// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。 +func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string { var builder strings.Builder - for _, msg := range messages { - role, _ := msg["role"].(string) - content, _ := msg["content"].(string) + for _, msg := range msgs { + role := msg.Role + content := msg.Content - // 处理assistant消息:提取tool_calls信息 - if role == "assistant" { - if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { - // 如果有文本内容,先显示 - if content != "" { - builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content)) - } - // 详细显示每个工具调用 - builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls))) - for i, toolCall := range toolCalls { - if tc, ok := toolCall.(map[string]interface{}); ok { - toolCallID, _ := tc["id"].(string) - if funcData, ok := tc["function"].(map[string]interface{}); ok { - toolName, _ := funcData["name"].(string) - arguments, _ := funcData["arguments"].(string) - builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1)) - builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID)) - builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName)) - builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments)) - } + if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 { + if content != "" { + builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content)) + } + builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls))) + for i, tc := range msg.ToolCalls { + args := "" + if tc.Function.Arguments != nil { + if b, err := json.Marshal(tc.Function.Arguments); err == nil { + args = string(b) } } - builder.WriteString("\n") - continue + builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1)) + builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID)) + builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name)) + builder.WriteString(fmt.Sprintf(" 参数: %s\n", args)) } + builder.WriteString("\n") + continue } - // 处理tool消息:显示tool_call_id和完整内容 - if role == "tool" { - toolCallID, _ := msg["tool_call_id"].(string) - if toolCallID != "" { - builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content)) + if strings.EqualFold(role, "tool") { + if msg.ToolCallID != "" { + builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content)) } else { builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) } continue } - // 其他消息类型(system, user等)正常显示 builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) } - return builder.String() } // buildSimplePrompt 构建简化的prompt func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { - return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。 + return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)。 + +## 输入范围(与「继续对话」续跑一致) +- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。 +- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。 ## 核心目标 @@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { 5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability) 6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径) -## 最后一轮ReAct输入 +## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant) %s - -## 大模型输出 - %s ## 输出格式 @@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { 9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。 10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。 -现在开始分析并构建攻击链:`, reactInput, modelOutput) +现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput)) +} + +func assistantOutSection(modelOutput string) string { + modelOutput = strings.TrimSpace(modelOutput) + if modelOutput == "" { + return "" + } + return "\n## 助手结论(补充)\n\n" + modelOutput + "\n" } // saveChain 保存攻击链到数据库 @@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) ( }, }, "temperature": 0.3, - "max_completion_tokens": 80000, + "max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens), } var apiResponse struct { diff --git a/internal/attackchain/truncate.go b/internal/attackchain/truncate.go new file mode 100644 index 00000000..ba379b3b --- /dev/null +++ b/internal/attackchain/truncate.go @@ -0,0 +1,248 @@ +package attackchain + +import ( + "strings" + "unicode/utf8" + + "go.uber.org/zap" +) + +const ( + attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n" + attackChainSystemReserve = 256 + attackChainSafetyReserve = 2048 +) + +// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。 +func attackChainMaxCompletionTokens(maxTotal int) int { + const capTokens = 16384 + if maxTotal <= 0 { + return 8192 + } + v := maxTotal / 8 + if v < 4096 { + v = 4096 + } + if v > capTokens { + v = capTokens + } + return v +} + +func (b *Builder) modelName() string { + if b.openAIConfig != nil && b.openAIConfig.Model != "" { + return b.openAIConfig.Model + } + return "gpt-4" +} + +func (b *Builder) countTokens(text string) int { + if text == "" { + return 0 + } + n, err := b.tokenCounter.Count(b.modelName(), text) + if err != nil { + return utf8.RuneCountInString(text) / 4 + } + return n +} + +// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。 +func (b *Builder) attackChainPayloadTokenBudget() int { + maxTotal := b.maxTokens + if maxTotal <= 0 { + maxTotal = 100000 + } + templateTok := b.countTokens(b.buildSimplePrompt("", "")) + completion := attackChainMaxCompletionTokens(maxTotal) + reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve + budget := maxTotal - reserve + minBudget := maxTotal * 35 / 100 + if budget < minBudget { + budget = minBudget + } + if budget < 4096 { + budget = 4096 + } + return budget +} + +// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。 +func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) { + budget := b.attackChainPayloadTokenBudget() + modelBudget := budget * 15 / 100 + if modelBudget < 512 { + modelBudget = 512 + } + reactBudget := budget - modelBudget + + origReactTok := b.countTokens(reactInput) + origModelTok := b.countTokens(modelOutput) + truncated := false + + outModel := modelOutput + if origModelTok > modelBudget { + outModel = truncateTextByTokens(b, modelOutput, modelBudget) + truncated = true + } + + outReact := reactInput + perToolLimits := []int{12000, 6000, 3000, 1500, 800} + for _, lim := range perToolLimits { + compact := compactFormattedToolBodies(outReact, lim) + if compact != outReact { + outReact = compact + truncated = true + } + if b.countTokens(outReact) <= reactBudget { + break + } + } + + if b.countTokens(outReact) > reactBudget { + outReact = truncateTextByTokens(b, outReact, reactBudget) + truncated = true + } + + if truncated { + b.logger.Info("攻击链输入已按 token 预算截断", + zap.Int("maxTotalTokens", b.maxTokens), + zap.Int("payloadBudget", budget), + zap.Int("reactBudget", reactBudget), + zap.Int("modelBudget", modelBudget), + zap.Int("reactInputTokensBefore", origReactTok), + zap.Int("reactInputTokensAfter", b.countTokens(outReact)), + zap.Int("modelOutputTokensBefore", origModelTok), + zap.Int("modelOutputTokensAfter", b.countTokens(outModel)), + zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)), + ) + } + + return outReact, outModel, truncated +} + +// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。 +func compactFormattedToolBodies(s string, maxRunesPerBody int) string { + if maxRunesPerBody <= 0 || s == "" { + return s + } + const marker = "[tool]" + var out strings.Builder + remaining := s + changed := false + for { + idx := strings.Index(remaining, marker) + if idx < 0 { + out.WriteString(remaining) + break + } + out.WriteString(remaining[:idx]) + remaining = remaining[idx:] + nl := strings.IndexByte(remaining, '\n') + if nl < 0 { + out.WriteString(remaining) + break + } + header := remaining[:nl+1] + remaining = remaining[nl+1:] + bodyEnd := strings.Index(remaining, "\n\n[") + var body, rest string + if bodyEnd < 0 { + body = remaining + rest = "" + } else { + body = remaining[:bodyEnd] + rest = remaining[bodyEnd:] + } + if runeLen(body) > maxRunesPerBody { + body = truncateRunesWithNotice(body, maxRunesPerBody) + changed = true + } + out.WriteString(header) + out.WriteString(body) + remaining = rest + if rest == "" { + break + } + } + if !changed { + return s + } + return out.String() +} + +func truncateTextByTokens(b *Builder, text string, maxTokens int) string { + if maxTokens <= 0 || text == "" { + return "" + } + if b.countTokens(text) <= maxTokens { + return text + } + markerTok := b.countTokens(attackChainTruncationMarker) + usable := maxTokens - markerTok + if usable < 256 { + usable = maxTokens / 2 + } + headBudget := usable * 60 / 100 + tailBudget := usable - headBudget + head := takeTokensFromStart(b, text, headBudget) + tail := takeTokensFromEnd(b, text, tailBudget) + return head + attackChainTruncationMarker + tail +} + +func takeTokensFromStart(b *Builder, text string, maxTokens int) string { + rs := []rune(text) + if len(rs) == 0 || maxTokens <= 0 { + return "" + } + lo, hi := 0, len(rs) + for lo < hi { + mid := (lo + hi + 1) / 2 + if b.countTokens(string(rs[:mid])) <= maxTokens { + lo = mid + } else { + hi = mid - 1 + } + } + return string(rs[:lo]) +} + +func takeTokensFromEnd(b *Builder, text string, maxTokens int) string { + rs := []rune(text) + if len(rs) == 0 || maxTokens <= 0 { + return "" + } + lo, hi := 0, len(rs) + for lo < hi { + mid := (lo + hi) / 2 + if b.countTokens(string(rs[mid:])) <= maxTokens { + hi = mid + } else { + lo = mid + 1 + } + } + return string(rs[lo:]) +} + +func truncateRunesWithNotice(s string, maxRunes int) string { + rs := []rune(s) + if len(rs) <= maxRunes { + return s + } + const notice = "\n...[工具输出已截断 / tool output truncated]...\n" + noticeRunes := []rune(notice) + keep := maxRunes - len(noticeRunes) + if keep < 200 { + keep = maxRunes * 2 / 3 + } + if keep < 1 { + return notice + } + head := keep * 70 / 100 + tail := keep - head + return string(rs[:head]) + notice + string(rs[len(rs)-tail:]) +} + +func runeLen(s string) int { + return len([]rune(s)) +} diff --git a/internal/attackchain/truncate_test.go b/internal/attackchain/truncate_test.go new file mode 100644 index 00000000..2cb4563c --- /dev/null +++ b/internal/attackchain/truncate_test.go @@ -0,0 +1,63 @@ +package attackchain + +import ( + "strings" + "testing" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + + "go.uber.org/zap" +) + +func testBuilder(maxTotal int) *Builder { + return &Builder{ + logger: zap.NewNop(), + openAIConfig: &config.OpenAIConfig{Model: "gpt-4"}, + tokenCounter: agent.NewTikTokenCounter(), + maxTokens: maxTotal, + } +} + +func TestCompactFormattedToolBodies(t *testing.T) { + long := strings.Repeat("x", 20000) + in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n" + out := compactFormattedToolBodies(in, 500) + if strings.Contains(out, strings.Repeat("x", 10000)) { + t.Fatal("expected tool body to be truncated") + } + if !strings.Contains(out, "[user]: hi") { + t.Fatal("expected user header preserved") + } + if !strings.Contains(out, "[assistant]: done") { + t.Fatal("expected assistant header preserved") + } +} + +func TestFitAttackChainPayloadWithinBudget(t *testing.T) { + b := testBuilder(32000) + react := strings.Repeat("scan ", 50000) + model := strings.Repeat("result ", 10000) + r, m, truncated := b.fitAttackChainPayload(react, model) + if !truncated { + t.Fatal("expected truncation for large payload") + } + prompt := b.buildSimplePrompt(r, m) + total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve + if total > b.maxTokens+attackChainSafetyReserve { + t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens) + } + _ = m +} + +func TestAttackChainMaxCompletionTokens(t *testing.T) { + if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 { + // 120000/8 = 15000 + if got < 4096 || got > 16384 { + t.Fatalf("unexpected completion cap: %d", got) + } + } + if got := attackChainMaxCompletionTokens(0); got != 8192 { + t.Fatalf("expected default 8192, got %d", got) + } +}