From 38169abc4b265a5df04ed0b561cdd52b8005e797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Wed, 22 Apr 2026 13:59:17 +0800 Subject: [PATCH] Add files via upload --- internal/config/config.go | 2 +- .../multiagent/orchestrator_instruction.go | 2 +- internal/multiagent/runner.go | 6 +- internal/multiagent/sub_agent_context.go | 84 +++++++-- internal/multiagent/sub_agent_context_test.go | 173 ++++++++++++------ 5 files changed, 193 insertions(+), 74 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 3e4a912a..99fb4c6a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -54,7 +54,7 @@ type MultiAgentConfig struct { // OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。 OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"` SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"` - // SubAgentUserContextMaxRunes caps the user-context supplement injected into sub-agent instructions. + // SubAgentUserContextMaxRunes caps the user-context supplement appended to task descriptions for sub-agents. // 0 (default) uses the built-in default of 2000 runes; negative value disables injection entirely. SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"` // EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent. diff --git a/internal/multiagent/orchestrator_instruction.go b/internal/multiagent/orchestrator_instruction.go index 3b8f9a10..a1fd01d3 100644 --- a/internal/multiagent/orchestrator_instruction.go +++ b/internal/multiagent/orchestrator_instruction.go @@ -210,7 +210,7 @@ func DefaultSupervisorOrchestratorInstruction() string { ## transfer 交接与防重复劳动 -- 每次 transfer 前,在**本条助手正文**中写清交接包:已知主域、关键子域或主机短表、已识别端口与服务、上轮已达成共识的结论要点;勿仅依赖历史里的超长工具原始输出(上下文摘要后专家可能看不到细节)。 +- **把专家当作刚走进房间的同事——它没看过你的对话,不知道你做了什么,也不了解这个任务为什么重要。** 每次 transfer 前,在**本条助手正文**中写清交接包:已知主域、关键子域或主机短表、已识别端口与服务、上轮已达成共识的结论要点;勿仅依赖历史里的超长工具原始输出(上下文摘要后专家可能看不到细节)。 - 写清本轮**唯一子目标**与**禁止项**(例如:不得再做全量子域枚举;仅对下列目标做 MQTT 或认证验证)。 - 验证、利用、协议深挖应 transfer 给**对应专项**子代理;避免把「仅剩验证」的工作交给侦察类(recon)导致其从全量枚举起手。 - 同一目标多次串行 transfer 时,每一次交接包都要带上**截至当前的共识事实**增量,勿假设专家已读过上一轮专家的隐性推理。 diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index 8efea50b..f94c4303 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -205,9 +205,6 @@ func RunDeepAgent( if instr == "" { instr = "你是 CyberStrikeAI 中的专业子代理,在授权渗透测试场景下协助完成用户委托的子任务。优先使用可用工具获取证据,回答简洁专业。" } - if supplement := buildUserContextForSubAgent(userMessage, history, ma.SubAgentUserContextMaxRunes); supplement != "" { - instr += supplement - } roleTools := sub.RoleTools bind := strings.TrimSpace(sub.BindRole) @@ -344,6 +341,9 @@ func RunDeepAgent( // noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。 deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()} + if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes); mw != nil { + deepHandlers = append(deepHandlers, mw) + } if len(mainOrchestratorPre) > 0 { deepHandlers = append(deepHandlers, mainOrchestratorPre...) } diff --git a/internal/multiagent/sub_agent_context.go b/internal/multiagent/sub_agent_context.go index bd9e11eb..d2ec73cb 100644 --- a/internal/multiagent/sub_agent_context.go +++ b/internal/multiagent/sub_agent_context.go @@ -1,26 +1,81 @@ package multiagent import ( + "context" + "encoding/json" "strings" "cyberstrike-ai/internal/agent" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/tool" ) const defaultSubAgentUserContextMaxRunes = 2000 -// buildUserContextForSubAgent collects all user messages from conversation -// history plus the current user message, and returns a formatted string to -// append to sub-agent instructions. This ensures sub-agents always have -// access to the full user intent (target URLs, scope, etc.) even when the -// orchestrator forgets to include them in the task description. +// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator +// and appends the user's original conversation messages to the task description. +// This ensures sub-agents always receive the full user intent (target URLs, +// scope, etc.) even when the orchestrator forgets to include them. // -// maxRunes controls the character budget for the user-context body: -// - 0 uses defaultSubAgentUserContextMaxRunes -// - negative disables injection (returns "") -// -// When truncation is needed, the first and last user messages are each -// allocated half the budget so neither is lost entirely. -func buildUserContextForSubAgent(userMessage string, history []agent.ChatMessage, maxRunes int) string { +// Design: user context is injected into the task description (per-task), NOT +// into the sub-agent's Instruction (system prompt). This keeps sub-agent +// Instructions clean as pure role definitions while attaching context to the +// specific delegation — aligned with Claude Code's agent design philosophy. +type taskContextEnrichMiddleware struct { + adk.BaseChatModelAgentMiddleware + supplement string // pre-built user context block +} + +// newTaskContextEnrichMiddleware returns a middleware that enriches task +// descriptions with user conversation context. Returns nil if disabled +// (maxRunes < 0) or no user messages exist. +func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int) adk.ChatModelAgentMiddleware { + supplement := buildUserContextSupplement(userMessage, history, maxRunes) + if supplement == "" { + return nil + } + return &taskContextEnrichMiddleware{supplement: supplement} +} + +func (m *taskContextEnrichMiddleware) WrapInvokableToolCall( + ctx context.Context, + endpoint adk.InvokableToolCallEndpoint, + tCtx *adk.ToolContext, +) (adk.InvokableToolCallEndpoint, error) { + if tCtx == nil || !strings.EqualFold(strings.TrimSpace(tCtx.Name), "task") { + return endpoint, nil + } + return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + enriched := m.enrichTaskDescription(argumentsInJSON) + return endpoint(ctx, enriched, opts...) + }, nil +} + +// enrichTaskDescription parses the task JSON arguments, appends user context +// to the "description" field, and re-serializes. Falls back to the original +// JSON if parsing fails or no description field exists. +func (m *taskContextEnrichMiddleware) enrichTaskDescription(argsJSON string) string { + var raw map[string]interface{} + if err := json.Unmarshal([]byte(argsJSON), &raw); err != nil { + return argsJSON + } + desc, ok := raw["description"].(string) + if !ok { + return argsJSON + } + raw["description"] = desc + m.supplement + enriched, err := json.Marshal(raw) + if err != nil { + return argsJSON + } + return string(enriched) +} + +// buildUserContextSupplement collects user messages from conversation history +// and the current message, returning a formatted block to append to task +// descriptions. Returns "" if disabled or no user messages exist. +func buildUserContextSupplement(userMessage string, history []agent.ChatMessage, maxRunes int) string { if maxRunes < 0 { return "" } @@ -46,17 +101,16 @@ func buildUserContextForSubAgent(userMessage string, history []agent.ChatMessage } joined := strings.Join(userMsgs, "\n---\n") - if len([]rune(joined)) > maxRunes { joined = truncateKeepFirstLast(userMsgs, maxRunes) } - return "\n\n## 本次会话用户原始请求(自动注入,确保你了解完整上下文)\n" + joined + return "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + joined } // truncateKeepFirstLast keeps the first and last user messages, giving each // half the rune budget. The first message typically contains target info; -// the last is the current instruction. +// the last contains the current instruction. func truncateKeepFirstLast(msgs []string, maxRunes int) string { if len(msgs) == 1 { return truncateRunes(msgs[0], maxRunes) diff --git a/internal/multiagent/sub_agent_context_test.go b/internal/multiagent/sub_agent_context_test.go index c2ca21eb..72e10762 100644 --- a/internal/multiagent/sub_agent_context_test.go +++ b/internal/multiagent/sub_agent_context_test.go @@ -1,71 +1,87 @@ package multiagent import ( + "context" + "encoding/json" "strings" "testing" "cyberstrike-ai/internal/agent" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/components/tool" ) -func TestBuildUserContextForSubAgent_SingleMessage(t *testing.T) { - result := buildUserContextForSubAgent("http://8.163.32.73:8081 测试命令执行", nil, 0) +// --- buildUserContextSupplement tests --- + +func TestBuildUserContextSupplement_SingleMessage(t *testing.T) { + result := buildUserContextSupplement("http://8.163.32.73:8081 测试命令执行", nil, 0) if result == "" { - t.Fatal("expected non-empty context") + t.Fatal("expected non-empty supplement") } if !strings.Contains(result, "http://8.163.32.73:8081") { - t.Error("expected URL in context") + t.Error("expected URL in supplement") } } -func TestBuildUserContextForSubAgent_MultiTurn(t *testing.T) { +func TestBuildUserContextSupplement_MultiTurn(t *testing.T) { history := []agent.ChatMessage{ {Role: "user", Content: "http://8.163.32.73:8081 这是一个pikachu靶场,尝试测试命令执行"}, {Role: "assistant", Content: "好的,我来测试..."}, {Role: "user", Content: "继续,并持久化webshell"}, {Role: "assistant", Content: "正在处理..."}, } - result := buildUserContextForSubAgent("你好", history, 0) + result := buildUserContextSupplement("你好", history, 0) if !strings.Contains(result, "http://8.163.32.73:8081") { t.Error("expected first turn URL to be preserved") } if !strings.Contains(result, "你好") { - t.Error("expected current message in context") + t.Error("expected current message") } } -func TestBuildUserContextForSubAgent_EmptyMessages(t *testing.T) { - result := buildUserContextForSubAgent("", nil, 0) - if result != "" { - t.Errorf("expected empty context, got %q", result) +func TestBuildUserContextSupplement_Empty(t *testing.T) { + if result := buildUserContextSupplement("", nil, 0); result != "" { + t.Errorf("expected empty, got %q", result) } } -func TestBuildUserContextForSubAgent_DeduplicateCurrentMessage(t *testing.T) { - history := []agent.ChatMessage{ - {Role: "user", Content: "你好"}, - } - result := buildUserContextForSubAgent("你好", history, 0) +func TestBuildUserContextSupplement_Deduplicate(t *testing.T) { + history := []agent.ChatMessage{{Role: "user", Content: "你好"}} + result := buildUserContextSupplement("你好", history, 0) if strings.Count(result, "你好") != 1 { - t.Errorf("expected '你好' exactly once, got: %s", result) + t.Errorf("expected '你好' once, got: %s", result) } } -func TestBuildUserContextForSubAgent_SkipsNonUserMessages(t *testing.T) { +func TestBuildUserContextSupplement_SkipsNonUser(t *testing.T) { history := []agent.ChatMessage{ {Role: "user", Content: "目标是 10.0.0.1"}, - {Role: "assistant", Content: "这个不应该出现"}, - {Role: "user", Content: "开始扫描"}, + {Role: "assistant", Content: "不应该出现"}, } - result := buildUserContextForSubAgent("确认", history, 0) - if strings.Contains(result, "这个不应该出现") { + result := buildUserContextSupplement("确认", history, 0) + if strings.Contains(result, "不应该出现") { t.Error("assistant message should not be included") } - if !strings.Contains(result, "10.0.0.1") { - t.Error("expected IP from first user message") - } } -func TestBuildUserContextForSubAgent_TruncatesLongConversation(t *testing.T) { +func TestBuildUserContextSupplement_DisabledByNegative(t *testing.T) { + if result := buildUserContextSupplement("test", nil, -1); result != "" { + t.Errorf("expected empty when disabled, got %q", result) + } +} + +func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) { + msg := strings.Repeat("A", 200) + result := buildUserContextSupplement(msg, nil, 50) + header := "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + body := strings.TrimPrefix(result, header) + if len([]rune(body)) > 50 { + t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body))) + } +} + +func TestBuildUserContextSupplement_TruncateKeepsFirstAndLast(t *testing.T) { first := "http://target.com " + strings.Repeat("A", 500) var history []agent.ChatMessage history = append(history, agent.ChatMessage{Role: "user", Content: first}) @@ -73,45 +89,94 @@ func TestBuildUserContextForSubAgent_TruncatesLongConversation(t *testing.T) { history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)}) } last := "最后一条指令" - result := buildUserContextForSubAgent(last, history, 0) - + result := buildUserContextSupplement(last, history, 0) if !strings.Contains(result, "http://target.com") { - t.Error("first message (target URL) should be preserved after truncation") + t.Error("first message (target URL) should survive truncation") } if !strings.Contains(result, last) { - t.Error("last message should be preserved after truncation") + t.Error("last message should survive truncation") } } -func TestBuildUserContextForSubAgent_DisabledByNegativeMax(t *testing.T) { - result := buildUserContextForSubAgent("http://example.com test", nil, -1) - if result != "" { - t.Errorf("expected empty when disabled, got %q", result) +// --- middleware integration tests --- + +func TestTaskContextEnrichMiddleware_EnrichesTaskDescription(t *testing.T) { + mw := newTaskContextEnrichMiddleware( + "继续测试", + []agent.ChatMessage{{Role: "user", Content: "http://8.163.32.73:8081 pikachu靶场"}}, + 0, + ) + if mw == nil { + t.Fatal("expected non-nil middleware") + } + + called := false + var capturedArgs string + fakeEndpoint := func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + called = true + capturedArgs = args + return "ok", nil + } + + wrapped, err := mw.(interface { + WrapInvokableToolCall(context.Context, adk.InvokableToolCallEndpoint, *adk.ToolContext) (adk.InvokableToolCallEndpoint, error) + }).WrapInvokableToolCall(context.Background(), fakeEndpoint, &adk.ToolContext{Name: "task"}) + if err != nil { + t.Fatal(err) + } + + taskArgs := `{"subagent_type":"recon","description":"扫描目标端口"}` + wrapped(context.Background(), taskArgs) + + if !called { + t.Fatal("endpoint was not called") + } + + var parsed map[string]interface{} + if err := json.Unmarshal([]byte(capturedArgs), &parsed); err != nil { + t.Fatalf("enriched args not valid JSON: %v", err) + } + desc := parsed["description"].(string) + if !strings.Contains(desc, "扫描目标端口") { + t.Error("original description should be preserved") + } + if !strings.Contains(desc, "http://8.163.32.73:8081") { + t.Error("user context should be appended to description") + } + if !strings.Contains(desc, "继续测试") { + t.Error("current user message should be in description") } } -func TestBuildUserContextForSubAgent_CustomMaxRunes(t *testing.T) { - msg := strings.Repeat("A", 200) - result := buildUserContextForSubAgent(msg, nil, 50) - body := strings.TrimPrefix(result, "\n\n## 本次会话用户原始请求(自动注入,确保你了解完整上下文)\n") - if len([]rune(body)) > 50 { - t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body))) +func TestTaskContextEnrichMiddleware_IgnoresNonTaskTools(t *testing.T) { + mw := newTaskContextEnrichMiddleware("test", nil, 0) + if mw == nil { + t.Fatal("expected non-nil middleware") + } + + original := `{"command":"nmap -sV target"}` + var capturedArgs string + fakeEndpoint := func(ctx context.Context, args string, opts ...tool.Option) (string, error) { + capturedArgs = args + return "ok", nil + } + + wrapped, err := mw.(interface { + WrapInvokableToolCall(context.Context, adk.InvokableToolCallEndpoint, *adk.ToolContext) (adk.InvokableToolCallEndpoint, error) + }).WrapInvokableToolCall(context.Background(), fakeEndpoint, &adk.ToolContext{Name: "nmap_scan"}) + if err != nil { + t.Fatal(err) + } + + wrapped(context.Background(), original) + if capturedArgs != original { + t.Errorf("non-task tool args should not be modified, got %q", capturedArgs) } } -func TestTruncateKeepFirstLast_BothPreserved(t *testing.T) { - first := strings.Repeat("F", 100) - last := strings.Repeat("L", 100) - msgs := []string{first, "middle1", "middle2", last} - result := truncateKeepFirstLast(msgs, 250) - - if !strings.HasPrefix(result, "FFFF") { - t.Error("first message should be at the start") - } - if !strings.HasSuffix(result, "LLLL") { - t.Error("last message should be at the end") - } - if !strings.Contains(result, "中间对话省略") { - t.Error("should contain truncation marker") +func TestTaskContextEnrichMiddleware_NilWhenDisabled(t *testing.T) { + mw := newTaskContextEnrichMiddleware("test", nil, -1) + if mw != nil { + t.Error("middleware should be nil when disabled") } }