mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-06 17:46:46 +02:00
Add files via upload
This commit is contained in:
@@ -57,19 +57,30 @@ func newEinoSummarizationMiddleware(
|
||||
if modelName == "" {
|
||||
modelName = "gpt-4o"
|
||||
}
|
||||
tokenCounter := einoSummarizationTokenCounter(modelName)
|
||||
recentTrailMax := trigger / 4
|
||||
if recentTrailMax < 2048 {
|
||||
recentTrailMax = 2048
|
||||
}
|
||||
if recentTrailMax > trigger/2 {
|
||||
recentTrailMax = trigger / 2
|
||||
}
|
||||
|
||||
mw, err := summarization.New(ctx, &summarization.Config{
|
||||
Model: summaryModel,
|
||||
Trigger: &summarization.TriggerCondition{
|
||||
ContextTokens: trigger,
|
||||
},
|
||||
TokenCounter: einoSummarizationTokenCounter(modelName),
|
||||
TokenCounter: tokenCounter,
|
||||
UserInstruction: einoSummarizeUserInstruction,
|
||||
EmitInternalEvents: false,
|
||||
PreserveUserMessages: &summarization.PreserveUserMessages{
|
||||
Enabled: true,
|
||||
MaxTokens: preserveMax,
|
||||
},
|
||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
||||
},
|
||||
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
||||
if logger == nil {
|
||||
return nil
|
||||
@@ -89,6 +100,108 @@ func newEinoSummarizationMiddleware(
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
||||
func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
ctx context.Context,
|
||||
originalMessages []adk.Message,
|
||||
summary adk.Message,
|
||||
tokenCounter summarization.TokenCounterFunc,
|
||||
recentTrailTokenBudget int,
|
||||
) ([]adk.Message, error) {
|
||||
systemMsgs := make([]adk.Message, 0, len(originalMessages))
|
||||
nonSystem := make([]adk.Message, 0, len(originalMessages))
|
||||
for _, msg := range originalMessages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.System {
|
||||
systemMsgs = append(systemMsgs, msg)
|
||||
continue
|
||||
}
|
||||
nonSystem = append(nonSystem, msg)
|
||||
}
|
||||
|
||||
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
selectedReverse := make([]adk.Message, 0, 8)
|
||||
seen := make(map[adk.Message]struct{})
|
||||
totalTokens := 0
|
||||
assistantToolKept := 0
|
||||
const minAssistantToolTrail = 4
|
||||
|
||||
tryKeep := func(msg adk.Message) (bool, error) {
|
||||
if msg == nil {
|
||||
return false, nil
|
||||
}
|
||||
if _, ok := seen[msg]; ok {
|
||||
return false, nil
|
||||
}
|
||||
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: []adk.Message{msg}})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
if totalTokens+n > recentTrailTokenBudget {
|
||||
return false, nil
|
||||
}
|
||||
totalTokens += n
|
||||
selectedReverse = append(selectedReverse, msg)
|
||||
seen[msg] = struct{}{}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 优先保留最近 assistant/tool,确保执行轨迹可续跑。
|
||||
for i := len(nonSystem) - 1; i >= 0; i-- {
|
||||
msg := nonSystem[i]
|
||||
if msg.Role != schema.Assistant && msg.Role != schema.Tool {
|
||||
continue
|
||||
}
|
||||
ok, err := tryKeep(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
assistantToolKept++
|
||||
}
|
||||
if assistantToolKept >= minAssistantToolTrail {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 在预算内回填更多最近消息,保持短链路上下文。
|
||||
for i := len(nonSystem) - 1; i >= 0; i-- {
|
||||
_, exists := seen[nonSystem[i]]
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
ok, err := tryKeep(nonSystem[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
selected := make([]adk.Message, 0, len(selectedReverse))
|
||||
for i := len(selectedReverse) - 1; i >= 0; i-- {
|
||||
selected = append(selected, selectedReverse[i])
|
||||
}
|
||||
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selected))
|
||||
out = append(out, systemMsgs...)
|
||||
out = append(out, summary)
|
||||
out = append(out, selected...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
Reference in New Issue
Block a user