diff --git a/internal/multiagent/eino_orchestration.go b/internal/multiagent/eino_orchestration.go index 8461225f..fa387137 100644 --- a/internal/multiagent/eino_orchestration.go +++ b/internal/multiagent/eino_orchestration.go @@ -7,6 +7,7 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/adk" @@ -29,7 +30,9 @@ type PlanExecuteRootArgs struct { MwCfg *config.MultiAgentEinoMiddlewareConfig // ConversationID is used for transcript/isolation paths in middleware. ConversationID string - Logger *zap.Logger + DB *database.DB + ProjectID string + Logger *zap.Logger // ModelName is used for model input token estimation logs. ModelName string // ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask), @@ -93,7 +96,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma } // 4. summarization(最后,与 Deep/Supervisor 一致) if a.AppCfg != nil { - sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger) + sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger) if sumErr != nil { return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr) } diff --git a/internal/multiagent/eino_single_runner.go b/internal/multiagent/eino_single_runner.go index 96b9df91..c38e508a 100644 --- a/internal/multiagent/eino_single_runner.go +++ b/internal/multiagent/eino_single_runner.go @@ -11,6 +11,7 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/project" @@ -32,6 +33,7 @@ func RunEinoSingleChatModelAgent( appCfg *config.Config, ma *config.MultiAgentConfig, ag *agent.Agent, + db *database.DB, logger *zap.Logger, conversationID string, projectID string, @@ -121,7 +123,7 @@ func RunEinoSingleChatModelAgent( return nil, fmt.Errorf("eino single 模型: %w", err) } - mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger) + mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger) if err != nil { return nil, fmt.Errorf("eino single summarization: %w", err) } diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go index 5dc358b8..702eb6e1 100644 --- a/internal/multiagent/eino_summarize.go +++ b/internal/multiagent/eino_summarize.go @@ -9,7 +9,9 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" copenai "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/project" "github.com/bytedance/sonic" "github.com/cloudwego/eino/adk" @@ -40,6 +42,8 @@ func newEinoSummarizationMiddleware( appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig, conversationID string, + db *database.DB, + projectID string, logger *zap.Logger, ) (adk.ChatModelAgentMiddleware, error) { if summaryModel == nil || appCfg == nil { @@ -143,7 +147,14 @@ func newEinoSummarizationMiddleware( }, }, Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) { - return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax) + out, ferr := summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax) + if ferr != nil { + return nil, ferr + } + if appCfg != nil { + out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger) + } + return out, nil }, Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error { if transcriptPath != "" && len(before.Messages) > 0 { @@ -176,6 +187,50 @@ func newEinoSummarizationMiddleware( return mw, nil } +// refreshFactIndexInMessages 在 summarization 压缩后,用 DB 最新索引替换 system 中已有的项目黑板索引段。 +func refreshFactIndexInMessages(msgs []adk.Message, db *database.DB, projectID string, cfg config.ProjectConfig, logger *zap.Logger) []adk.Message { + if db == nil || !cfg.Enabled { + return msgs + } + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return msgs + } + freshIndex, err := project.BuildFactIndexBlock(db, projectID, cfg) + if err != nil { + if logger != nil { + logger.Warn("summarization: 刷新项目黑板索引失败", zap.String("projectId", projectID), zap.Error(err)) + } + return msgs + } + freshIndex = strings.TrimSpace(freshIndex) + if freshIndex == "" { + return msgs + } + + changed := false + out := make([]adk.Message, len(msgs)) + for i, msg := range msgs { + if msg == nil || msg.Role != schema.System { + out[i] = msg + continue + } + newContent, ok := project.ReplaceFactIndexSection(msg.Content, freshIndex) + if !ok { + out[i] = msg + continue + } + cloned := *msg + cloned.Content = newContent + out[i] = &cloned + changed = true + } + if changed && logger != nil { + logger.Info("summarization: 已刷新项目黑板索引", zap.String("projectId", projectID)) + } + return out +} + // summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。 // // 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。 diff --git a/internal/multiagent/eino_summarize_test.go b/internal/multiagent/eino_summarize_test.go index 7197f672..94fdefda 100644 --- a/internal/multiagent/eino_summarize_test.go +++ b/internal/multiagent/eino_summarize_test.go @@ -7,9 +7,14 @@ import ( "strings" "testing" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/project" + "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk/middlewares/summarization" "github.com/cloudwego/eino/schema" + "go.uber.org/zap" ) // fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。 @@ -389,9 +394,11 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) { "你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。", "高强度扫描要求:全力出击", "", + project.FactIndexSectionStartMarker, "## 项目黑板索引(project: 123, id: abc)", "(暂无事实)", "需要写入请使用 upsert_project_fact。", + project.FactIndexSectionEndMarker, "", "# Skills System", "**How to Use Skills**", @@ -419,7 +426,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) { func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) { t.Parallel() msgs := []adk.Message{ - schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n# Skills System\nboiler"), + schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n# Skills System\nboiler"), schema.UserMessage("hello"), schema.AssistantMessage("reply", nil), } @@ -434,3 +441,51 @@ func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) { t.Fatalf("dynamic blackboard missing: %q", out) } } + +func TestRefreshFactIndexInMessages(t *testing.T) { + t.Parallel() + dbPath := filepath.Join(t.TempDir(), "summarize-facts.db") + db, err := database.NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + proj, err := db.CreateProject(&database.Project{Name: "summarize-proj"}) + if err != nil { + t.Fatal(err) + } + + cfg := config.ProjectConfig{Enabled: true} + oldIndex, err := project.BuildFactIndexBlock(db, proj.ID, cfg) + if err != nil { + t.Fatal(err) + } + + _, err = db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: proj.ID, + FactKey: "target/host", + Category: "target", + Summary: "fresh host fact", + }) + if err != nil { + t.Fatal(err) + } + + msgs := []adk.Message{ + schema.SystemMessage("instruction\n\n" + oldIndex), + schema.UserMessage("hi"), + } + + out := refreshFactIndexInMessages(msgs, db, proj.ID, cfg, nil) + sys := out[0].Content + if strings.Contains(sys, "(暂无事实)") { + t.Fatalf("expected refreshed index, got: %q", sys) + } + if !strings.Contains(sys, "fresh host fact") { + t.Fatalf("expected new fact in index: %q", sys) + } + if !strings.Contains(sys, "instruction") { + t.Fatalf("non-index system content should be preserved: %q", sys) + } +} diff --git a/internal/multiagent/eino_summarize_transcript.go b/internal/multiagent/eino_summarize_transcript.go index 7c31f040..fcb7e2c3 100644 --- a/internal/multiagent/eino_summarize_transcript.go +++ b/internal/multiagent/eino_summarize_transcript.go @@ -6,6 +6,8 @@ import ( "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/schema" + "cyberstrike-ai/internal/project" + "github.com/bytedance/sonic" ) @@ -19,7 +21,6 @@ const ( transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引" transcriptPersonaStartMarker = "你是CyberStrikeAI" transcriptSkillsSystemMarker = "# Skills System" - transcriptProjectBlackboardMarker = "## 项目黑板索引" ) // formatSummarizationTranscript renders pre-compaction messages for transcript.txt. @@ -88,11 +89,17 @@ func stripSkillsSystemBoilerplate(s string) string { } func extractProjectBlackboardSection(s string) string { - idx := strings.Index(s, transcriptProjectBlackboardMarker) - if idx < 0 { + start := strings.Index(s, project.FactIndexSectionStartMarker) + if start < 0 { return "" } - return strings.TrimSpace(s[idx:]) + section := s[start:] + end := strings.Index(section, project.FactIndexSectionEndMarker) + if end < 0 { + return "" + } + section = section[:end+len(project.FactIndexSectionEndMarker)] + return strings.TrimSpace(section) } func appendTranscriptSection(sb *strings.Builder, role schema.RoleType, body string) { diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index 70279edc..2cc2c32d 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -15,6 +15,7 @@ import ( "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/project" @@ -56,6 +57,7 @@ func RunDeepAgent( appCfg *config.Config, ma *config.MultiAgentConfig, ag *agent.Agent, + db *database.DB, logger *zap.Logger, conversationID string, projectID string, @@ -210,7 +212,7 @@ func RunDeepAgent( subMax := resolveMaxIterations(appCfg, sub.MaxIterations) - subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, logger) + subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger) if err != nil { return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err) } @@ -281,7 +283,7 @@ func RunDeepAgent( return nil, fmt.Errorf("多代理主模型: %w", err) } - mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger) + mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger) if err != nil { return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err) } @@ -441,6 +443,8 @@ func RunDeepAgent( AppCfg: appCfg, MwCfg: &ma.EinoMiddleware, ConversationID: conversationID, + DB: db, + ProjectID: projectID, Logger: logger, ModelName: appCfg.OpenAI.Model, ExecPreMiddlewares: mainOrchestratorPre,