diff --git a/internal/agent/default_single_system_prompt.go b/internal/agent/default_single_system_prompt.go index fa2d25d0..0ee8468e 100644 --- a/internal/agent/default_single_system_prompt.go +++ b/internal/agent/default_single_system_prompt.go @@ -1,7 +1,6 @@ package agent import ( - "cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/project" ) @@ -108,17 +107,7 @@ func DefaultSingleAgentSystemPrompt() string { - 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。 - 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。 -## 项目黑板(事实)与漏洞记录(分离) - -当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ` + builtin.ToolGetProjectFact + `(fact_key) 获取 body,禁止凭摘要臆造细节。** - -- **环境/目标/认证等认知**(非正式漏洞条目):使用 ` + builtin.ToolUpsertProjectFact + `,fact_key 建议 ` + "`category/slug`" + `(如 target/primary_domain),同 key 覆盖更新。 -- **可交付漏洞**:使用 ` + builtin.ToolRecordVulnerability + `,含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ` + builtin.ToolListVulnerabilities + ` 查重,详情用 ` + builtin.ToolGetVulnerability + `(id)(默认仅当前项目/会话)。 -- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ` + builtin.ToolDeprecateProjectFact + ` 或漏洞状态 false_positive。 - -` + project.FactRecordingGuidanceBlock() + ` - -严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。 +` + project.FactRecordingBlackboardSection(false) + ` ## 技能库(Skills)与知识库 diff --git a/internal/database/database.go b/internal/database/database.go index 5f62cac1..eb8fe27b 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -247,6 +247,25 @@ func (db *DB) initTables() error { UNIQUE(project_id, fact_key) );` + createProjectFactVersionsTable := ` + CREATE TABLE IF NOT EXISTS project_fact_versions ( + id TEXT PRIMARY KEY, + fact_id TEXT NOT NULL, + project_id TEXT NOT NULL, + fact_key TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'note', + summary TEXT NOT NULL DEFAULT '', + body TEXT, + confidence TEXT NOT NULL DEFAULT 'tentative', + source_conversation_id TEXT, + source_message_id TEXT, + pinned INTEGER NOT NULL DEFAULT 0, + related_vulnerability_id TEXT, + archived_at DATETIME NOT NULL, + FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE + );` + // 创建漏洞表 createVulnerabilitiesTable := ` CREATE TABLE IF NOT EXISTS vulnerabilities ( @@ -483,6 +502,8 @@ func (db *DB) initTables() error { CREATE INDEX IF NOT EXISTS idx_projects_updated_at ON projects(updated_at); CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id); CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence); + CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id); + CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id); CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id); CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); @@ -564,6 +585,10 @@ func (db *DB) initTables() error { return fmt.Errorf("创建project_facts表失败: %w", err) } + if _, err := db.Exec(createProjectFactVersionsTable); err != nil { + return fmt.Errorf("创建project_fact_versions表失败: %w", err) + } + if _, err := db.Exec(createVulnerabilitiesTable); err != nil { return fmt.Errorf("创建vulnerabilities表失败: %w", err) } @@ -634,6 +659,9 @@ func (db *DB) initTables() error { if err := db.migrateProjectsTable(); err != nil { db.logger.Warn("迁移projects相关表失败", zap.Error(err)) } + if err := db.migrateProjectFactVersionsTable(); err != nil { + db.logger.Warn("迁移project_fact_versions表失败", zap.Error(err)) + } if err := db.migrateWebshellConnectionsTable(); err != nil { db.logger.Warn("迁移webshell_connections表失败", zap.Error(err)) @@ -1030,6 +1058,34 @@ func (db *DB) migrateProjectsTable() error { return nil } +// migrateProjectFactVersionsTable 为已有库创建事实版本表。 +func (db *DB) migrateProjectFactVersionsTable() error { + ddl := ` + CREATE TABLE IF NOT EXISTS project_fact_versions ( + id TEXT PRIMARY KEY, + fact_id TEXT NOT NULL, + project_id TEXT NOT NULL, + fact_key TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'note', + summary TEXT NOT NULL DEFAULT '', + body TEXT, + confidence TEXT NOT NULL DEFAULT 'tentative', + source_conversation_id TEXT, + source_message_id TEXT, + pinned INTEGER NOT NULL DEFAULT 0, + related_vulnerability_id TEXT, + archived_at DATETIME NOT NULL, + FOREIGN KEY (fact_id) REFERENCES project_facts(id) ON DELETE CASCADE, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE + );` + if _, err := db.Exec(ddl); err != nil { + return err + } + _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_fact_versions_fact_id ON project_fact_versions(fact_id)`) + _, _ = db.Exec(`CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id)`) + return nil +} + // migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段 func (db *DB) migrateVulnerabilitiesTable() error { columns := []struct { diff --git a/internal/database/project.go b/internal/database/project.go index 54d86e1e..9f21bc79 100644 --- a/internal/database/project.go +++ b/internal/database/project.go @@ -59,9 +59,11 @@ type ProjectFact struct { // ProjectFactListFilter 事实列表筛选。 type ProjectFactListFilter struct { - Category string - Confidence string - Search string + Category string + Confidence string + Search string + RelatedVulnerabilityID string + ExcludeDeprecated bool // 为 true 时排除 confidence=deprecated } // CreateProject 创建项目。 @@ -160,8 +162,11 @@ func (db *DB) UpdateProject(p *Project) error { return nil } -// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理)。 +// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理;漏洞 project_id 置空)。 func (db *DB) DeleteProject(id string) error { + if _, err := db.Exec(`UPDATE vulnerabilities SET project_id = NULL WHERE project_id = ?`, id); err != nil { + return fmt.Errorf("解除漏洞项目关联失败: %w", err) + } _, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id) if err != nil { return fmt.Errorf("删除项目失败: %w", err) @@ -243,6 +248,13 @@ func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, l query += " AND confidence = ?" args = append(args, c) } + if filter.ExcludeDeprecated { + query += " AND confidence != 'deprecated'" + } + if rid := strings.TrimSpace(filter.RelatedVulnerabilityID); rid != "" { + query += " AND related_vulnerability_id = ?" + args = append(args, rid) + } if s := strings.TrimSpace(filter.Search); s != "" { pat := "%" + s + "%" query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)" @@ -309,10 +321,26 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) { f.CreatedAt = existing.CreatedAt f.UpdatedAt = now f.Body = mergeFactBodyOnUpdate(f.Body, existing.Body) + if strings.TrimSpace(f.Category) == "" { + f.Category = existing.Category + } + if strings.TrimSpace(f.Confidence) == "" { + f.Confidence = existing.Confidence + } + if projectFactContentChanged(existing, f) { + versionID, verr := db.InsertProjectFactVersion(existing) + if verr != nil { + return nil, verr + } + f.SupersedesFactID = versionID + } else if f.SupersedesFactID == "" { + f.SupersedesFactID = existing.SupersedesFactID + } _, err = db.Exec( `UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?, - source_conversation_id = ?, source_message_id = ?, pinned = ?, - supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ? + source_conversation_id = COALESCE(?, source_conversation_id), + source_message_id = COALESCE(?, source_message_id), + pinned = ?, supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ? WHERE id = ?`, f.Category, f.Summary, f.Body, f.Confidence, nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned), diff --git a/internal/database/project_fact_upsert_test.go b/internal/database/project_fact_upsert_test.go index c843d508..e5ea08f6 100644 --- a/internal/database/project_fact_upsert_test.go +++ b/internal/database/project_fact_upsert_test.go @@ -135,6 +135,54 @@ func TestRestoreProjectFact(t *testing.T) { } } +func TestUpsertProjectFact_createsVersionOnContentChange(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "facts.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + proj, err := db.CreateProject(&Project{Name: "version-test"}) + if err != nil { + t.Fatal(err) + } + + created, err := db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: "finding/xss", + Category: "finding", + Summary: "v1", + Body: "body v1", + }) + if err != nil { + t.Fatal(err) + } + if created.SupersedesFactID != "" { + t.Fatalf("expected no supersedes on create, got %q", created.SupersedesFactID) + } + + updated, err := db.UpsertProjectFact(&ProjectFact{ + ProjectID: proj.ID, + FactKey: "finding/xss", + Summary: "v2", + Body: "body v2", + }) + if err != nil { + t.Fatal(err) + } + if updated.SupersedesFactID == "" { + t.Fatal("expected supersedes_fact_id after content change") + } + prev, err := db.GetProjectFactVersion(updated.SupersedesFactID) + if err != nil { + t.Fatal(err) + } + if prev.Summary != "v1" || prev.Body != "body v1" { + t.Fatalf("previous version mismatch: summary=%q body=%q", prev.Summary, prev.Body) + } +} + func TestMergeFactBodyOnUpdate(t *testing.T) { if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" { t.Fatalf("empty incoming: got %q", got) diff --git a/internal/database/project_fact_version.go b/internal/database/project_fact_version.go new file mode 100644 index 00000000..cf49eecb --- /dev/null +++ b/internal/database/project_fact_version.go @@ -0,0 +1,144 @@ +package database + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/google/uuid" +) + +// ProjectFactVersion 事实历史快照(同 fact_key 更新前归档)。 +type ProjectFactVersion struct { + ID string `json:"id"` + FactID string `json:"fact_id"` + ProjectID string `json:"project_id"` + FactKey string `json:"fact_key"` + Category string `json:"category"` + Summary string `json:"summary"` + Body string `json:"body"` + Confidence string `json:"confidence"` + SourceConversationID string `json:"source_conversation_id,omitempty"` + SourceMessageID string `json:"source_message_id,omitempty"` + Pinned bool `json:"pinned"` + RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"` + ArchivedAt time.Time `json:"archived_at"` +} + +// InsertProjectFactVersion 将当前事实行快照写入版本表。 +func (db *DB) InsertProjectFactVersion(f *ProjectFact) (string, error) { + if f == nil || f.ID == "" { + return "", fmt.Errorf("无效的事实记录") + } + id := uuid.New().String() + now := time.Now() + _, err := db.Exec( + `INSERT INTO project_fact_versions ( + id, fact_id, project_id, fact_key, category, summary, body, confidence, + source_conversation_id, source_message_id, pinned, related_vulnerability_id, archived_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + id, f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence, + nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned), + nullIfEmpty(f.RelatedVulnerabilityID), now, + ) + if err != nil { + return "", fmt.Errorf("归档事实版本失败: %w", err) + } + return id, nil +} + +// GetProjectFactVersion 按版本 ID 获取快照。 +func (db *DB) GetProjectFactVersion(versionID string) (*ProjectFactVersion, error) { + row := db.QueryRow( + `SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, + COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, + COALESCE(related_vulnerability_id,''), archived_at + FROM project_fact_versions WHERE id = ?`, versionID, + ) + return scanProjectFactVersionRow(row) +} + +// ListProjectFactVersions 列出某条事实的全部历史版本(新→旧)。 +func (db *DB) ListProjectFactVersions(factID string, limit int) ([]*ProjectFactVersion, error) { + if limit <= 0 { + limit = 20 + } + rows, err := db.Query( + `SELECT id, fact_id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, + COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, + COALESCE(related_vulnerability_id,''), archived_at + FROM project_fact_versions WHERE fact_id = ? ORDER BY archived_at DESC LIMIT ?`, + factID, limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var out []*ProjectFactVersion + for rows.Next() { + v, err := scanProjectFactVersionFromRows(rows) + if err != nil { + return nil, err + } + out = append(out, v) + } + return out, rows.Err() +} + +func projectFactContentChanged(existing, incoming *ProjectFact) bool { + if existing == nil || incoming == nil { + return false + } + mergedBody := mergeFactBodyOnUpdate(incoming.Body, existing.Body) + inCat := stringsTrimDefault(incoming.Category, existing.Category) + inConf := stringsTrimDefault(incoming.Confidence, existing.Confidence) + return existing.Summary != incoming.Summary || + existing.Body != mergedBody || + existing.Category != inCat || + existing.Confidence != inConf +} + +func stringsTrimDefault(s, fallback string) string { + if strings.TrimSpace(s) == "" { + return fallback + } + return strings.TrimSpace(s) +} + +func scanProjectFactVersionRow(row *sql.Row) (*ProjectFactVersion, error) { + var v ProjectFactVersion + var pinned int + var archivedAt string + err := row.Scan( + &v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence, + &v.SourceConversationID, &v.SourceMessageID, &pinned, + &v.RelatedVulnerabilityID, &archivedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("事实版本不存在") + } + return nil, err + } + v.Pinned = pinned != 0 + v.ArchivedAt = parseDBTime(archivedAt) + return &v, nil +} + +func scanProjectFactVersionFromRows(rows *sql.Rows) (*ProjectFactVersion, error) { + var v ProjectFactVersion + var pinned int + var archivedAt string + err := rows.Scan( + &v.ID, &v.FactID, &v.ProjectID, &v.FactKey, &v.Category, &v.Summary, &v.Body, &v.Confidence, + &v.SourceConversationID, &v.SourceMessageID, &pinned, + &v.RelatedVulnerabilityID, &archivedAt, + ) + if err != nil { + return nil, err + } + v.Pinned = pinned != 0 + v.ArchivedAt = parseDBTime(archivedAt) + return &v, nil +} diff --git a/internal/database/project_stats.go b/internal/database/project_stats.go new file mode 100644 index 00000000..b35e3787 --- /dev/null +++ b/internal/database/project_stats.go @@ -0,0 +1,121 @@ +package database + +import ( + "database/sql" + "fmt" + "strings" +) + +// ProjectStats 项目聚合统计。 +type ProjectStats struct { + FactCount int `json:"fact_count"` + VulnCount int `json:"vuln_count"` + ConversationCount int `json:"conversation_count"` + SparseFactCount int `json:"sparse_fact_count"` +} + +// GetProjectStatsCounts 统计项目下事实、漏洞、对话数量(不含 sparse,由 project 包补全)。 +func (db *DB) GetProjectStatsCounts(projectID string) (*ProjectStats, error) { + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return nil, fmt.Errorf("project_id 不能为空") + } + if _, err := db.GetProject(projectID); err != nil { + return nil, err + } + stats := &ProjectStats{} + if err := db.QueryRow( + `SELECT COUNT(*) FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`, + projectID, + ).Scan(&stats.FactCount); err != nil { + return nil, fmt.Errorf("统计事实失败: %w", err) + } + if err := db.QueryRow( + `SELECT COUNT(*) FROM vulnerabilities WHERE project_id = ?`, + projectID, + ).Scan(&stats.VulnCount); err != nil { + return nil, fmt.Errorf("统计漏洞失败: %w", err) + } + if err := db.QueryRow( + `SELECT COUNT(*) FROM conversations WHERE project_id = ?`, + projectID, + ).Scan(&stats.ConversationCount); err != nil { + return nil, fmt.Errorf("统计对话失败: %w", err) + } + return stats, nil +} + +// ListProjectFactsForSparseCheck 返回用于待补全检测的事实字段(非 deprecated)。 +func (db *DB) ListProjectFactsForSparseCheck(projectID string) ([]struct { + Category string + FactKey string + Body string +}, error) { + rows, err := db.Query( + `SELECT category, fact_key, COALESCE(body,'') FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`, + projectID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var out []struct { + Category string + FactKey string + Body string + } + for rows.Next() { + var row struct { + Category string + FactKey string + Body string + } + if err := rows.Scan(&row.Category, &row.FactKey, &row.Body); err != nil { + return nil, err + } + out = append(out, row) + } + return out, rows.Err() +} + +// ListConversationsByProjectID 列出绑定到项目的对话。 +func (db *DB) ListConversationsByProjectID(projectID string, limit, offset int) ([]*Conversation, error) { + if limit <= 0 { + limit = 100 + } + rows, err := db.Query( + `SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id + FROM conversations WHERE project_id = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?`, + projectID, limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("查询项目对话失败: %w", err) + } + defer rows.Close() + + var conversations []*Conversation + for rows.Next() { + var conv Conversation + var createdAt, updatedAt string + var pinned int + var pid sql.NullString + if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &pid); err != nil { + return nil, err + } + if pid.Valid { + conv.ProjectID = strings.TrimSpace(pid.String) + } + conv.CreatedAt = parseDBTime(createdAt) + conv.UpdatedAt = parseDBTime(updatedAt) + conv.Pinned = pinned != 0 + conversations = append(conversations, &conv) + } + return conversations, rows.Err() +} + +// CountConversationsByProjectID 统计项目绑定对话数。 +func (db *DB) CountConversationsByProjectID(projectID string) (int, error) { + var n int + err := db.QueryRow(`SELECT COUNT(*) FROM conversations WHERE project_id = ?`, projectID).Scan(&n) + return n, err +}