diff --git a/internal/database/project.go b/internal/database/project.go index f300d82e..54d86e1e 100644 --- a/internal/database/project.go +++ b/internal/database/project.go @@ -282,6 +282,14 @@ func (db *DB) GetProjectFact(id string) (*ProjectFact, error) { return scanProjectFactRow(row) } +// mergeFactBodyOnUpdate 更新时若 incoming body 为空则保留已有内容,避免仅改 summary 时丢失攻击链。 +func mergeFactBodyOnUpdate(incoming, existing string) string { + if strings.TrimSpace(incoming) == "" { + return existing + } + return incoming +} + // UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。 func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) { if err := ValidateFactKey(f.FactKey); err != nil { @@ -300,6 +308,7 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) { f.ID = existing.ID f.CreatedAt = existing.CreatedAt f.UpdatedAt = now + f.Body = mergeFactBodyOnUpdate(f.Body, existing.Body) _, err = db.Exec( `UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?, source_conversation_id = ?, source_message_id = ?, pinned = ?, @@ -353,6 +362,31 @@ func (db *DB) DeprecateProjectFact(projectID, factKey string) error { return nil } +// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。 +func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error { + confidence = strings.TrimSpace(strings.ToLower(confidence)) + if confidence == "" { + confidence = "tentative" + } + if confidence != "confirmed" && confidence != "tentative" { + return fmt.Errorf("confidence 须为 confirmed 或 tentative") + } + + existing, err := db.GetProjectFactByKey(projectID, factKey) + if err != nil { + return fmt.Errorf("事实不存在") + } + if strings.ToLower(strings.TrimSpace(existing.Confidence)) != "deprecated" { + return fmt.Errorf("事实未处于废弃状态") + } + + _, err = db.Exec( + `UPDATE project_facts SET confidence = ?, updated_at = ? WHERE project_id = ? AND fact_key = ?`, + confidence, time.Now(), projectID, factKey, + ) + return err +} + // DeleteProjectFact 删除事实。 func (db *DB) DeleteProjectFact(id string) error { _, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id) diff --git a/internal/database/project_fact_upsert_test.go b/internal/database/project_fact_upsert_test.go new file mode 100644 index 00000000..c843d508 --- /dev/null +++ b/internal/database/project_fact_upsert_test.go @@ -0,0 +1,148 @@ +package database + +import ( + "path/filepath" + "testing" + + "go.uber.org/zap" +) + +func TestUpsertProjectFact_preservesBodyOnEmptyUpdate(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "facts.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + proj, err := db.CreateProject(&Project{Name: "test-facts"}) + if err != nil { + t.Fatal(err) + } + + const body = "## 攻击链\n1. step\n```http\nGET / HTTP/1.1\n```\n" + _, err = db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: "finding/sqli-login", + Category: "finding", + Summary: "SQLi on /login", + Body: body, + }) + if err != nil { + t.Fatal(err) + } + + updated, err := db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: "finding/sqli-login", + Summary: "SQLi on /login (confirmed)", + Body: "", + }) + if err != nil { + t.Fatal(err) + } + if updated.Summary != "SQLi on /login (confirmed)" { + t.Fatalf("summary=%q", updated.Summary) + } + if updated.Body != body { + t.Fatalf("returned body=%q want preserved attack chain", updated.Body) + } + + fromDB, err := db.GetProjectFactByKey(proj.ID, "finding/sqli-login") + if err != nil { + t.Fatal(err) + } + if fromDB.Body != body { + t.Fatalf("stored body=%q want preserved", fromDB.Body) + } +} + +func TestUpsertProjectFact_replacesBodyWhenProvided(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "facts.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + proj, err := db.CreateProject(&Project{Name: "test-facts"}) + if err != nil { + t.Fatal(err) + } + + _, err = db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: "target/primary", + Summary: "v1", + Body: "old body", + }) + if err != nil { + t.Fatal(err) + } + + const newBody = "new body with evidence" + updated, err := db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: "target/primary", + Summary: "v2", + Body: newBody, + }) + if err != nil { + t.Fatal(err) + } + if updated.Body != newBody { + t.Fatalf("body=%q want %q", updated.Body, newBody) + } +} + +func TestRestoreProjectFact(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "facts.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + proj, err := db.CreateProject(&Project{Name: "restore-test"}) + if err != nil { + t.Fatal(err) + } + key := "target/restore-me" + _, err = db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: key, + Summary: "s", + Confidence: "confirmed", + }) + if err != nil { + t.Fatal(err) + } + if err := db.DeprecateProjectFact(proj.ID, key); err != nil { + t.Fatal(err) + } + if err := db.RestoreProjectFact(proj.ID, key, "confirmed"); err != nil { + t.Fatal(err) + } + f, err := db.GetProjectFactByKey(proj.ID, key) + if err != nil { + t.Fatal(err) + } + if f.Confidence != "confirmed" { + t.Fatalf("confidence=%q want confirmed", f.Confidence) + } + if err := db.RestoreProjectFact(proj.ID, key, ""); err == nil { + t.Fatal("expected error when not deprecated") + } +} + +func TestMergeFactBodyOnUpdate(t *testing.T) { + if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" { + t.Fatalf("empty incoming: got %q", got) + } + if got := mergeFactBodyOnUpdate(" ", "keep"); got != "keep" { + t.Fatalf("whitespace incoming: got %q", got) + } + if got := mergeFactBodyOnUpdate("new", "old"); got != "new" { + t.Fatalf("non-empty incoming: got %q", got) + } +} diff --git a/internal/mcp/builtin/constants.go b/internal/mcp/builtin/constants.go index bc178049..11dc1bba 100644 --- a/internal/mcp/builtin/constants.go +++ b/internal/mcp/builtin/constants.go @@ -14,6 +14,7 @@ const ( ToolListProjectFacts = "list_project_facts" ToolSearchProjectFacts = "search_project_facts" ToolDeprecateProjectFact = "deprecate_project_fact" + ToolRestoreProjectFact = "restore_project_fact" // 知识库工具 ToolListKnowledgeRiskTypes = "list_knowledge_risk_types" @@ -69,6 +70,7 @@ func IsBuiltinTool(toolName string) bool { ToolListProjectFacts, ToolSearchProjectFacts, ToolDeprecateProjectFact, + ToolRestoreProjectFact, ToolListKnowledgeRiskTypes, ToolSearchKnowledgeBase, ToolWebshellExec, @@ -119,6 +121,7 @@ func GetAllBuiltinTools() []string { ToolListProjectFacts, ToolSearchProjectFacts, ToolDeprecateProjectFact, + ToolRestoreProjectFact, ToolListKnowledgeRiskTypes, ToolSearchKnowledgeBase, ToolWebshellExec,