From 90cd119a836f0ed474792f220ecc8e40ee77613c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 20 Jun 2026 19:35:06 +0800 Subject: [PATCH] Add files via upload --- internal/app/app.go | 5 + internal/app/project_fact_tools.go | 53 +++ internal/attackchain/promote_project.go | 203 +++++++++++ internal/handler/openapi.go | 93 +++++- internal/handler/project.go | 276 +++++++++++++-- internal/project/blackboard.go | 42 ++- internal/project/fact_body_links.go | 256 ++++++++++++++ internal/project/fact_body_links_test.go | 68 ++++ internal/project/fact_edges.go | 389 ++++++++++++++++++++++ internal/project/fact_edges_apply.go | 96 ++++++ internal/project/fact_edges_test.go | 290 ++++++++++++++++ internal/project/fact_index_links.go | 231 +++++++++++++ internal/project/fact_index_links_test.go | 161 +++++++++ internal/project/fact_recording_prompt.go | 95 +----- internal/project/fact_template.go | 15 +- internal/projectprompt/blackboard.go | 132 ++++++++ 16 files changed, 2273 insertions(+), 132 deletions(-) create mode 100644 internal/attackchain/promote_project.go create mode 100644 internal/project/fact_body_links.go create mode 100644 internal/project/fact_body_links_test.go create mode 100644 internal/project/fact_edges.go create mode 100644 internal/project/fact_edges_apply.go create mode 100644 internal/project/fact_edges_test.go create mode 100644 internal/project/fact_index_links.go create mode 100644 internal/project/fact_index_links_test.go create mode 100644 internal/projectprompt/blackboard.go diff --git a/internal/app/app.go b/internal/app/app.go index 9467cb4c..cb875a2f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -1069,6 +1069,11 @@ func setupRoutes( protected.GET("/projects/:id", projectHandler.GetProject) protected.PUT("/projects/:id", projectHandler.UpdateProject) protected.DELETE("/projects/:id", projectHandler.DeleteProject) + protected.GET("/projects/:id/fact-graph", projectHandler.GetFactGraph) + protected.GET("/projects/:id/fact-edges", projectHandler.ListFactEdges) + protected.POST("/projects/:id/fact-edges", projectHandler.CreateFactEdge) + protected.DELETE("/projects/:id/fact-edges/:edgeId", projectHandler.DeleteFactEdge) + protected.POST("/projects/:id/promote-attack-chain/:conversationId", projectHandler.PromoteAttackChain) protected.GET("/projects/:id/facts", projectHandler.ListFacts) protected.POST("/projects/:id/facts", projectHandler.CreateFact) protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact) diff --git a/internal/app/project_fact_tools.go b/internal/app/project_fact_tools.go index ffbff5dc..1365dbde 100644 --- a/internal/app/project_fact_tools.go +++ b/internal/app/project_fact_tools.go @@ -89,6 +89,28 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi "type": "string", "description": "可选:关联的漏洞记录 ID", }, + "links": map[string]interface{}{ + "type": "array", + "description": "可选:关系边(from → 当前 fact)。finding 至少 1 条 {from:target/*, type:discovered_on};finding 上记录 exploit 用 {from:exploit/*, type:exploits}。省略保留已有边;传 [] 清空全部关系边。", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "from": map[string]interface{}{ + "type": "string", + "description": "来源 fact_key:存储为 from → 当前 fact", + }, + "type": map[string]interface{}{ + "type": "string", + "description": "depends_on | leads_to | enables | exploits | discovered_on | contains | part_of | supports", + }, + "confidence": map[string]interface{}{ + "type": "string", + "description": "confirmed | tentative | deprecated", + }, + }, + "required": []string{"from", "type"}, + }, + }, }, "required": []string{"fact_key", "summary"}, }, @@ -124,7 +146,26 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi if err != nil { return textResult("错误: "+err.Error(), true), nil } + if _, hasLinks := args["links"]; hasLinks { + linkInputs, err := project.ParseFactLinkInputs(args["links"]) + if err != nil { + return textResult("错误: "+err.Error(), true), nil + } + convID := agent.ConversationIDFromContext(ctx) + if err := project.PersistFactLinksFromParsed(db, projectID, created.FactKey, convID, linkInputs, true); err != nil { + return textResult("错误: 保存关系边失败: "+err.Error(), true), nil + } + created, _ = db.GetProjectFactByKey(projectID, created.FactKey) + } else if parsed := project.ParseLinksFromBody(created.Body); len(parsed) > 0 { + if err := project.PersistFactIncomingLinks(db, projectID, created.FactKey, parsed, true); err != nil { + return textResult("错误: 从 body 解析边失败: "+err.Error(), true), nil + } + created, _ = db.GetProjectFactByKey(projectID, created.FactKey) + } msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence) + if in, _ := db.ListIncomingProjectFactEdges(projectID, created.FactKey); len(in) > 0 { + msg += "\n关系边: " + project.FormatFactLinksText(in) + } if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" { msg += warn } @@ -164,6 +205,18 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi if f.SourceConversationID != "" { msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID) } + if in, _ := db.ListIncomingProjectFactEdges(projectID, f.FactKey); len(in) > 0 { + msg += "\n关系边(from → 本 fact):\n" + for _, e := range in { + msg += fmt.Sprintf("- %s ← %s (%s)\n", e.EdgeType, e.SourceFactKey, e.Confidence) + } + } + if out, _ := db.ListOutgoingProjectFactEdges(projectID, f.FactKey); len(out) > 0 { + msg += "指向其他事实:\n" + for _, e := range out { + msg += fmt.Sprintf("- %s → %s (%s)\n", e.EdgeType, e.TargetFactKey, e.Confidence) + } + } msg += "\n\n--- body ---\n" + f.Body if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" { msg += warn diff --git a/internal/attackchain/promote_project.go b/internal/attackchain/promote_project.go new file mode 100644 index 00000000..d8a9cd80 --- /dev/null +++ b/internal/attackchain/promote_project.go @@ -0,0 +1,203 @@ +package attackchain + +import ( + "fmt" + "regexp" + "strings" + + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/project" + + "github.com/google/uuid" +) + +var promoteSlugSanitizer = regexp.MustCompile(`[^a-z0-9._/-]+`) + +// PromoteToProjectResult 攻击链沉淀结果。 +type PromoteToProjectResult struct { + FactsCreated int `json:"facts_created"` + FactsUpdated int `json:"facts_updated"` + EdgesCreated int `json:"edges_created"` + FactKeys []string `json:"fact_keys"` + Graph *database.ProjectFactGraph `json:"graph,omitempty"` +} + +// PromoteToProject 将对话攻击链沉淀为项目事实与边。 +func PromoteToProject(db *database.DB, projectID, conversationID string) (*PromoteToProjectResult, error) { + if db == nil { + return nil, fmt.Errorf("database 未初始化") + } + projectID = strings.TrimSpace(projectID) + conversationID = strings.TrimSpace(conversationID) + if projectID == "" || conversationID == "" { + return nil, fmt.Errorf("project_id 与 conversation_id 必填") + } + if _, err := db.GetProject(projectID); err != nil { + return nil, fmt.Errorf("项目不存在") + } + conv, err := db.GetConversation(conversationID) + if err != nil { + return nil, fmt.Errorf("对话不存在") + } + if pid := strings.TrimSpace(conv.ProjectID); pid != "" && pid != projectID { + return nil, fmt.Errorf("对话已绑定其他项目") + } + + nodes, err := db.LoadAttackChainNodes(conversationID) + if err != nil { + return nil, err + } + edges, err := db.LoadAttackChainEdges(conversationID) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, fmt.Errorf("该对话尚无攻击链,请先在对话中生成攻击链") + } + + res := &PromoteToProjectResult{} + nodeToKey := make(map[string]string, len(nodes)) + usedKeys := map[string]int{} + + for _, node := range nodes { + key := allocatePromoteFactKey(node, usedKeys) + nodeToKey[node.ID] = key + category := mapPromoteNodeCategory(node.Type) + existing, getErr := db.GetProjectFactByKey(projectID, key) + f := &database.ProjectFact{ + ProjectID: projectID, + FactKey: key, + Category: category, + Summary: strings.TrimSpace(node.Label), + Body: formatPromotedFactBody(node, conversationID), + Confidence: "tentative", + SourceConversationID: conversationID, + } + if getErr == nil && existing != nil { + f.ID = existing.ID + f.CreatedAt = existing.CreatedAt + if strings.TrimSpace(f.Summary) == "" { + f.Summary = existing.Summary + } + if _, err := db.UpsertProjectFact(f); err != nil { + return nil, err + } + res.FactsUpdated++ + } else { + if _, err := db.UpsertProjectFact(f); err != nil { + return nil, err + } + res.FactsCreated++ + } + res.FactKeys = append(res.FactKeys, key) + } + + for _, edge := range edges { + srcKey, ok1 := nodeToKey[edge.Source] + tgtKey, ok2 := nodeToKey[edge.Target] + if !ok1 || !ok2 || srcKey == tgtKey { + continue + } + edgeType := mapPromoteEdgeType(edge.Type) + incoming, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey) + merged := project.MergeLinkFromInputsUnique(promoteFromEdgeInputsFromDB(incoming), []database.ProjectFactEdgeFromInput{{From: srcKey, Type: edgeType}}) + if err := db.ReplaceIncomingProjectFactEdges(projectID, tgtKey, merged); err != nil { + return nil, err + } + res.EdgesCreated++ + if fact, err := db.GetProjectFactByKey(projectID, tgtKey); err == nil { + in, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey) + fact.Body = project.SyncBodyLinksSection(fact.Body, in) + _, _ = db.UpsertProjectFact(fact) + } + } + + graph, _ := project.BuildProjectFactGraph(db, projectID, "full", true) + res.Graph = graph + return res, nil +} + +func promoteFromEdgeInputsFromDB(edges []*database.ProjectFactEdge) []database.ProjectFactEdgeFromInput { + out := make([]database.ProjectFactEdgeFromInput, 0, len(edges)) + for _, e := range edges { + out = append(out, database.ProjectFactEdgeFromInput{From: e.SourceFactKey, Type: e.EdgeType, Confidence: e.Confidence}) + } + return out +} + +func mapPromoteNodeCategory(nodeType string) string { + switch strings.ToLower(strings.TrimSpace(nodeType)) { + case "target": + return project.FactCategoryTarget + case "vulnerability": + return project.FactCategoryFinding + case "action": + return project.FactCategoryChain + default: + return project.FactCategoryNote + } +} + +func mapPromoteEdgeType(t string) string { + switch strings.ToLower(strings.TrimSpace(t)) { + case "discovers", "discovered_on", "targets": + return "discovered_on" + case "exploits": + return "exploits" + case "enables": + return "enables" + case "depends_on": + return "depends_on" + default: + return "leads_to" + } +} + +func allocatePromoteFactKey(node Node, used map[string]int) string { + prefix := "chain/" + switch strings.ToLower(strings.TrimSpace(node.Type)) { + case "target": + prefix = "target/" + case "vulnerability": + prefix = "finding/" + case "action": + prefix = "chain/" + } + base := promoteSlugify(node.Label) + if base == "" { + base = promoteSlugify(node.ID) + } + if base == "" { + base = uuid.New().String()[:8] + } + key := prefix + base + if n, ok := used[key]; ok { + n++ + used[key] = n + key = fmt.Sprintf("%s-%d", key, n) + } else { + used[key] = 1 + } + return key +} + +func promoteSlugify(s string) string { + s = strings.ToLower(strings.TrimSpace(s)) + s = strings.NewReplacer(" ", "-", "—", "-", "–", "-", "/", "-").Replace(s) + s = promoteSlugSanitizer.ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + if len(s) > 64 { + s = s[:64] + } + return s +} + +func formatPromotedFactBody(node Node, conversationID string) string { + var b strings.Builder + b.WriteString("## 来源\n") + b.WriteString(fmt.Sprintf("- 对话攻击链沉淀\n- source_conversation_id: %s\n- node_id: %s\n- node_type: %s\n\n", conversationID, node.ID, node.Type)) + b.WriteString("## 摘要\n") + b.WriteString(strings.TrimSpace(node.Label)) + b.WriteString("\n\n## 关联\n- 结构化关系边(自动同步):\n (见项目攻击路径图)\n") + return b.String() +} diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index da07001a..70979da0 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -2464,17 +2464,108 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "parameters": []map[string]interface{}{ {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, {"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}}, + {"name": "include_links", "in": "query", "schema": map[string]interface{}{"type": "boolean"}}, + {"name": "include_link_counts", "in": "query", "schema": map[string]interface{}{"type": "boolean"}}, }, - "responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}}, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条(可含 link_counts / outgoing_links)"}}, }, "post": map[string]interface{}{ "tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST", "parameters": []map[string]interface{}{ {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "fact_key": map[string]interface{}{"type": "string"}, + "summary": map[string]interface{}{"type": "string"}, + "links": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "to": map[string]interface{}{"type": "string"}, + "type": map[string]interface{}{"type": "string"}, + }, + }, + }, + "links_text": map[string]interface{}{"type": "string", "description": "type: fact_key 每行一条"}, + }, + }, + }, + }, + }, "responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}}, }, }, + "/api/projects/{id}/fact-graph": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "获取项目事实攻击路径图", "operationId": "getProjectFactGraph", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "view", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"path", "full"}, "default": "path"}}, + {"name": "exclude_deprecated", "in": "query", "schema": map[string]interface{}{"type": "boolean", "default": true}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "nodes + edges"}}, + }, + }, + "/api/projects/{id}/fact-edges": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "列出项目全部事实边", "operationId": "listProjectFactEdges", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "边列表"}}, + }, + "post": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "添加事实边", "operationId": "createProjectFactEdge", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"source_fact_key", "target_fact_key", "edge_type"}, + "properties": map[string]interface{}{ + "source_fact_key": map[string]interface{}{"type": "string"}, + "target_fact_key": map[string]interface{}{"type": "string"}, + "edge_type": map[string]interface{}{"type": "string"}, + "confidence": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "边已创建"}}, + }, + }, + "/api/projects/{id}/fact-edges/{edgeId}": map[string]interface{}{ + "delete": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "删除事实边", "operationId": "deleteProjectFactEdge", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "edgeId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}}, + }, + }, + "/api/projects/{id}/promote-attack-chain/{conversationId}": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "将对话攻击链沉淀到项目事实图", "operationId": "promoteAttackChainToProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "conversationId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "沉淀结果(facts/edges/graph)"}}, + }, + }, "/api/vulnerabilities": map[string]interface{}{ "get": map[string]interface{}{ "tags": []string{"漏洞管理"}, diff --git a/internal/handler/project.go b/internal/handler/project.go index b585c57e..fb393562 100644 --- a/internal/handler/project.go +++ b/internal/handler/project.go @@ -1,10 +1,12 @@ package handler import ( + "fmt" "net/http" "strconv" "strings" + "cyberstrike-ai/internal/attackchain" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/project" @@ -223,26 +225,102 @@ func (h *ProjectHandler) DeleteProject(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"success": true}) } +type factLinkRequest struct { + From string `json:"from"` + Type string `json:"type"` + Confidence string `json:"confidence,omitempty"` +} + type upsertFactRequest struct { - FactKey string `json:"fact_key" binding:"required"` - Category string `json:"category"` - Summary string `json:"summary" binding:"required"` - Body string `json:"body"` - Confidence string `json:"confidence"` - Pinned bool `json:"pinned"` - RelatedVulnerabilityID string `json:"related_vulnerability_id"` + FactKey string `json:"fact_key" binding:"required"` + Category string `json:"category"` + Summary string `json:"summary" binding:"required"` + Body string `json:"body"` + Confidence string `json:"confidence"` + Pinned bool `json:"pinned"` + RelatedVulnerabilityID string `json:"related_vulnerability_id"` + Links []factLinkRequest `json:"links"` + LinksText *string `json:"links_text"` } // updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。 type updateFactRequest struct { - FactKey *string `json:"fact_key"` - Category *string `json:"category"` - Summary *string `json:"summary"` - Body *string `json:"body"` - Confidence *string `json:"confidence"` - Pinned *bool `json:"pinned"` - RelatedVulnerabilityID *string `json:"related_vulnerability_id"` - ClearBody bool `json:"clear_body"` + FactKey *string `json:"fact_key"` + Category *string `json:"category"` + Summary *string `json:"summary"` + Body *string `json:"body"` + Confidence *string `json:"confidence"` + Pinned *bool `json:"pinned"` + RelatedVulnerabilityID *string `json:"related_vulnerability_id"` + ClearBody bool `json:"clear_body"` + Links *[]factLinkRequest `json:"links"` + LinksText *string `json:"links_text"` +} + +func factLinksFromRequest(links []factLinkRequest, linksText *string) (*project.ParsedFactLinks, error) { + if len(links) > 0 { + parsed := &project.ParsedFactLinks{} + for i, l := range links { + from := strings.TrimSpace(l.From) + edgeType := strings.TrimSpace(l.Type) + if from == "" { + return nil, fmt.Errorf("links[%d] 须含 from", i) + } + if edgeType == "" { + return nil, fmt.Errorf("links[%d] 须含 type", i) + } + parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{ + From: from, Type: edgeType, Confidence: strings.TrimSpace(l.Confidence), + }) + } + return parsed, nil + } + if linksText != nil { + in, err := project.ParseFactLinksText(*linksText) + if err != nil { + return nil, err + } + return &project.ParsedFactLinks{Incoming: in}, nil + } + return &project.ParsedFactLinks{Incoming: []database.ProjectFactEdgeFromInput{}}, nil +} + +type factWithLinksResponse struct { + *database.ProjectFact + OutgoingLinks []*database.ProjectFactEdge `json:"outgoing_links,omitempty"` + IncomingLinks []*database.ProjectFactEdge `json:"incoming_links,omitempty"` + LinkCounts *project.LinkCounts `json:"link_counts,omitempty"` +} + +func (h *ProjectHandler) applyFactLinksAfterUpsert(projectID string, fact *database.ProjectFact, links []factLinkRequest, linksText *string, explicitLinks, parseBody bool) error { + if explicitLinks { + parsed, err := factLinksFromRequest(links, linksText) + if err != nil { + return err + } + return project.PersistFactLinksFromParsed(h.db, projectID, fact.FactKey, fact.SourceConversationID, parsed, true) + } + if parseBody { + inputs := project.ParseLinksFromBody(fact.Body) + if inputs == nil { + return nil + } + return project.PersistFactIncomingLinks(h.db, projectID, fact.FactKey, inputs, true) + } + return nil +} + +func (h *ProjectHandler) factResponseWithLinks(projectID string, f *database.ProjectFact, includeLinks bool) interface{} { + if !includeLinks || f == nil { + return f + } + out, _ := h.db.ListOutgoingProjectFactEdges(projectID, f.FactKey) + in, _ := h.db.ListIncomingProjectFactEdges(projectID, f.FactKey) + return &factWithLinksResponse{ + ProjectFact: f, + OutgoingLinks: out, + IncomingLinks: in, + } } // ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情) @@ -254,7 +332,8 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } - c.JSON(http.StatusOK, f) + includeLinks := c.Query("include_links") == "1" || c.Query("include_links") == "true" + c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, f, includeLinks)) return } limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) @@ -285,7 +364,52 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) { } list = filtered } - c.JSON(http.StatusOK, list) + includeLinkCounts := c.Query("include_link_counts") == "1" || c.Query("include_link_counts") == "true" + if !includeLinkCounts { + c.JSON(http.StatusOK, list) + return + } + counts, err := project.LoadProjectFactLinkCounts(h.db, projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + out := make([]factWithLinksResponse, 0, len(list)) + for _, f := range list { + item := factWithLinksResponse{ProjectFact: f} + if c, ok := counts[f.FactKey]; ok { + cc := c + item.LinkCounts = &cc + } + out = append(out, item) + } + c.JSON(http.StatusOK, out) +} + +// GetFactGraph GET /api/projects/:id/fact-graph?view=path|full +func (h *ProjectHandler) GetFactGraph(c *gin.Context) { + projectID := c.Param("id") + if _, err := h.db.GetProject(projectID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + view := c.DefaultQuery("view", "path") + excludeDeprecated := true + if v := c.Query("exclude_deprecated"); v == "0" || v == "false" { + excludeDeprecated = false + } + graph, err := project.BuildProjectFactGraph(h.db, projectID, view, excludeDeprecated) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if graph.Nodes == nil { + graph.Nodes = []database.ProjectFactGraphNode{} + } + if graph.Edges == nil { + graph.Edges = []database.ProjectFactGraphEdge{} + } + c.JSON(http.StatusOK, graph) } // CreateFact POST /api/projects/:id/facts @@ -295,8 +419,9 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + projectID := c.Param("id") f := &database.ProjectFact{ - ProjectID: c.Param("id"), + ProjectID: projectID, FactKey: req.FactKey, Category: req.Category, Summary: req.Summary, @@ -310,16 +435,24 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - c.JSON(http.StatusOK, created) + explicitLinks := req.Links != nil || req.LinksText != nil + if err := h.applyFactLinksAfterUpsert(projectID, created, req.Links, req.LinksText, explicitLinks, !explicitLinks); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + created, _ = h.db.GetProjectFactByKey(projectID, created.FactKey) + c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, created, true)) } // UpdateFact PUT /api/projects/:id/facts/:factId func (h *ProjectHandler) UpdateFact(c *gin.Context) { + projectID := c.Param("id") existing, err := h.db.GetProjectFact(c.Param("factId")) - if err != nil || existing.ProjectID != c.Param("id") { + if err != nil || existing.ProjectID != projectID { c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"}) return } + oldFactKey := existing.FactKey var req updateFactRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -355,7 +488,29 @@ func (h *ProjectHandler) UpdateFact(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - c.JSON(http.StatusOK, updated) + if oldFactKey != updated.FactKey { + if err := h.db.RenameProjectFactKeyEdges(projectID, oldFactKey, updated.FactKey); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + if req.Links != nil || req.LinksText != nil { + var links []factLinkRequest + if req.Links != nil { + links = *req.Links + } + if err := h.applyFactLinksAfterUpsert(projectID, updated, links, req.LinksText, true, false); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + } else if req.ClearBody || req.Body != nil { + if err := h.applyFactLinksAfterUpsert(projectID, updated, nil, nil, false, true); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + } + updated, _ = h.db.GetProjectFactByKey(projectID, updated.FactKey) + c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, updated, true)) } // DeleteFact DELETE /api/projects/:id/facts/:factId @@ -408,3 +563,82 @@ func (h *ProjectHandler) RestoreFact(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"success": true}) } + +type createFactEdgeRequest struct { + SourceFactKey string `json:"source_fact_key" binding:"required"` + TargetFactKey string `json:"target_fact_key" binding:"required"` + EdgeType string `json:"edge_type" binding:"required"` + Confidence string `json:"confidence"` +} + +// ListFactEdges GET /api/projects/:id/fact-edges +func (h *ProjectHandler) ListFactEdges(c *gin.Context) { + projectID := c.Param("id") + edges, err := h.db.ListProjectFactEdgesByProject(projectID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if edges == nil { + edges = []*database.ProjectFactEdge{} + } + c.JSON(http.StatusOK, edges) +} + +// CreateFactEdge POST /api/projects/:id/fact-edges +func (h *ProjectHandler) CreateFactEdge(c *gin.Context) { + projectID := c.Param("id") + var req createFactEdgeRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + edge, err := h.db.AddProjectFactEdge(projectID, database.ProjectFactEdgeInput{ + To: req.TargetFactKey, + Type: req.EdgeType, + Confidence: req.Confidence, + }, req.SourceFactKey, "") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if f, err := h.db.GetProjectFactByKey(projectID, req.TargetFactKey); err == nil { + in, _ := h.db.ListIncomingProjectFactEdges(projectID, req.TargetFactKey) + f.Body = project.SyncBodyLinksSection(f.Body, in) + _, _ = h.db.UpsertProjectFact(f) + } + c.JSON(http.StatusOK, edge) +} + +// DeleteFactEdge DELETE /api/projects/:id/fact-edges/:edgeId +func (h *ProjectHandler) DeleteFactEdge(c *gin.Context) { + projectID := c.Param("id") + edgeID := c.Param("edgeId") + edge, err := h.db.GetProjectFactEdge(edgeID) + if err != nil || edge.ProjectID != projectID { + c.JSON(http.StatusNotFound, gin.H{"error": "边不存在"}) + return + } + if err := h.db.DeleteProjectFactEdge(edgeID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if f, err := h.db.GetProjectFactByKey(projectID, edge.TargetFactKey); err == nil { + in, _ := h.db.ListIncomingProjectFactEdges(projectID, edge.TargetFactKey) + f.Body = project.SyncBodyLinksSection(f.Body, in) + _, _ = h.db.UpsertProjectFact(f) + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +// PromoteAttackChain POST /api/projects/:id/promote-attack-chain/:conversationId +func (h *ProjectHandler) PromoteAttackChain(c *gin.Context) { + projectID := c.Param("id") + conversationID := c.Param("conversationId") + result, err := attackchain.PromoteToProject(h.db, projectID, conversationID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, result) +} diff --git a/internal/project/blackboard.go b/internal/project/blackboard.go index 255f39a2..d1e2aec9 100644 --- a/internal/project/blackboard.go +++ b/internal/project/blackboard.go @@ -2,7 +2,6 @@ package project import ( "fmt" - "sort" "strings" "cyberstrike-ai/internal/config" @@ -24,11 +23,11 @@ func AppendSystemPromptBlock(base, block string) string { const ( factIndexFooterGetDetail = "需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。" - factIndexFooterWriteHint = "写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。" + factIndexFooterWriteHint = "写入事实 links 时用 from(来源 fact_key → 当前 fact),如 finding 上 {from:target/*, type:discovered_on};body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。" factIndexFooterEmpty = "需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。" ) -// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。 +// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(key + summary + 关系边 + 攻击路径,不含 body)。 func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) { if db == nil || !cfg.Enabled { return "", nil @@ -47,27 +46,38 @@ func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectCo if err != nil { return "", err } + allEdges, _ := db.ListProjectFactEdgesByProject(projectID) + _, incomingByTarget := indexEdgeGroupMaps(allEdges) + if len(facts) == 0 { return wrapFactIndexBlock(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n%s", proj.Name, proj.ID, factIndexFooterEmpty)), nil } - sort.SliceStable(facts, func(i, j int) bool { - if facts[i].Pinned != facts[j].Pinned { - return facts[i].Pinned - } - return facts[i].UpdatedAt.After(facts[j].UpdatedAt) - }) + sortFactsForIndex(facts) maxRunes := cfg.FactIndexMaxRunesEffective() + pathMaxRunes := cfg.FactIndexPathMaxRunesEffective() + footer := factIndexFooterGetDetail + "\n" + factIndexFooterWriteHint + footerRunes := len([]rune(footer)) + factsBudget := maxRunes - pathMaxRunes - footerRunes + if factsBudget < 800 { + factsBudget = maxRunes - footerRunes + pathMaxRunes = 0 + } + + indexedKeys := make(map[string]struct{}, len(facts)) var b strings.Builder b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID)) used := len([]rune(b.String())) omitted := 0 for _, f := range facts { - line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence) + indexedKeys[f.FactKey] = struct{}{} + line := fmt.Sprintf("- [%s] %s — %s (%s)", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence) + line += FormatFactIndexLinksHint(f.FactKey, incomingByTarget[f.FactKey]) + line += "\n" lineRunes := len([]rune(line)) - if used+lineRunes > maxRunes { + if used+lineRunes > factsBudget { omitted++ continue } @@ -78,8 +88,12 @@ func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectCo if omitted > 0 { b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted)) } - b.WriteString(factIndexFooterGetDetail) - b.WriteByte('\n') - b.WriteString(factIndexFooterWriteHint) + + if pathSection := BuildFactPathOverviewSection(allEdges, indexedKeys, pathMaxRunes); pathSection != "" { + b.WriteString("\n") + b.WriteString(pathSection) + } + + b.WriteString(footer) return wrapFactIndexBlock(b.String()), nil } diff --git a/internal/project/fact_body_links.go b/internal/project/fact_body_links.go new file mode 100644 index 00000000..0ec69328 --- /dev/null +++ b/internal/project/fact_body_links.go @@ -0,0 +1,256 @@ +package project + +import ( + "fmt" + "regexp" + "strings" + + "cyberstrike-ai/internal/database" +) + +var ( + bodyDepFactLine = regexp.MustCompile(`(?im)^[\s\-*]*依赖事实\s*[::]\s*([a-z0-9][a-z0-9._/-]*)`) + bodyRelFactLine = regexp.MustCompile(`(?im)^[\s\-*]*相关\s*fact_key\s*[::]\s*([a-z0-9][a-z0-9._/-]*)`) + bodyAssocSection = regexp.MustCompile(`(?im)^##\s*关联\s*$`) + bodySyncLinksHead = "结构化关系边(自动同步)" +) + +// ParseLinksFromBody 从 body「关联」段落解析 from 语义的关系边(无显式 links 时的兜底)。 +func ParseLinksFromBody(body string) []database.ProjectFactEdgeFromInput { + body = strings.TrimSpace(body) + if body == "" { + return nil + } + seen := map[string]struct{}{} + var out []database.ProjectFactEdgeFromInput + add := func(key, edgeType string) { + key = strings.TrimSpace(key) + if key == "" { + return + } + if err := database.ValidateFactKey(key); err != nil { + return + } + sig := edgeType + "\x00" + key + if _, ok := seen[sig]; ok { + return + } + seen[sig] = struct{}{} + out = append(out, database.ProjectFactEdgeFromInput{From: key, Type: edgeType}) + } + for _, m := range bodyDepFactLine.FindAllStringSubmatch(body, -1) { + if len(m) > 1 { + add(m[1], "depends_on") + } + } + for _, m := range bodyRelFactLine.FindAllStringSubmatch(body, -1) { + if len(m) > 1 { + add(m[1], "supports") + } + } + // 自动同步块:type: key + syncBlock := extractBodySyncLinksBlock(body) + for _, line := range strings.Split(syncBlock, "\n") { + line = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(line), "-")) + if line == "" { + continue + } + edgeType, source, ok := strings.Cut(line, ":") + if !ok { + continue + } + edgeType = strings.TrimSpace(edgeType) + source = strings.TrimSpace(source) + if err := database.ValidateProjectFactEdgeType(edgeType); err != nil { + continue + } + add(source, edgeType) + } + if len(out) == 0 { + return nil + } + return out +} + +func extractBodySyncLinksBlock(body string) string { + lines := strings.Split(body, "\n") + var b strings.Builder + inAssoc := false + inSync := false + for _, line := range lines { + trim := strings.TrimSpace(line) + if bodyAssocSection.MatchString(trim) { + inAssoc = true + inSync = false + continue + } + if inAssoc && strings.HasPrefix(trim, "## ") && !strings.HasPrefix(trim, "## 关联") { + break + } + if inAssoc && strings.Contains(trim, bodySyncLinksHead) { + inSync = true + continue + } + if inSync { + if trim == "" || strings.HasPrefix(trim, "-") || strings.Contains(trim, ":") { + if strings.HasPrefix(trim, "-") || (strings.Contains(trim, ":") && !strings.Contains(trim, "related_vulnerability")) { + b.WriteString(trim) + b.WriteByte('\n') + } + } else if strings.HasPrefix(trim, "##") { + break + } + } + } + return b.String() +} + +// SyncBodyLinksSection 将入边镜像写入 body 的「关联」段(人读用;结构化以 links 为准)。 +func SyncBodyLinksSection(body string, edges []*database.ProjectFactEdge) string { + body = strings.TrimSpace(body) + block := formatBodySyncLinksBlock(edges) + if block == "" { + return body + } + if body == "" { + return "## 关联\n" + block + } + lines := strings.Split(body, "\n") + var out []string + inAssoc := false + replaced := false + for i := 0; i < len(lines); i++ { + trim := strings.TrimSpace(lines[i]) + if bodyAssocSection.MatchString(trim) { + inAssoc = true + out = append(out, lines[i]) + // 跳过旧同步块 + j := i + 1 + for j < len(lines) { + t := strings.TrimSpace(lines[j]) + if strings.HasPrefix(t, "## ") { + break + } + if strings.Contains(t, bodySyncLinksHead) { + for j < len(lines) { + t2 := strings.TrimSpace(lines[j]) + if t2 != "" && !strings.HasPrefix(t2, "-") && !strings.Contains(t2, ":") && !strings.Contains(t2, bodySyncLinksHead) { + if strings.HasPrefix(t2, "##") { + break + } + } + j++ + if j < len(lines) && strings.HasPrefix(strings.TrimSpace(lines[j]), "## ") { + break + } + if j >= len(lines) { + break + } + if j > i+1 && strings.TrimSpace(lines[j-1]) == "" && strings.HasPrefix(strings.TrimSpace(lines[j]), "## ") { + break + } + } + break + } + j++ + } + out = append(out, block) + i = j - 1 + replaced = true + continue + } + out = append(out, lines[i]) + } + if !replaced { + if !inAssoc { + out = append(out, "", "## 关联", block) + } else { + out = append(out, block) + } + } + return strings.TrimSpace(strings.Join(out, "\n")) +} + +func formatBodySyncLinksBlock(edges []*database.ProjectFactEdge) string { + if len(edges) == 0 { + return fmt.Sprintf("- %s:\n (暂无)", bodySyncLinksHead) + } + var b strings.Builder + b.WriteString("- ") + b.WriteString(bodySyncLinksHead) + b.WriteString(":\n") + for _, e := range edges { + b.WriteString(fmt.Sprintf(" - %s: %s\n", e.EdgeType, e.SourceFactKey)) + } + return strings.TrimRight(b.String(), "\n") +} + +// ResolveFactLinksForUpsert 合并显式 links、links_text 与 body 解析结果。 +func ResolveFactLinksForUpsert(explicit []database.ProjectFactEdgeFromInput, linksText *string, body string, explicitSet bool) ([]database.ProjectFactEdgeFromInput, bool, error) { + if explicitSet { + if len(explicit) > 0 { + return explicit, true, nil + } + if linksText != nil { + parsed, err := ParseFactLinksText(*linksText) + if err != nil { + return nil, true, err + } + if parsed == nil { + return []database.ProjectFactEdgeFromInput{}, true, nil + } + return parsed, true, nil + } + return []database.ProjectFactEdgeFromInput{}, true, nil + } + if parsed := ParseLinksFromBody(body); len(parsed) > 0 { + return parsed, true, nil + } + return nil, false, nil +} + +// MergeLinkFromInputsUnique 合并多组 from 入边输入并去重。 +func MergeLinkFromInputsUnique(groups ...[]database.ProjectFactEdgeFromInput) []database.ProjectFactEdgeFromInput { + seen := map[string]struct{}{} + var out []database.ProjectFactEdgeFromInput + for _, g := range groups { + for _, in := range g { + sig := in.Type + "\x00" + in.From + if _, ok := seen[sig]; ok { + continue + } + if err := database.ValidateProjectFactEdgeType(in.Type); err != nil { + continue + } + if err := database.ValidateFactKey(in.From); err != nil { + continue + } + seen[sig] = struct{}{} + out = append(out, in) + } + } + return out +} + +// MergeLinkInputsUnique 合并多组 link 输入并去重(内部出边写入用)。 +func MergeLinkInputsUnique(groups ...[]database.ProjectFactEdgeInput) []database.ProjectFactEdgeInput { + seen := map[string]struct{}{} + var out []database.ProjectFactEdgeInput + for _, g := range groups { + for _, in := range g { + sig := in.Type + "\x00" + in.To + if _, ok := seen[sig]; ok { + continue + } + if err := database.ValidateProjectFactEdgeType(in.Type); err != nil { + continue + } + if err := database.ValidateFactKey(in.To); err != nil { + continue + } + seen[sig] = struct{}{} + out = append(out, in) + } + } + return out +} diff --git a/internal/project/fact_body_links_test.go b/internal/project/fact_body_links_test.go new file mode 100644 index 00000000..1b5daa95 --- /dev/null +++ b/internal/project/fact_body_links_test.go @@ -0,0 +1,68 @@ +package project + +import ( + "path/filepath" + "strings" + "testing" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +func TestParseLinksFromBodyDependsOn(t *testing.T) { + t.Parallel() + body := "## 关联\n- 依赖事实: target/api\n- 相关 fact_key: auth/session" + links := ParseLinksFromBody(body) + if len(links) != 2 { + t.Fatalf("want 2 links, got %d", len(links)) + } +} + +func TestSyncBodyLinksSection(t *testing.T) { + t.Parallel() + body := "## 结论\nx\n\n## 关联\n- 依赖事实: old/key" + edges := []*database.ProjectFactEdge{{EdgeType: "discovered_on", SourceFactKey: "target/a"}} + out := SyncBodyLinksSection(body, edges) + if !strings.Contains(out, "discovered_on: target/a") { + t.Fatalf("missing synced edge: %q", out) + } +} + +func TestFactGraphIntegration(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + db, err := database.NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + p, err := db.CreateProject(&database.Project{Name: "g"}) + if err != nil { + t.Fatal(err) + } + for _, spec := range []struct{ key, cat, summary string }{ + {"target/root", "target", "root"}, + {"finding/x", "finding", "finding x"}, + } { + _, err := db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.summary, Confidence: "confirmed", + }) + if err != nil { + t.Fatal(err) + } + } + if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/x", []database.ProjectFactEdgeFromInput{ + {From: "target/root", Type: "discovered_on"}, + }); err != nil { + t.Fatal(err) + } + graph, err := BuildProjectFactGraph(db, p.ID, "path", true) + if err != nil { + t.Fatal(err) + } + if len(graph.Nodes) < 2 || len(graph.Edges) < 1 { + t.Fatalf("expected graph nodes/edges, got %d/%d", len(graph.Nodes), len(graph.Edges)) + } +} diff --git a/internal/project/fact_edges.go b/internal/project/fact_edges.go new file mode 100644 index 00000000..2d11e8fb --- /dev/null +++ b/internal/project/fact_edges.go @@ -0,0 +1,389 @@ +package project + +import ( + "fmt" + "strings" + + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/projectprompt" +) + +// PathGraphCategories 攻击路径视图包含的事实分类。 +var PathGraphCategories = map[string]struct{}{ + FactCategoryTarget: {}, + FactCategoryFinding: {}, + FactCategoryChain: {}, + FactCategoryExploit: {}, + FactCategoryPOC: {}, + "vuln": {}, +} + +// GraphNodeType 将 fact category 映射为图节点类型(供前端样式与 ELK 分层)。 +func GraphNodeType(category, factKey string) string { + key := strings.ToLower(strings.TrimSpace(factKey)) + switch { + case strings.HasPrefix(key, "target/"): + return "target" + case strings.HasPrefix(key, "exploit/"), strings.HasPrefix(key, "evidence/"): + return "exploit" + case strings.HasPrefix(key, "poc/"): + return "poc" + case strings.HasPrefix(key, "chain/"): + return "chain" + case strings.HasPrefix(key, "finding/"): + return "finding" + case strings.HasPrefix(key, "auth/"): + return "auth" + case strings.HasPrefix(key, "infra/"), strings.HasPrefix(key, "business/"): + return "infra" + case strings.HasPrefix(key, "vuln:"): + return "vulnerability" + } + c := strings.ToLower(strings.TrimSpace(category)) + switch c { + case FactCategoryTarget: + return "target" + case FactCategoryExploit: + return "exploit" + case FactCategoryPOC: + return "poc" + case FactCategoryChain: + return "chain" + case FactCategoryFinding, "vuln": + return "finding" + case "auth": + return "auth" + case "infra", "business": + return "infra" + default: + return "note" + } +} + +func truncateGraphLabel(summary string, maxRunes int) string { + summary = strings.TrimSpace(summary) + if summary == "" { + return "—" + } + r := []rune(summary) + if len(r) <= maxRunes { + return summary + } + return string(r[:maxRunes]) + "…" +} + +// BuildProjectFactGraph 构建项目事实图(nodes + edges)。 +func BuildProjectFactGraph(db *database.DB, projectID string, view string, excludeDeprecated bool) (*database.ProjectFactGraph, error) { + if db == nil { + return nil, fmt.Errorf("database 未初始化") + } + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return nil, fmt.Errorf("project_id 不能为空") + } + + view = strings.TrimSpace(strings.ToLower(view)) + if view == "" { + view = "path" + } + + filter := database.ProjectFactListFilter{} + if excludeDeprecated { + filter.ExcludeDeprecated = true + } + facts, err := db.ListProjectFacts(projectID, filter, 1000, 0) + if err != nil { + return nil, err + } + + edges, err := db.ListProjectFactEdgesByProject(projectID) + if err != nil { + return nil, err + } + if excludeDeprecated { + edges = filterDeprecatedEdges(edges) + } + + factByKey := make(map[string]*database.ProjectFact, len(facts)) + for _, f := range facts { + factByKey[f.FactKey] = f + } + + pathMode := view == "path" + nodeKeys := make(map[string]struct{}) + + if pathMode { + for _, f := range facts { + if isPathGraphFact(f.Category, f.FactKey) { + nodeKeys[f.FactKey] = struct{}{} + } + } + // 路径视图中保留作为依赖目标的 auth/infra 节点 + for _, e := range edges { + if _, ok := nodeKeys[e.SourceFactKey]; !ok { + continue + } + if f, ok := factByKey[e.TargetFactKey]; ok && isDependencyGraphFact(f.Category, f.FactKey) { + nodeKeys[e.TargetFactKey] = struct{}{} + } + } + } else { + for _, f := range facts { + nodeKeys[f.FactKey] = struct{}{} + } + } + + // 边上引用的 endpoint 纳入节点集 + for _, e := range edges { + if pathMode { + if _, ok := nodeKeys[e.SourceFactKey]; !ok { + continue + } + if _, ok := nodeKeys[e.TargetFactKey]; ok { + // already included + } else if f, ok := factByKey[e.TargetFactKey]; !ok { + nodeKeys[e.TargetFactKey] = struct{}{} // 占位节点 + } else if isPathGraphFact(f.Category, f.FactKey) || isDependencyGraphFact(f.Category, f.FactKey) { + nodeKeys[e.TargetFactKey] = struct{}{} + } else { + continue + } + } else { + nodeKeys[e.SourceFactKey] = struct{}{} + nodeKeys[e.TargetFactKey] = struct{}{} + } + } + + nodes := make([]database.ProjectFactGraphNode, 0, len(nodeKeys)) + for key := range nodeKeys { + if f, ok := factByKey[key]; ok { + nodes = append(nodes, database.ProjectFactGraphNode{ + ID: f.FactKey, + FactKey: f.FactKey, + Category: f.Category, + Label: truncateGraphLabel(f.Summary, 48), + Summary: strings.TrimSpace(f.Summary), + Confidence: f.Confidence, + Type: GraphNodeType(f.Category, f.FactKey), + Pinned: f.Pinned, + }) + continue + } + nodes = append(nodes, database.ProjectFactGraphNode{ + ID: key, + FactKey: key, + Category: "missing", + Label: key, + Confidence: "tentative", + Type: "missing", + Pinned: false, + }) + } + + graphEdges := make([]database.ProjectFactGraphEdge, 0, len(edges)) + for _, e := range edges { + if pathMode { + if _, ok := nodeKeys[e.SourceFactKey]; !ok { + continue + } + if _, ok := nodeKeys[e.TargetFactKey]; !ok { + continue + } + } else { + if _, ok := nodeKeys[e.SourceFactKey]; !ok { + continue + } + if _, ok := nodeKeys[e.TargetFactKey]; !ok { + continue + } + } + graphEdges = append(graphEdges, database.ProjectFactGraphEdge{ + ID: e.ID, + Source: e.SourceFactKey, + Target: e.TargetFactKey, + Type: e.EdgeType, + Confidence: e.Confidence, + }) + } + + // related_vulnerability_id 合成边(source=fact → target=vuln:) + for _, f := range facts { + if _, ok := nodeKeys[f.FactKey]; !ok { + continue + } + vid := strings.TrimSpace(f.RelatedVulnerabilityID) + if vid == "" { + continue + } + vulnNodeID := "vuln:" + vid + if _, exists := nodeKeys[vulnNodeID]; !exists { + nodeKeys[vulnNodeID] = struct{}{} + label := "漏洞" + if len(vid) >= 8 { + label += " " + vid[:8] + "…" + } else { + label += " " + vid + } + nodes = append(nodes, database.ProjectFactGraphNode{ + ID: vulnNodeID, + FactKey: vulnNodeID, + Category: "vuln", + Label: label, + Confidence: f.Confidence, + Type: "vulnerability", + Pinned: false, + }) + } + graphEdges = append(graphEdges, database.ProjectFactGraphEdge{ + ID: "vuln-link:" + f.FactKey + ":" + vid, + Source: f.FactKey, + Target: vulnNodeID, + Type: "links_vuln", + Confidence: f.Confidence, + }) + } + + return &database.ProjectFactGraph{Nodes: nodes, Edges: graphEdges}, nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func isPathGraphFact(category, factKey string) bool { + c := strings.ToLower(strings.TrimSpace(category)) + if _, ok := PathGraphCategories[c]; ok { + return true + } + key := strings.ToLower(strings.TrimSpace(factKey)) + for _, p := range []string{"target/", "finding/", "chain/", "exploit/", "poc/", "evidence/"} { + if strings.HasPrefix(key, p) { + return true + } + } + return false +} + +func isDependencyGraphFact(category, factKey string) bool { + c := strings.ToLower(strings.TrimSpace(category)) + if c == "auth" || c == "infra" || c == "business" { + return true + } + key := strings.ToLower(strings.TrimSpace(factKey)) + return strings.HasPrefix(key, "auth/") || strings.HasPrefix(key, "infra/") || strings.HasPrefix(key, "business/") +} + +func filterDeprecatedEdges(edges []*database.ProjectFactEdge) []*database.ProjectFactEdge { + out := make([]*database.ProjectFactEdge, 0, len(edges)) + for _, e := range edges { + if strings.EqualFold(strings.TrimSpace(e.Confidence), "deprecated") { + continue + } + out = append(out, e) + } + return out +} + +// ParsedFactLinks 解析 links 参数(from → 当前 fact)。 +type ParsedFactLinks struct { + Incoming []database.ProjectFactEdgeFromInput +} + +// ParseFactLinkInputs 从 MCP links 参数解析;空数组表示清空全部入边。 +func ParseFactLinkInputs(raw interface{}) (*ParsedFactLinks, error) { + if raw == nil { + return nil, nil + } + items, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("links 须为数组") + } + if len(items) == 0 { + return &ParsedFactLinks{ + Incoming: []database.ProjectFactEdgeFromInput{}, + }, nil + } + parsed := &ParsedFactLinks{} + for i, item := range items { + m, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("links[%d] 格式无效", i) + } + from, _ := m["from"].(string) + edgeType, _ := m["type"].(string) + from = strings.TrimSpace(from) + edgeType = strings.TrimSpace(edgeType) + if from == "" { + return nil, fmt.Errorf("links[%d] 须含 from", i) + } + if edgeType == "" { + return nil, fmt.Errorf("links[%d] 须含 type", i) + } + conf, _ := m["confidence"].(string) + parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{ + From: from, Type: edgeType, Confidence: strings.TrimSpace(conf), + }) + } + return parsed, nil +} + +// ParseFactLinksText 解析 UI 文本:`type: source_fact_key` 每行一条(from 语义)。 +func ParseFactLinksText(text string) ([]database.ProjectFactEdgeFromInput, error) { + return ParseFactIncomingLinksText(text) +} + +// FormatFactLinksText 将入边格式化为 UI 文本。 +func FormatFactLinksText(edges []*database.ProjectFactEdge) string { + return FormatFactIncomingLinksText(edges) +} + +// ParseFactIncomingLinksText 解析 UI 入边文本:`type: source_fact_key` 每行一条。 +func ParseFactIncomingLinksText(text string) ([]database.ProjectFactEdgeFromInput, error) { + text = strings.TrimSpace(text) + if text == "" { + return nil, nil + } + var out []database.ProjectFactEdgeFromInput + for i, line := range strings.Split(text, "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + edgeType, source, ok := strings.Cut(line, ":") + if !ok { + return nil, fmt.Errorf("第 %d 行格式无效,应为 type: fact_key", i+1) + } + edgeType = strings.TrimSpace(edgeType) + source = strings.TrimSpace(source) + if edgeType == "" || source == "" { + return nil, fmt.Errorf("第 %d 行 type 或 fact_key 为空", i+1) + } + out = append(out, database.ProjectFactEdgeFromInput{From: source, Type: edgeType}) + } + return out, nil +} + +// FormatFactIncomingLinksText 将入边格式化为 UI 文本。 +func FormatFactIncomingLinksText(edges []*database.ProjectFactEdge) string { + if len(edges) == 0 { + return "" + } + var b strings.Builder + for i, e := range edges { + if i > 0 { + b.WriteByte('\n') + } + b.WriteString(e.EdgeType) + b.WriteString(": ") + b.WriteString(e.SourceFactKey) + } + return b.String() +} + +// FactEdgeRecordingGuidance 写入边时的 Agent 规范。 +func FactEdgeRecordingGuidance() string { + return projectprompt.FactEdgeRecordingGuidance() +} diff --git a/internal/project/fact_edges_apply.go b/internal/project/fact_edges_apply.go new file mode 100644 index 00000000..870861e4 --- /dev/null +++ b/internal/project/fact_edges_apply.go @@ -0,0 +1,96 @@ +package project + +import ( + "cyberstrike-ai/internal/database" +) + +// ApplyFactOutgoingLinks 替换某事实的出边(links 为 nil 时不修改)。 +func ApplyFactOutgoingLinks(db *database.DB, projectID, sourceFactKey, sourceConversationID string, links []database.ProjectFactEdgeInput) error { + if links == nil { + return nil + } + return db.ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID, links) +} + +// ResolveFactLinkInputs 合并 links 数组与 links_text 文本(数组优先)。 +func ResolveFactLinkInputs(links []database.ProjectFactEdgeFromInput, linksText string) ([]database.ProjectFactEdgeFromInput, error) { + if len(links) > 0 { + return links, nil + } + return ParseFactLinksText(linksText) +} + +// ApplyFactIncomingLinks 替换某事实的入边(links 为 nil 时不修改)。 +func ApplyFactIncomingLinks(db *database.DB, projectID, targetFactKey string, links []database.ProjectFactEdgeFromInput) error { + if links == nil { + return nil + } + return db.ReplaceIncomingProjectFactEdges(projectID, targetFactKey, links) +} + +// PersistFactIncomingLinks 写入入边并可选同步当前事实 body「关联」段。 +func PersistFactIncomingLinks(db *database.DB, projectID, targetFactKey string, links []database.ProjectFactEdgeFromInput, syncBody bool) error { + if links == nil { + return nil + } + if err := ApplyFactIncomingLinks(db, projectID, targetFactKey, links); err != nil { + return err + } + if !syncBody { + return nil + } + f, err := db.GetProjectFactByKey(projectID, targetFactKey) + if err != nil { + return nil + } + in, err := db.ListIncomingProjectFactEdges(projectID, targetFactKey) + if err != nil { + return err + } + f.Body = SyncBodyLinksSection(f.Body, in) + _, err = db.UpsertProjectFact(f) + return err +} + +// PersistFactLinksFromParsed 写入解析后的 links(parsed 为 nil 表示不修改)。 +func PersistFactLinksFromParsed(db *database.DB, projectID, factKey, sourceConversationID string, parsed *ParsedFactLinks, syncBody bool) error { + if parsed == nil || parsed.Incoming == nil { + return nil + } + return PersistFactIncomingLinks(db, projectID, factKey, parsed.Incoming, syncBody) +} + +// PersistFactOutgoingLinks 写入出边(图连线等低层 API;body 同步请用 PersistFactIncomingLinks)。 +func PersistFactOutgoingLinks(db *database.DB, projectID, sourceFactKey, sourceConversationID string, links []database.ProjectFactEdgeInput, syncBody bool) error { + if links == nil { + return nil + } + return ApplyFactOutgoingLinks(db, projectID, sourceFactKey, sourceConversationID, links) +} + +// LinkCountMap 项目内各 fact 的入/出边计数。 +type LinkCountMap map[string]LinkCounts + +// LinkCounts 单 fact 的入/出边数。 +type LinkCounts struct { + Outgoing int `json:"outgoing"` + Incoming int `json:"incoming"` +} + +// LoadProjectFactLinkCounts 批量加载边计数。 +func LoadProjectFactLinkCounts(db *database.DB, projectID string) (LinkCountMap, error) { + edges, err := db.ListProjectFactEdgesByProject(projectID) + if err != nil { + return nil, err + } + m := LinkCountMap{} + for _, e := range edges { + c := m[e.SourceFactKey] + c.Outgoing++ + m[e.SourceFactKey] = c + c = m[e.TargetFactKey] + c.Incoming++ + m[e.TargetFactKey] = c + } + return m, nil +} diff --git a/internal/project/fact_edges_test.go b/internal/project/fact_edges_test.go new file mode 100644 index 00000000..eae1a435 --- /dev/null +++ b/internal/project/fact_edges_test.go @@ -0,0 +1,290 @@ +package project + +import ( + "path/filepath" + "testing" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +func TestParseFactLinksText(t *testing.T) { + t.Parallel() + inputs, err := ParseFactLinksText("discovered_on: target/api\nleads_to: finding/swagger") + if err != nil { + t.Fatal(err) + } + if len(inputs) != 2 { + t.Fatalf("want 2 links, got %d", len(inputs)) + } + if inputs[0].Type != "discovered_on" || inputs[0].From != "target/api" { + t.Fatalf("unexpected first link: %+v", inputs[0]) + } +} + +func TestParseFactIncomingLinksText(t *testing.T) { + t.Parallel() + inputs, err := ParseFactIncomingLinksText("leads_to: finding/swagger\ndepends_on: target/api") + if err != nil { + t.Fatal(err) + } + if len(inputs) != 2 { + t.Fatalf("want 2 links, got %d", len(inputs)) + } + if inputs[0].Type != "leads_to" || inputs[0].From != "finding/swagger" { + t.Fatalf("unexpected first link: %+v", inputs[0]) + } +} + +func TestFormatFactIncomingLinksText(t *testing.T) { + t.Parallel() + text := FormatFactIncomingLinksText([]*database.ProjectFactEdge{ + {EdgeType: "leads_to", SourceFactKey: "finding/a"}, + {EdgeType: "depends_on", SourceFactKey: "target/b"}, + }) + want := "leads_to: finding/a\ndepends_on: target/b" + if text != want { + t.Fatalf("got %q want %q", text, want) + } +} + +func TestParseFactLinkInputsEmptyClears(t *testing.T) { + t.Parallel() + parsed, err := ParseFactLinkInputs([]interface{}{}) + if err != nil { + t.Fatal(err) + } + if parsed == nil || parsed.Incoming == nil || len(parsed.Incoming) != 0 { + t.Fatalf("empty array should clear incoming links, got %v", parsed) + } +} + +func TestParseFactLinkInputsFrom(t *testing.T) { + t.Parallel() + raw := []interface{}{ + map[string]interface{}{ + "from": "target/primary_domain", + "type": "discovered_on", + }, + } + parsed, err := ParseFactLinkInputs(raw) + if err != nil { + t.Fatal(err) + } + if len(parsed.Incoming) != 1 || parsed.Incoming[0].From != "target/primary_domain" { + t.Fatalf("unexpected incoming: %+v", parsed.Incoming) + } +} + +func TestParseFactLinkInputsRequiresFrom(t *testing.T) { + t.Parallel() + raw := []interface{}{ + map[string]interface{}{ + "to": "target/primary_domain", + "type": "discovered_on", + }, + } + _, err := ParseFactLinkInputs(raw) + if err == nil { + t.Fatal("expected error when from is missing") + } +} + +func TestGraphNodeType(t *testing.T) { + t.Parallel() + if GraphNodeType("chain", "chain/x") != "chain" { + t.Fatal("chain category") + } + if GraphNodeType("finding", "finding/x") != "finding" { + t.Fatal("finding category") + } + if GraphNodeType("exploit", "exploit/x") != "exploit" { + t.Fatal("exploit category") + } + if GraphNodeType("finding", "evidence/x") != "exploit" { + t.Fatal("evidence prefix") + } + if GraphNodeType("note", "target/x") != "target" { + t.Fatal("target prefix") + } +} + +func TestBuildProjectFactGraphPreservesStoredEdgeDirection(t *testing.T) { + dir := t.TempDir() + db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + p, err := db.CreateProject(&database.Project{Name: "path-edges"}) + if err != nil { + t.Fatal(err) + } + for _, spec := range []struct{ key, cat string }{ + {"target/primary_domain", "target"}, + {"chain/full_attack_path", "chain"}, + {"finding/mysql_public", "finding"}, + {"exploit/mysql_creds_extract", "exploit"}, + } { + if _, err := db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed", + }); err != nil { + t.Fatal(err) + } + } + if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{ + {From: "target/primary_domain", Type: "discovered_on"}, + }); err != nil { + t.Fatal(err) + } + if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{ + {From: "target/primary_domain", Type: "discovered_on"}, + {From: "exploit/mysql_creds_extract", Type: "exploits"}, + }); err != nil { + t.Fatal(err) + } + if err := db.ReplaceIncomingProjectFactEdges(p.ID, "chain/full_attack_path", []database.ProjectFactEdgeFromInput{ + {From: "target/primary_domain", Type: "discovered_on"}, + }); err != nil { + t.Fatal(err) + } + if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/mysql_creds_extract", []database.ProjectFactEdgeFromInput{ + {From: "chain/full_attack_path", Type: "leads_to"}, + }); err != nil { + t.Fatal(err) + } + + graph, err := BuildProjectFactGraph(db, p.ID, "path", true) + if err != nil { + t.Fatal(err) + } + want := map[string]struct{}{ + "target/primary_domain|discovered_on|finding/mysql_public": {}, + "exploit/mysql_creds_extract|exploits|finding/mysql_public": {}, + "target/primary_domain|discovered_on|chain/full_attack_path": {}, + "chain/full_attack_path|leads_to|exploit/mysql_creds_extract": {}, + } + for _, e := range graph.Edges { + key := e.Source + "|" + e.Type + "|" + e.Target + delete(want, key) + } + if len(want) > 0 { + t.Fatalf("missing expected stored-direction edges: %v", want) + } + countInOut := func(factKey string) (out, in int) { + for _, e := range graph.Edges { + if e.Source == factKey { + out++ + } + if e.Target == factKey { + in++ + } + } + return out, in + } + if out, in := countInOut("chain/full_attack_path"); out != 1 || in != 1 { + t.Fatalf("chain/full_attack_path want out=1 in=1 got out=%d in=%d", out, in) + } + if out, in := countInOut("exploit/mysql_creds_extract"); out != 1 || in != 1 { + t.Fatalf("exploit/mysql_creds_extract want out=1 in=1 got out=%d in=%d", out, in) + } +} + +func TestPersistFactLinksFromUsesFromAsIncoming(t *testing.T) { + dir := t.TempDir() + db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + p, err := db.CreateProject(&database.Project{Name: "from-links"}) + if err != nil { + t.Fatal(err) + } + for _, spec := range []struct{ key, cat string }{ + {"target/primary_domain", "target"}, + {"finding/sqli", "finding"}, + } { + if _, err := db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed", + }); err != nil { + t.Fatal(err) + } + } + parsed := &ParsedFactLinks{ + Incoming: []database.ProjectFactEdgeFromInput{ + {From: "target/primary_domain", Type: "discovered_on"}, + }, + } + if err := PersistFactLinksFromParsed(db, p.ID, "finding/sqli", "", parsed, false); err != nil { + t.Fatal(err) + } + graph, err := BuildProjectFactGraph(db, p.ID, "path", true) + if err != nil { + t.Fatal(err) + } + want := "target/primary_domain|discovered_on|finding/sqli" + for _, e := range graph.Edges { + key := e.Source + "|" + e.Type + "|" + e.Target + if key == want { + return + } + } + t.Fatalf("expected edge %s, got %+v", want, graph.Edges) +} + +func TestFormatOutgoingLinksHint(t *testing.T) { + t.Parallel() + hint := FormatOutgoingLinksHint([]*database.ProjectFactEdge{ + {EdgeType: "discovered_on", TargetFactKey: "target/a"}, + }) + if hint == "" || hint[0] != ' ' { + t.Fatalf("unexpected hint: %q", hint) + } +} + +func TestReplaceIncomingAllowsNotYetCreatedSource(t *testing.T) { + dir := t.TempDir() + db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + p, err := db.CreateProject(&database.Project{Name: "parallel-links"}) + if err != nil { + t.Fatal(err) + } + if _, err := db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: p.ID, FactKey: "exploit/sqli", Category: "exploit", Summary: "exploit", Confidence: "confirmed", + }); err != nil { + t.Fatal(err) + } + if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/sqli", []database.ProjectFactEdgeFromInput{ + {From: "finding/sqli_endpoint", Type: "exploits"}, + }); err != nil { + t.Fatalf("incoming edge should not require source fact to exist yet: %v", err) + } + if _, err := db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: p.ID, FactKey: "finding/sqli_endpoint", Category: "finding", Summary: "finding", Confidence: "confirmed", + }); err != nil { + t.Fatal(err) + } + in, err := db.ListIncomingProjectFactEdges(p.ID, "exploit/sqli") + if err != nil || len(in) != 1 || in[0].SourceFactKey != "finding/sqli_endpoint" { + t.Fatalf("expected persisted edge from finding, got %+v err=%v", in, err) + } +} + +func TestValidateProjectFactEdgeType(t *testing.T) { + t.Parallel() + if err := database.ValidateProjectFactEdgeType("leads_to"); err != nil { + t.Fatal(err) + } + if err := database.ValidateProjectFactEdgeType("invalid"); err == nil { + t.Fatal("expected error") + } +} diff --git a/internal/project/fact_index_links.go b/internal/project/fact_index_links.go new file mode 100644 index 00000000..32732894 --- /dev/null +++ b/internal/project/fact_index_links.go @@ -0,0 +1,231 @@ +package project + +import ( + "fmt" + "sort" + "strings" + + "cyberstrike-ai/internal/database" +) + +var factIndexEdgeTypeOrder = []string{ + "discovered_on", "leads_to", "enables", "depends_on", "exploits", "contains", "part_of", "supports", +} + +func filterIndexEdges(edges []*database.ProjectFactEdge) []*database.ProjectFactEdge { + if len(edges) == 0 { + return nil + } + out := make([]*database.ProjectFactEdge, 0, len(edges)) + for _, e := range edges { + if e == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(e.Confidence), "deprecated") { + continue + } + edgeType := strings.ToLower(strings.TrimSpace(e.EdgeType)) + if _, ok := database.ValidProjectFactEdgeTypes[edgeType]; !ok { + continue + } + out = append(out, e) + } + return out +} + +func edgeConfidenceSuffix(confidence string) string { + c := strings.ToLower(strings.TrimSpace(confidence)) + if c == "" || c == "confirmed" { + return "" + } + return " (" + c + ")" +} + +func formatRelationHintPart(e *database.ProjectFactEdge) string { + return fmt.Sprintf("%s←%s%s", e.EdgeType, e.SourceFactKey, edgeConfidenceSuffix(e.Confidence)) +} + +func formatOutgoingHintPart(e *database.ProjectFactEdge) string { + return fmt.Sprintf("%s→%s%s", e.EdgeType, e.TargetFactKey, edgeConfidenceSuffix(e.Confidence)) +} + +func formatIncomingHintPart(e *database.ProjectFactEdge) string { + return formatRelationHintPart(e) +} + +func joinEdgeHintParts(edges []*database.ProjectFactEdge, formatter func(*database.ProjectFactEdge) string) string { + parts := make([]string, 0, len(edges)) + for _, e := range edges { + parts = append(parts, formatter(e)) + } + return strings.Join(parts, ", ") +} + +// FormatOutgoingLinksHint 黑板索引用出边摘要(全部有效边类型,不截断)。 +func FormatOutgoingLinksHint(edges []*database.ProjectFactEdge) string { + edges = filterIndexEdges(edges) + if len(edges) == 0 { + return "" + } + return " {出边: " + joinEdgeHintParts(edges, formatOutgoingHintPart) + "}" +} + +// FormatIncomingLinksHint 黑板索引用入边摘要(全部有效边类型,不截断)。 +func FormatIncomingLinksHint(edges []*database.ProjectFactEdge) string { + edges = filterIndexEdges(edges) + if len(edges) == 0 { + return "" + } + return " {入边: " + joinEdgeHintParts(edges, formatIncomingHintPart) + "}" +} + +// FormatFactIndexLinksHint 黑板索引行内关系边(from → 当前 fact,与 upsert links 一致)。 +func FormatFactIndexLinksHint(_ string, incoming []*database.ProjectFactEdge) string { + in := filterIndexEdges(incoming) + if len(in) == 0 { + return "" + } + return " {关系边: " + joinEdgeHintParts(in, formatRelationHintPart) + "}" +} + +func indexEdgeGroupMaps(edges []*database.ProjectFactEdge) (outgoing, incoming map[string][]*database.ProjectFactEdge) { + outgoing = map[string][]*database.ProjectFactEdge{} + incoming = map[string][]*database.ProjectFactEdge{} + for _, e := range filterIndexEdges(edges) { + outgoing[e.SourceFactKey] = append(outgoing[e.SourceFactKey], e) + incoming[e.TargetFactKey] = append(incoming[e.TargetFactKey], e) + } + return outgoing, incoming +} + +func relationOverviewLine(e *database.ProjectFactEdge) string { + return fmt.Sprintf("- %s → %s%s · %s", e.SourceFactKey, e.TargetFactKey, edgeConfidenceSuffix(e.Confidence), e.EdgeType) +} + +func indexEdgeSortKey(e *database.ProjectFactEdge) (int, int, string) { + confRank := 0 + if strings.EqualFold(strings.TrimSpace(e.Confidence), "tentative") { + confRank = 1 + } + typeRank := len(factIndexEdgeTypeOrder) + 1 + for i, t := range factIndexEdgeTypeOrder { + if strings.EqualFold(e.EdgeType, t) { + typeRank = i + break + } + } + return confRank, typeRank, e.SourceFactKey + ">" + e.TargetFactKey + ">" + e.EdgeType +} + +func sortIndexOverviewEdges(edges []*database.ProjectFactEdge) { + sort.SliceStable(edges, func(i, j int) bool { + ci, ti, ki := indexEdgeSortKey(edges[i]) + cj, tj, kj := indexEdgeSortKey(edges[j]) + if ci != cj { + return ci < cj + } + if ti != tj { + return ti < tj + } + return ki < kj + }) +} + +// BuildFactPathOverviewSection 生成事实关系速览(全部有效边类型,不含 body)。 +func BuildFactPathOverviewSection(edges []*database.ProjectFactEdge, indexedKeys map[string]struct{}, maxRunes int) string { + if maxRunes <= 0 { + return "" + } + candidates := filterIndexEdges(edges) + if len(candidates) == 0 { + return "" + } + filtered := make([]*database.ProjectFactEdge, 0, len(candidates)) + for _, e := range candidates { + if len(indexedKeys) > 0 { + if _, ok := indexedKeys[e.SourceFactKey]; !ok { + continue + } + if _, ok := indexedKeys[e.TargetFactKey]; !ok { + continue + } + } + filtered = append(filtered, e) + } + if len(filtered) == 0 { + return "" + } + sortIndexOverviewEdges(filtered) + + header := "### 攻击路径(事实关系)\n" + header += "source → target · type(与攻击路径图/库中方向一致;写入时在目标 fact 的 links 用 from 声明来源)\n" + var b strings.Builder + b.WriteString(header) + used := len([]rune(header)) + omitted := 0 + + for _, e := range filtered { + line := relationOverviewLine(e) + "\n" + lineRunes := len([]rune(line)) + if used+lineRunes > maxRunes { + omitted++ + continue + } + b.WriteString(line) + used += lineRunes + } + if omitted > 0 { + extra := fmt.Sprintf("(另有 %d 条关系边未列入,请 get_project_fact 查看完整关系。)\n", omitted) + if used+len([]rune(extra)) <= maxRunes { + b.WriteString(extra) + } + } + if used <= len([]rune(header)) { + return "" + } + return b.String() +} + +func factIndexSortPriority(f *database.ProjectFact) int { + if f == nil { + return 0 + } + score := 0 + if f.Pinned { + score += 1000 + } + c := strings.ToLower(strings.TrimSpace(f.Category)) + switch c { + case FactCategoryTarget: + score += 400 + case FactCategoryFinding, FactCategoryChain: + score += 300 + case FactCategoryExploit, FactCategoryPOC: + score += 250 + case "auth", "infra", "business": + score += 200 + case "note": + score += 50 + default: + key := strings.ToLower(strings.TrimSpace(f.FactKey)) + if strings.HasPrefix(key, "target/") { + score += 400 + } else if strings.HasPrefix(key, "finding/") || strings.HasPrefix(key, "chain/") { + score += 300 + } + } + if strings.EqualFold(strings.TrimSpace(f.Confidence), "confirmed") { + score += 80 + } + return score +} + +func sortFactsForIndex(facts []*database.ProjectFact) { + sort.SliceStable(facts, func(i, j int) bool { + pi, pj := factIndexSortPriority(facts[i]), factIndexSortPriority(facts[j]) + if pi != pj { + return pi > pj + } + return facts[i].UpdatedAt.After(facts[j].UpdatedAt) + }) +} diff --git a/internal/project/fact_index_links_test.go b/internal/project/fact_index_links_test.go new file mode 100644 index 00000000..a5794b9d --- /dev/null +++ b/internal/project/fact_index_links_test.go @@ -0,0 +1,161 @@ +package project + +import ( + "fmt" + "path/filepath" + "strings" + "testing" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +func TestFormatIncomingLinksHint(t *testing.T) { + t.Parallel() + hint := FormatIncomingLinksHint([]*database.ProjectFactEdge{ + {EdgeType: "discovered_on", SourceFactKey: "finding/x", Confidence: "tentative"}, + }) + if !strings.Contains(hint, "入边:") { + t.Fatalf("expected 入边 label: %q", hint) + } + if !strings.Contains(hint, "discovered_on←finding/x") { + t.Fatalf("unexpected hint: %q", hint) + } + if !strings.Contains(hint, "tentative") { + t.Fatalf("expected tentative in hint: %q", hint) + } +} + +func TestFormatIncomingLinksHint_allEdges(t *testing.T) { + t.Parallel() + edges := make([]*database.ProjectFactEdge, 0, 5) + for i := 1; i <= 5; i++ { + edges = append(edges, &database.ProjectFactEdge{ + EdgeType: "discovered_on", + SourceFactKey: fmt.Sprintf("finding/f%d", i), + Confidence: "tentative", + }) + } + hint := FormatIncomingLinksHint(edges) + if strings.Contains(hint, "+") { + t.Fatalf("should not truncate with +N: %q", hint) + } + for i := 1; i <= 5; i++ { + if !strings.Contains(hint, fmt.Sprintf("finding/f%d", i)) { + t.Fatalf("missing edge f%d in hint: %q", i, hint) + } + } +} + +func TestFormatFactIndexLinksHint_incomingOnly(t *testing.T) { + t.Parallel() + in := []*database.ProjectFactEdge{ + {EdgeType: "discovered_on", SourceFactKey: "target/dev", Confidence: "tentative"}, + {EdgeType: "exploits", SourceFactKey: "exploit/rce", Confidence: "confirmed"}, + } + hint := FormatFactIndexLinksHint("finding/sqli", in) + if !strings.Contains(hint, "关系边:") { + t.Fatalf("missing 关系边 label: %q", hint) + } + if !strings.Contains(hint, "discovered_on←target/dev") { + t.Fatalf("missing discovered_on: %q", hint) + } + if !strings.Contains(hint, "exploits←exploit/rce") { + t.Fatalf("missing exploits: %q", hint) + } + if strings.Contains(hint, "出边") || strings.Contains(hint, "入边") { + t.Fatalf("should not use legacy 出边/入边 labels: %q", hint) + } +} + +func TestFormatFactIndexLinksHint_includesAuxiliaryEdgeTypes(t *testing.T) { + t.Parallel() + in := []*database.ProjectFactEdge{{EdgeType: "supports", SourceFactKey: "note/log"}} + hint := FormatFactIndexLinksHint("finding/x", in) + if !strings.Contains(hint, "supports←note/log") { + t.Fatalf("supports edge should be included: %q", hint) + } +} + +func TestBuildFactPathOverviewSection(t *testing.T) { + t.Parallel() + edges := []*database.ProjectFactEdge{ + {EdgeType: "discovered_on", SourceFactKey: "target/dev", TargetFactKey: "finding/sqli", Confidence: "tentative"}, + {EdgeType: "exploits", SourceFactKey: "exploit/rce", TargetFactKey: "finding/sqli", Confidence: "confirmed"}, + {EdgeType: "supports", SourceFactKey: "note/log", TargetFactKey: "finding/sqli"}, + } + keys := map[string]struct{}{ + "target/dev": {}, "finding/sqli": {}, "exploit/rce": {}, "note/log": {}, + } + section := BuildFactPathOverviewSection(edges, keys, 800) + if !strings.Contains(section, "### 攻击路径(事实关系)") { + t.Fatalf("missing header: %q", section) + } + if !strings.Contains(section, "target/dev → finding/sqli") { + t.Fatalf("missing discovered_on line: %q", section) + } + if !strings.Contains(section, "exploit/rce → finding/sqli") { + t.Fatalf("missing exploits line: %q", section) + } + if !strings.Contains(section, "note/log → finding/sqli") { + t.Fatalf("supports edge should be included: %q", section) + } +} + +func TestBuildFactIndexBlock_withLinksAndPathOverview(t *testing.T) { + t.Parallel() + dbPath := filepath.Join(t.TempDir(), "facts.db") + db, err := database.NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + proj, err := db.CreateProject(&database.Project{Name: "path-proj"}) + if err != nil { + t.Fatal(err) + } + _, err = db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: proj.ID, + FactKey: "target/dev", + Category: "target", + Summary: "dev 子域", + Confidence: "confirmed", + }) + if err != nil { + t.Fatal(err) + } + _, err = db.UpsertProjectFact(&database.ProjectFact{ + ProjectID: proj.ID, + FactKey: "finding/sqli", + Category: "finding", + Summary: "时间盲注", + Confidence: "tentative", + }) + if err != nil { + t.Fatal(err) + } + _, err = db.AddProjectFactEdge(proj.ID, database.ProjectFactEdgeInput{ + To: "finding/sqli", + Type: "discovered_on", + }, "target/dev", "") + if err != nil { + t.Fatal(err) + } + + block, err := BuildFactIndexBlock(db, proj.ID, config.ProjectConfig{Enabled: true, FactIndexMaxRunes: 6500, FactIndexPathMaxRunes: 1000}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(block, "关系边: discovered_on←target/dev") { + t.Fatalf("finding line should include relation hint: %q", block) + } + if !strings.Contains(block, "### 攻击路径(事实关系)") { + t.Fatalf("missing relation overview: %q", block) + } + if !strings.Contains(block, "target/dev → finding/sqli") { + t.Fatalf("missing overview edge: %q", block) + } +} diff --git a/internal/project/fact_recording_prompt.go b/internal/project/fact_recording_prompt.go index 1e02e650..7d986a46 100644 --- a/internal/project/fact_recording_prompt.go +++ b/internal/project/fact_recording_prompt.go @@ -1,100 +1,23 @@ package project -import ( - "strings" +import "cyberstrike-ai/internal/projectprompt" - "cyberstrike-ai/internal/mcp/builtin" -) - -// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。 -const ( - factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。" - factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。" - factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。" -) - -// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。 +// FactRecordingIncrementalRhythmMarkdown 见 projectprompt。 func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string { - var b strings.Builder - b.WriteString("- **边渗透边记录(强制节奏)**:") - b.WriteString(factRhythmCore) - if coordinator { - b.WriteString(factRhythmCoordinatorSuffix) - } - if subAgent { - b.WriteString(factRhythmSubAgentSuffix) - } - return b.String() + return projectprompt.FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent) } -func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string { - var b strings.Builder - b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ") - b.WriteString(builtin.ToolUpsertProjectFact) - b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ") - b.WriteString(builtin.ToolRecordVulnerability) - b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。") - if coordinator { - b.WriteString(factRhythmCoordinatorSuffix) - } - if subAgent { - b.WriteString(factRhythmSubAgentSuffix) - } - return b.String() -} - -// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。 -// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。 +// FactRecordingBlackboardSection 见 projectprompt。 func FactRecordingBlackboardSection(coordinatorDelegate bool) string { - var b strings.Builder - b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n") - b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ") - b.WriteString(builtin.ToolGetProjectFact) - b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n") - b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false)) - b.WriteString("\n\n") - b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ") - b.WriteString(builtin.ToolUpsertProjectFact) - b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n") - b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n") - b.WriteString("- **可交付漏洞**:使用 ") - b.WriteString(builtin.ToolRecordVulnerability) - b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ") - b.WriteString(builtin.ToolListVulnerabilities) - b.WriteString(" 查重,详情用 ") - b.WriteString(builtin.ToolGetVulnerability) - b.WriteString("(id)(默认仅当前项目/会话)。\n") - b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ") - b.WriteString(builtin.ToolDeprecateProjectFact) - b.WriteString(" 或漏洞状态 false_positive。\n") - b.WriteString("- 事实多时用 ") - b.WriteString(builtin.ToolListProjectFacts) - b.WriteString(" / ") - b.WriteString(builtin.ToolSearchProjectFacts) - b.WriteString(" 检索。\n\n") - b.WriteString(FactRecordingGuidanceBlock()) - b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。") - return b.String() + return projectprompt.FactRecordingBlackboardSection(coordinatorDelegate) } -// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。 +// FactRecordingSubAgentSection 见 projectprompt。 func FactRecordingSubAgentSection() string { - return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n" + return projectprompt.FactRecordingSubAgentSection() } -// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。 +// FactRecordingBlackboardSectionMarkdown 见 projectprompt。 func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string { - var b strings.Builder - b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n") - b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n") - b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false)) - b.WriteString("\n\n") - b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n") - b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n") - b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n") - b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n") - b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n") - b.WriteString(FactRecordingGuidanceBlock()) - b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。") - return b.String() + return projectprompt.FactRecordingBlackboardSectionMarkdown(coordinatorDelegate) } diff --git a/internal/project/fact_template.go b/internal/project/fact_template.go index b3856b17..c94c819c 100644 --- a/internal/project/fact_template.go +++ b/internal/project/fact_template.go @@ -3,6 +3,8 @@ package project import ( "fmt" "strings" + + "cyberstrike-ai/internal/projectprompt" ) // 事实 category 常量(写入 upsert_project_fact 的 category 字段)。 @@ -90,7 +92,8 @@ const attackChainFactBodyTemplate = `## 结论(可验证,一句话) ## 关联 - related_vulnerability_id: <可选,对应 record_vulnerability 的 id> -- 依赖事实: +- links(upsert 参数): [{ "from": "", "type": "discovered_on|..." }](from → 当前 fact) +- 依赖事实(body 可读镜像): ## 备注与不确定性 <待验证假设、环境差异、绕过尝试记录>` @@ -109,15 +112,7 @@ const envFactBodyTemplate = `## 摘要 // FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。 func FactRecordingGuidanceBlock() string { - return `### 事实写入规范(审计复现 / 知识沉淀) - -- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。 -- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。 -- **category / fact_key 建议**: - - 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可) - - 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID) -- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。 -- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。` + return projectprompt.FactRecordingGuidanceBlock() } // SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。 diff --git a/internal/projectprompt/blackboard.go b/internal/projectprompt/blackboard.go new file mode 100644 index 00000000..d3e3ae76 --- /dev/null +++ b/internal/projectprompt/blackboard.go @@ -0,0 +1,132 @@ +// Package projectprompt 提供项目黑板相关的系统提示文本(纯字符串,无 database 依赖)。 +// 供 agent / multiagent 等包引用,避免 agent → project 导入环导致 gopls 元数据失败。 +package projectprompt + +import ( + "strings" + + "cyberstrike-ai/internal/mcp/builtin" +) + +const ( + factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。" + factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。" + factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。" +) + +// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。 +func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string { + var b strings.Builder + b.WriteString("- **边渗透边记录(强制节奏)**:") + b.WriteString(factRhythmCore) + if coordinator { + b.WriteString(factRhythmCoordinatorSuffix) + } + if subAgent { + b.WriteString(factRhythmSubAgentSuffix) + } + return b.String() +} + +func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string { + var b strings.Builder + b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ") + b.WriteString(builtin.ToolUpsertProjectFact) + b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ") + b.WriteString(builtin.ToolRecordVulnerability) + b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。") + if coordinator { + b.WriteString(factRhythmCoordinatorSuffix) + } + if subAgent { + b.WriteString(factRhythmSubAgentSuffix) + } + return b.String() +} + +func factEdgeRecordingGuidance() string { + return `### 事实关系边(links) + +- 写入 **finding / chain / exploit / poc** 时,**必须**在 ` + "`upsert_project_fact`" + ` 中提供 ` + "`links`" + `(**推荐 ` + "`from`" + `**:来源 fact 指向当前 fact,即 ` + "`from`" + ` → 当前 ` + "`fact_key`" + `)。 +- **最少要求**:finding 类至少 1 条 from=target/* + type=discovered_on(即 target → finding);在 finding 上记录 exploit 用 from=exploit/* + type=exploits(即 exploit → finding)。 +- **常用 type**:` + "`discovered_on`" + `(发现在哪)、` + "`depends_on`" + `(复现前置)、` + "`leads_to`" + `(认知推进)、` + "`enables`" + `(扩大攻击面)、` + "`exploits`" + `(利用关系)、` + "`contains`" + `(资产包含)、` + "`part_of`" + `(属于链/组)、` + "`supports`" + `(证据支撑)。 +- 更新时:**省略 links 保留已有边**;传入 links 则**替换**全部关系边(from → 当前 fact)。 +- body 中「依赖事实」段落可与 links 并存(人读);结构化关系以 links 为准。` +} + +func factRecordingGuidanceBlock() string { + return `### 事实写入规范(审计复现 / 知识沉淀) + +- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。 +- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。 +- **category / fact_key 建议**: + - 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可) + - 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID) +- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。 +- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。` +} + +// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。 +func FactRecordingBlackboardSection(coordinatorDelegate bool) string { + var b strings.Builder + b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n") + b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ") + b.WriteString(builtin.ToolGetProjectFact) + b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n") + b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false)) + b.WriteString("\n\n") + b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ") + b.WriteString(builtin.ToolUpsertProjectFact) + b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n") + b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n") + b.WriteString("- **可交付漏洞**:使用 ") + b.WriteString(builtin.ToolRecordVulnerability) + b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ") + b.WriteString(builtin.ToolListVulnerabilities) + b.WriteString(" 查重,详情用 ") + b.WriteString(builtin.ToolGetVulnerability) + b.WriteString("(id)(默认仅当前项目/会话)。\n") + b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ") + b.WriteString(builtin.ToolDeprecateProjectFact) + b.WriteString(" 或漏洞状态 false_positive。\n") + b.WriteString("- 事实多时用 ") + b.WriteString(builtin.ToolListProjectFacts) + b.WriteString(" / ") + b.WriteString(builtin.ToolSearchProjectFacts) + b.WriteString(" 检索。\n\n") + b.WriteString(factEdgeRecordingGuidance()) + b.WriteString("\n\n") + b.WriteString(factRecordingGuidanceBlock()) + b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。") + return b.String() +} + +// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。 +func FactRecordingSubAgentSection() string { + return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n" +} + +// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。 +func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string { + var b strings.Builder + b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n") + b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n") + b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false)) + b.WriteString("\n\n") + b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n") + b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n") + b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n") + b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n") + b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n") + b.WriteString(factEdgeRecordingGuidance()) + b.WriteString("\n\n") + b.WriteString(factRecordingGuidanceBlock()) + b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。") + return b.String() +} + +// FactEdgeRecordingGuidance 写入边时的 Agent 规范(供 project 包复用)。 +func FactEdgeRecordingGuidance() string { return factEdgeRecordingGuidance() } + +// FactRecordingGuidanceBlock 事实写入规范块(供 project 包复用)。 +func FactRecordingGuidanceBlock() string { return factRecordingGuidanceBlock() }