mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-11 00:27:53 +02:00
Add files via upload
This commit is contained in:
@@ -146,20 +146,27 @@ func newEinoSummarizationMiddleware(
|
||||
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
||||
},
|
||||
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
||||
if logger == nil {
|
||||
return nil
|
||||
if transcriptPath != "" && len(before.Messages) > 0 {
|
||||
if werr := writeSummarizationTranscript(transcriptPath, before.Messages); werr != nil && logger != nil {
|
||||
logger.Warn("eino summarization transcript 写入失败",
|
||||
zap.String("path", transcriptPath),
|
||||
zap.Error(werr),
|
||||
)
|
||||
}
|
||||
}
|
||||
if logger != nil {
|
||||
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
|
||||
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
|
||||
logger.Info("eino summarization 已压缩上下文",
|
||||
zap.Int("messages_before", len(before.Messages)),
|
||||
zap.Int("messages_after", len(after.Messages)),
|
||||
zap.Int("tokens_before_estimated", beforeTokens),
|
||||
zap.Int("tokens_after_estimated", afterTokens),
|
||||
zap.Int("max_total_tokens", maxTotal),
|
||||
zap.Int("trigger_context_tokens", trigger),
|
||||
zap.String("transcript_file", transcriptPath),
|
||||
)
|
||||
}
|
||||
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
|
||||
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
|
||||
logger.Info("eino summarization 已压缩上下文",
|
||||
zap.Int("messages_before", len(before.Messages)),
|
||||
zap.Int("messages_after", len(after.Messages)),
|
||||
zap.Int("tokens_before_estimated", beforeTokens),
|
||||
zap.Int("tokens_after_estimated", afterTokens),
|
||||
zap.Int("max_total_tokens", maxTotal),
|
||||
zap.Int("trigger_context_tokens", trigger),
|
||||
zap.String("transcript_file", transcriptPath),
|
||||
)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -335,6 +342,23 @@ func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
|
||||
return rounds
|
||||
}
|
||||
|
||||
// writeSummarizationTranscript persists pre-compaction history for read_file after summarization.
|
||||
// Eino TranscriptFilePath only embeds the path in summary text; the file must be written by the host app.
|
||||
func writeSummarizationTranscript(path string, msgs []adk.Message) error {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
body := formatSummarizationTranscript(msgs)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return fmt.Errorf("mkdir transcript dir: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(body), 0o600); err != nil {
|
||||
return fmt.Errorf("write transcript: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
@@ -2,6 +2,9 @@ package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -343,3 +346,91 @@ func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSummarizationTranscript(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "summarization", "transcript.txt")
|
||||
msgs := []adk.Message{
|
||||
schema.UserMessage("scan target"),
|
||||
assistantToolCallsMsg("", "tc1"),
|
||||
schema.ToolMessage("nmap output", "tc1"),
|
||||
}
|
||||
if err := writeSummarizationTranscript(path, msgs); err != nil {
|
||||
t.Fatalf("writeSummarizationTranscript: %v", err)
|
||||
}
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read transcript: %v", err)
|
||||
}
|
||||
text := string(body)
|
||||
if !strings.Contains(text, "Pre-compaction session record") {
|
||||
t.Fatalf("missing transcript header: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "[user]") || !strings.Contains(text, "scan target") {
|
||||
t.Fatalf("missing user section: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
|
||||
t.Fatalf("missing tool round: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
t.Parallel()
|
||||
system := strings.Join([]string{
|
||||
"以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。",
|
||||
"- nmap",
|
||||
"- nuclei",
|
||||
"",
|
||||
"使用规则:",
|
||||
"1) 上表仅为名称索引",
|
||||
"5) 不要臆造不存在的工具名。",
|
||||
"",
|
||||
"你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。",
|
||||
"高强度扫描要求:全力出击",
|
||||
"",
|
||||
"## 项目黑板索引(project: 123, id: abc)",
|
||||
"(暂无事实)",
|
||||
"需要写入请使用 upsert_project_fact。",
|
||||
"",
|
||||
"# Skills System",
|
||||
"**How to Use Skills**",
|
||||
"Remember: Skills make you more capable",
|
||||
}, "\n")
|
||||
|
||||
out := sanitizeSystemContentForTranscript(system)
|
||||
if strings.Contains(out, "以下是当前会话绑定的工具名称索引") {
|
||||
t.Fatalf("tool index should be stripped: %q", out)
|
||||
}
|
||||
if strings.Contains(out, "- nmap") || strings.Contains(out, "高强度扫描要求") {
|
||||
t.Fatalf("static persona should be stripped: %q", out)
|
||||
}
|
||||
if strings.Contains(out, "# Skills System") || strings.Contains(out, "How to Use Skills") {
|
||||
t.Fatalf("skills boilerplate should be stripped: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, transcriptStaticSystemOmitNote) {
|
||||
t.Fatalf("missing omission note: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "## 项目黑板索引(project: 123, id: abc)") {
|
||||
t.Fatalf("project blackboard should be kept: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n# Skills System\nboiler"),
|
||||
schema.UserMessage("hello"),
|
||||
schema.AssistantMessage("reply", nil),
|
||||
}
|
||||
out := formatSummarizationTranscript(msgs)
|
||||
if strings.Contains(out, "- nmap") {
|
||||
t.Fatalf("tool list leaked into transcript: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "hello") || !strings.Contains(out, "reply") {
|
||||
t.Fatalf("conversation turns missing: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "## 项目黑板索引(project: p1, id: x)") {
|
||||
t.Fatalf("dynamic blackboard missing: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
const (
|
||||
transcriptFileHeader = `# CyberStrikeAI summarization transcript
|
||||
# Pre-compaction session record for read_file after context compression.
|
||||
# Omits static system/tool-index/skills boilerplate; full user/assistant/tool turns below.
|
||||
|
||||
`
|
||||
transcriptStaticSystemOmitNote = "[static system prompt omitted — unchanged in live context after compaction]"
|
||||
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
|
||||
transcriptPersonaStartMarker = "你是CyberStrikeAI"
|
||||
transcriptSkillsSystemMarker = "# Skills System"
|
||||
transcriptProjectBlackboardMarker = "## 项目黑板索引"
|
||||
)
|
||||
|
||||
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
|
||||
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
|
||||
func formatSummarizationTranscript(msgs []adk.Message) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(transcriptFileHeader)
|
||||
wrote := false
|
||||
for _, msg := range msgs {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
switch msg.Role {
|
||||
case schema.System:
|
||||
body := sanitizeSystemContentForTranscript(msg.Content)
|
||||
if strings.TrimSpace(body) == "" {
|
||||
continue
|
||||
}
|
||||
if wrote {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
appendTranscriptSection(&sb, schema.System, body)
|
||||
wrote = true
|
||||
default:
|
||||
if wrote {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
appendTranscriptMessage(&sb, msg)
|
||||
wrote = true
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func sanitizeSystemContentForTranscript(content string) string {
|
||||
content = stripToolNamesIndexFromSystem(content)
|
||||
content = stripSkillsSystemBoilerplate(content)
|
||||
blackboard := extractProjectBlackboardSection(content)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(transcriptStaticSystemOmitNote)
|
||||
if bb := strings.TrimSpace(blackboard); bb != "" {
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(bb)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func stripToolNamesIndexFromSystem(s string) string {
|
||||
if !strings.Contains(s, transcriptToolIndexStartMarker) {
|
||||
return s
|
||||
}
|
||||
idx := strings.Index(s, transcriptPersonaStartMarker)
|
||||
if idx < 0 {
|
||||
return s
|
||||
}
|
||||
return strings.TrimSpace(s[idx:])
|
||||
}
|
||||
|
||||
func stripSkillsSystemBoilerplate(s string) string {
|
||||
idx := strings.Index(s, transcriptSkillsSystemMarker)
|
||||
if idx < 0 {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
return strings.TrimSpace(s[:idx])
|
||||
}
|
||||
|
||||
func extractProjectBlackboardSection(s string) string {
|
||||
idx := strings.Index(s, transcriptProjectBlackboardMarker)
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s[idx:])
|
||||
}
|
||||
|
||||
func appendTranscriptSection(sb *strings.Builder, role schema.RoleType, body string) {
|
||||
sb.WriteString("--- [")
|
||||
sb.WriteString(string(role))
|
||||
sb.WriteString("] ---\n")
|
||||
sb.WriteString(body)
|
||||
if !strings.HasSuffix(body, "\n") {
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
|
||||
sb.WriteString("--- [")
|
||||
sb.WriteString(string(msg.Role))
|
||||
sb.WriteString("] ---\n")
|
||||
if msg.Content != "" {
|
||||
sb.WriteString(msg.Content)
|
||||
if !strings.HasSuffix(msg.Content, "\n") {
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
if msg.ReasoningContent != "" {
|
||||
sb.WriteString("[reasoning]\n")
|
||||
sb.WriteString(msg.ReasoningContent)
|
||||
if !strings.HasSuffix(msg.ReasoningContent, "\n") {
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
for _, part := range msg.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText && strings.TrimSpace(part.Text) != "" {
|
||||
sb.WriteString(part.Text)
|
||||
if !strings.HasSuffix(part.Text, "\n") {
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||||
sb.WriteString("tool_calls: ")
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
if msg.ToolCallID != "" {
|
||||
sb.WriteString("tool_call_id: ")
|
||||
sb.WriteString(msg.ToolCallID)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user