mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-02 07:45:24 +02:00
146 lines
4.4 KiB
Go
146 lines
4.4 KiB
Go
package multiagent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
|
|
"cyberstrike-ai/internal/agent"
|
|
|
|
"github.com/cloudwego/eino/adk"
|
|
"github.com/cloudwego/eino/components/tool"
|
|
)
|
|
|
|
const defaultSubAgentUserContextMaxRunes = 2000
|
|
|
|
// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator
|
|
// and appends the user's original conversation messages to the task description.
|
|
// This ensures sub-agents always receive the full user intent (target URLs,
|
|
// scope, etc.) even when the orchestrator forgets to include them.
|
|
//
|
|
// Design: user context is injected into the task description (per-task), NOT
|
|
// into the sub-agent's Instruction (system prompt). This keeps sub-agent
|
|
// Instructions clean as pure role definitions while attaching context to the
|
|
// specific delegation — aligned with Claude Code's agent design philosophy.
|
|
type taskContextEnrichMiddleware struct {
|
|
adk.BaseChatModelAgentMiddleware
|
|
supplement string // pre-built user context block
|
|
}
|
|
|
|
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
|
// descriptions with user conversation context. Returns nil if disabled
|
|
// (maxRunes < 0) or no user messages exist.
|
|
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int) adk.ChatModelAgentMiddleware {
|
|
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
|
if supplement == "" {
|
|
return nil
|
|
}
|
|
return &taskContextEnrichMiddleware{supplement: supplement}
|
|
}
|
|
|
|
func (m *taskContextEnrichMiddleware) WrapInvokableToolCall(
|
|
ctx context.Context,
|
|
endpoint adk.InvokableToolCallEndpoint,
|
|
tCtx *adk.ToolContext,
|
|
) (adk.InvokableToolCallEndpoint, error) {
|
|
if tCtx == nil || !strings.EqualFold(strings.TrimSpace(tCtx.Name), "task") {
|
|
return endpoint, nil
|
|
}
|
|
return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
|
enriched := m.enrichTaskDescription(argumentsInJSON)
|
|
return endpoint(ctx, enriched, opts...)
|
|
}, nil
|
|
}
|
|
|
|
// enrichTaskDescription parses the task JSON arguments, appends user context
|
|
// to the "description" field, and re-serializes. Falls back to the original
|
|
// JSON if parsing fails or no description field exists.
|
|
func (m *taskContextEnrichMiddleware) enrichTaskDescription(argsJSON string) string {
|
|
var raw map[string]interface{}
|
|
if err := json.Unmarshal([]byte(argsJSON), &raw); err != nil {
|
|
return argsJSON
|
|
}
|
|
desc, ok := raw["description"].(string)
|
|
if !ok {
|
|
return argsJSON
|
|
}
|
|
raw["description"] = desc + m.supplement
|
|
enriched, err := json.Marshal(raw)
|
|
if err != nil {
|
|
return argsJSON
|
|
}
|
|
return string(enriched)
|
|
}
|
|
|
|
// buildUserContextSupplement collects user messages from conversation history
|
|
// and the current message, returning a formatted block to append to task
|
|
// descriptions. Returns "" if disabled or no user messages exist.
|
|
func buildUserContextSupplement(userMessage string, history []agent.ChatMessage, maxRunes int) string {
|
|
if maxRunes < 0 {
|
|
return ""
|
|
}
|
|
if maxRunes == 0 {
|
|
maxRunes = defaultSubAgentUserContextMaxRunes
|
|
}
|
|
|
|
var userMsgs []string
|
|
for _, h := range history {
|
|
if h.Role == "user" {
|
|
if m := strings.TrimSpace(h.Content); m != "" {
|
|
userMsgs = append(userMsgs, m)
|
|
}
|
|
}
|
|
}
|
|
if um := strings.TrimSpace(userMessage); um != "" {
|
|
if len(userMsgs) == 0 || userMsgs[len(userMsgs)-1] != um {
|
|
userMsgs = append(userMsgs, um)
|
|
}
|
|
}
|
|
if len(userMsgs) == 0 {
|
|
return ""
|
|
}
|
|
|
|
joined := strings.Join(userMsgs, "\n---\n")
|
|
if len([]rune(joined)) > maxRunes {
|
|
joined = truncateKeepFirstLast(userMsgs, maxRunes)
|
|
}
|
|
|
|
return "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + joined
|
|
}
|
|
|
|
// truncateKeepFirstLast keeps the first and last user messages, giving each
|
|
// half the rune budget. The first message typically contains target info;
|
|
// the last contains the current instruction.
|
|
func truncateKeepFirstLast(msgs []string, maxRunes int) string {
|
|
if len(msgs) == 1 {
|
|
return truncateRunes(msgs[0], maxRunes)
|
|
}
|
|
|
|
first := msgs[0]
|
|
last := msgs[len(msgs)-1]
|
|
sep := "\n---\n...(中间对话省略)...\n---\n"
|
|
sepLen := len([]rune(sep))
|
|
|
|
budget := maxRunes - sepLen
|
|
if budget <= 0 {
|
|
return truncateRunes(first+"\n---\n"+last, maxRunes)
|
|
}
|
|
|
|
halfBudget := budget / 2
|
|
firstTrunc := truncateRunes(first, halfBudget)
|
|
lastTrunc := truncateRunes(last, budget-len([]rune(firstTrunc)))
|
|
|
|
return firstTrunc + sep + lastTrunc
|
|
}
|
|
|
|
func truncateRunes(s string, max int) string {
|
|
rs := []rune(s)
|
|
if len(rs) <= max {
|
|
return s
|
|
}
|
|
if max <= 0 {
|
|
return ""
|
|
}
|
|
return string(rs[:max])
|
|
}
|