diff --git a/internal/config/config.go b/internal/config/config.go index 9d75e160..3e4a912a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -54,6 +54,9 @@ type MultiAgentConfig struct { // OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。 OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"` SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"` + // SubAgentUserContextMaxRunes caps the user-context supplement injected into sub-agent instructions. + // 0 (default) uses the built-in default of 2000 runes; negative value disables injection entirely. + SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"` // EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent. EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"` // EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras. diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index ddfbce9c..8efea50b 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -205,6 +205,9 @@ func RunDeepAgent( if instr == "" { instr = "你是 CyberStrikeAI 中的专业子代理,在授权渗透测试场景下协助完成用户委托的子任务。优先使用可用工具获取证据,回答简洁专业。" } + if supplement := buildUserContextForSubAgent(userMessage, history, ma.SubAgentUserContextMaxRunes); supplement != "" { + instr += supplement + } roleTools := sub.RoleTools bind := strings.TrimSpace(sub.BindRole) diff --git a/internal/multiagent/sub_agent_context.go b/internal/multiagent/sub_agent_context.go new file mode 100644 index 00000000..bd9e11eb --- /dev/null +++ b/internal/multiagent/sub_agent_context.go @@ -0,0 +1,91 @@ +package multiagent + +import ( + "strings" + + "cyberstrike-ai/internal/agent" +) + +const defaultSubAgentUserContextMaxRunes = 2000 + +// buildUserContextForSubAgent collects all user messages from conversation +// history plus the current user message, and returns a formatted string to +// append to sub-agent instructions. This ensures sub-agents always have +// access to the full user intent (target URLs, scope, etc.) even when the +// orchestrator forgets to include them in the task description. +// +// maxRunes controls the character budget for the user-context body: +// - 0 uses defaultSubAgentUserContextMaxRunes +// - negative disables injection (returns "") +// +// When truncation is needed, the first and last user messages are each +// allocated half the budget so neither is lost entirely. +func buildUserContextForSubAgent(userMessage string, history []agent.ChatMessage, maxRunes int) string { + if maxRunes < 0 { + return "" + } + if maxRunes == 0 { + maxRunes = defaultSubAgentUserContextMaxRunes + } + + var userMsgs []string + for _, h := range history { + if h.Role == "user" { + if m := strings.TrimSpace(h.Content); m != "" { + userMsgs = append(userMsgs, m) + } + } + } + if um := strings.TrimSpace(userMessage); um != "" { + if len(userMsgs) == 0 || userMsgs[len(userMsgs)-1] != um { + userMsgs = append(userMsgs, um) + } + } + if len(userMsgs) == 0 { + return "" + } + + joined := strings.Join(userMsgs, "\n---\n") + + if len([]rune(joined)) > maxRunes { + joined = truncateKeepFirstLast(userMsgs, maxRunes) + } + + return "\n\n## 本次会话用户原始请求(自动注入,确保你了解完整上下文)\n" + joined +} + +// truncateKeepFirstLast keeps the first and last user messages, giving each +// half the rune budget. The first message typically contains target info; +// the last is the current instruction. +func truncateKeepFirstLast(msgs []string, maxRunes int) string { + if len(msgs) == 1 { + return truncateRunes(msgs[0], maxRunes) + } + + first := msgs[0] + last := msgs[len(msgs)-1] + sep := "\n---\n...(中间对话省略)...\n---\n" + sepLen := len([]rune(sep)) + + budget := maxRunes - sepLen + if budget <= 0 { + return truncateRunes(first+"\n---\n"+last, maxRunes) + } + + halfBudget := budget / 2 + firstTrunc := truncateRunes(first, halfBudget) + lastTrunc := truncateRunes(last, budget-len([]rune(firstTrunc))) + + return firstTrunc + sep + lastTrunc +} + +func truncateRunes(s string, max int) string { + rs := []rune(s) + if len(rs) <= max { + return s + } + if max <= 0 { + return "" + } + return string(rs[:max]) +} diff --git a/internal/multiagent/sub_agent_context_test.go b/internal/multiagent/sub_agent_context_test.go new file mode 100644 index 00000000..c2ca21eb --- /dev/null +++ b/internal/multiagent/sub_agent_context_test.go @@ -0,0 +1,117 @@ +package multiagent + +import ( + "strings" + "testing" + + "cyberstrike-ai/internal/agent" +) + +func TestBuildUserContextForSubAgent_SingleMessage(t *testing.T) { + result := buildUserContextForSubAgent("http://8.163.32.73:8081 测试命令执行", nil, 0) + if result == "" { + t.Fatal("expected non-empty context") + } + if !strings.Contains(result, "http://8.163.32.73:8081") { + t.Error("expected URL in context") + } +} + +func TestBuildUserContextForSubAgent_MultiTurn(t *testing.T) { + history := []agent.ChatMessage{ + {Role: "user", Content: "http://8.163.32.73:8081 这是一个pikachu靶场,尝试测试命令执行"}, + {Role: "assistant", Content: "好的,我来测试..."}, + {Role: "user", Content: "继续,并持久化webshell"}, + {Role: "assistant", Content: "正在处理..."}, + } + result := buildUserContextForSubAgent("你好", history, 0) + if !strings.Contains(result, "http://8.163.32.73:8081") { + t.Error("expected first turn URL to be preserved") + } + if !strings.Contains(result, "你好") { + t.Error("expected current message in context") + } +} + +func TestBuildUserContextForSubAgent_EmptyMessages(t *testing.T) { + result := buildUserContextForSubAgent("", nil, 0) + if result != "" { + t.Errorf("expected empty context, got %q", result) + } +} + +func TestBuildUserContextForSubAgent_DeduplicateCurrentMessage(t *testing.T) { + history := []agent.ChatMessage{ + {Role: "user", Content: "你好"}, + } + result := buildUserContextForSubAgent("你好", history, 0) + if strings.Count(result, "你好") != 1 { + t.Errorf("expected '你好' exactly once, got: %s", result) + } +} + +func TestBuildUserContextForSubAgent_SkipsNonUserMessages(t *testing.T) { + history := []agent.ChatMessage{ + {Role: "user", Content: "目标是 10.0.0.1"}, + {Role: "assistant", Content: "这个不应该出现"}, + {Role: "user", Content: "开始扫描"}, + } + result := buildUserContextForSubAgent("确认", history, 0) + if strings.Contains(result, "这个不应该出现") { + t.Error("assistant message should not be included") + } + if !strings.Contains(result, "10.0.0.1") { + t.Error("expected IP from first user message") + } +} + +func TestBuildUserContextForSubAgent_TruncatesLongConversation(t *testing.T) { + first := "http://target.com " + strings.Repeat("A", 500) + var history []agent.ChatMessage + history = append(history, agent.ChatMessage{Role: "user", Content: first}) + for i := 0; i < 10; i++ { + history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)}) + } + last := "最后一条指令" + result := buildUserContextForSubAgent(last, history, 0) + + if !strings.Contains(result, "http://target.com") { + t.Error("first message (target URL) should be preserved after truncation") + } + if !strings.Contains(result, last) { + t.Error("last message should be preserved after truncation") + } +} + +func TestBuildUserContextForSubAgent_DisabledByNegativeMax(t *testing.T) { + result := buildUserContextForSubAgent("http://example.com test", nil, -1) + if result != "" { + t.Errorf("expected empty when disabled, got %q", result) + } +} + +func TestBuildUserContextForSubAgent_CustomMaxRunes(t *testing.T) { + msg := strings.Repeat("A", 200) + result := buildUserContextForSubAgent(msg, nil, 50) + body := strings.TrimPrefix(result, "\n\n## 本次会话用户原始请求(自动注入,确保你了解完整上下文)\n") + if len([]rune(body)) > 50 { + t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body))) + } +} + +func TestTruncateKeepFirstLast_BothPreserved(t *testing.T) { + first := strings.Repeat("F", 100) + last := strings.Repeat("L", 100) + msgs := []string{first, "middle1", "middle2", last} + result := truncateKeepFirstLast(msgs, 250) + + if !strings.HasPrefix(result, "FFFF") { + t.Error("first message should be at the start") + } + if !strings.HasSuffix(result, "LLLL") { + t.Error("last message should be at the end") + } + if !strings.Contains(result, "中间对话省略") { + t.Error("should contain truncation marker") + } +}