mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-27 09:42:30 +02:00
452 lines
14 KiB
Go
452 lines
14 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`)
|
|
|
|
// ValidateFactKey 校验事实 key(项目内唯一标识)。
|
|
func ValidateFactKey(key string) error {
|
|
key = strings.TrimSpace(key)
|
|
if key == "" {
|
|
return fmt.Errorf("fact_key 不能为空")
|
|
}
|
|
if len(key) > 128 {
|
|
return fmt.Errorf("fact_key 过长(最多 128 字符)")
|
|
}
|
|
if !factKeyPattern.MatchString(key) {
|
|
return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Project 渗透测试项目(跨对话共享黑板)。
|
|
type Project struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
ScopeJSON string `json:"scope_json,omitempty"`
|
|
Status string `json:"status"` // active | archived
|
|
Pinned bool `json:"pinned"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// ProjectFact 项目事实(黑板条目)。
|
|
type ProjectFact struct {
|
|
ID string `json:"id"`
|
|
ProjectID string `json:"project_id"`
|
|
FactKey string `json:"fact_key"`
|
|
Category string `json:"category"`
|
|
Summary string `json:"summary"`
|
|
Body string `json:"body"`
|
|
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
|
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
|
SourceMessageID string `json:"source_message_id,omitempty"`
|
|
Pinned bool `json:"pinned"`
|
|
SupersedesFactID string `json:"supersedes_fact_id,omitempty"`
|
|
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// ProjectFactListFilter 事实列表筛选。
|
|
type ProjectFactListFilter struct {
|
|
Category string
|
|
Confidence string
|
|
Search string
|
|
}
|
|
|
|
// CreateProject 创建项目。
|
|
func (db *DB) CreateProject(p *Project) (*Project, error) {
|
|
if p.ID == "" {
|
|
p.ID = uuid.New().String()
|
|
}
|
|
if strings.TrimSpace(p.Status) == "" {
|
|
p.Status = "active"
|
|
}
|
|
now := time.Now()
|
|
if p.CreatedAt.IsZero() {
|
|
p.CreatedAt = now
|
|
}
|
|
p.UpdatedAt = now
|
|
|
|
_, err := db.Exec(
|
|
`INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建项目失败: %w", err)
|
|
}
|
|
return p, nil
|
|
}
|
|
|
|
// GetProject 获取项目。
|
|
func (db *DB) GetProject(id string) (*Project, error) {
|
|
var p Project
|
|
var pinned int
|
|
var createdAt, updatedAt string
|
|
err := db.QueryRow(
|
|
`SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
|
FROM projects WHERE id = ?`, id,
|
|
).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("项目不存在")
|
|
}
|
|
return nil, fmt.Errorf("获取项目失败: %w", err)
|
|
}
|
|
p.Pinned = pinned != 0
|
|
p.CreatedAt = parseDBTime(createdAt)
|
|
p.UpdatedAt = parseDBTime(updatedAt)
|
|
return &p, nil
|
|
}
|
|
|
|
// ListProjects 列出项目。
|
|
func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) {
|
|
if limit <= 0 {
|
|
limit = 200
|
|
}
|
|
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
|
|
FROM projects WHERE 1=1`
|
|
args := []interface{}{}
|
|
if s := strings.TrimSpace(status); s != "" {
|
|
query += " AND status = ?"
|
|
args = append(args, s)
|
|
}
|
|
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
|
args = append(args, limit, offset)
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("列出项目失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []*Project
|
|
for rows.Next() {
|
|
var p Project
|
|
var pinned int
|
|
var createdAt, updatedAt string
|
|
if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
p.Pinned = pinned != 0
|
|
p.CreatedAt = parseDBTime(createdAt)
|
|
p.UpdatedAt = parseDBTime(updatedAt)
|
|
out = append(out, &p)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// UpdateProject 更新项目。
|
|
func (db *DB) UpdateProject(p *Project) error {
|
|
p.UpdatedAt = time.Now()
|
|
_, err := db.Exec(
|
|
`UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`,
|
|
p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("更新项目失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理)。
|
|
func (db *DB) DeleteProject(id string) error {
|
|
_, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("删除项目失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetConversationProjectID 返回对话绑定的项目 ID。
|
|
func (db *DB) GetConversationProjectID(conversationID string) (string, error) {
|
|
var pid sql.NullString
|
|
err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return "", fmt.Errorf("对话不存在")
|
|
}
|
|
return "", err
|
|
}
|
|
if pid.Valid {
|
|
return strings.TrimSpace(pid.String), nil
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。
|
|
func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
|
|
projectID = strings.TrimSpace(projectID)
|
|
if projectID != "" {
|
|
if _, err := db.GetProject(projectID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
var val interface{}
|
|
if projectID == "" {
|
|
val = nil
|
|
} else {
|
|
val = projectID
|
|
}
|
|
_, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID)
|
|
if err != nil {
|
|
return fmt.Errorf("设置对话项目失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。
|
|
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
|
|
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
|
FROM project_facts WHERE project_id = ?`
|
|
args := []interface{}{projectID}
|
|
if !includeDeprecated {
|
|
query += " AND confidence != 'deprecated'"
|
|
}
|
|
query += " ORDER BY pinned DESC, updated_at DESC"
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanProjectFacts(rows)
|
|
}
|
|
|
|
// ListProjectFacts 分页列出项目事实。
|
|
func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) {
|
|
if limit <= 0 {
|
|
limit = 100
|
|
}
|
|
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
|
FROM project_facts WHERE project_id = ?`
|
|
args := []interface{}{projectID}
|
|
if c := strings.TrimSpace(filter.Category); c != "" {
|
|
query += " AND category = ?"
|
|
args = append(args, c)
|
|
}
|
|
if c := strings.TrimSpace(filter.Confidence); c != "" {
|
|
query += " AND confidence = ?"
|
|
args = append(args, c)
|
|
}
|
|
if s := strings.TrimSpace(filter.Search); s != "" {
|
|
pat := "%" + s + "%"
|
|
query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)"
|
|
args = append(args, pat, pat, pat)
|
|
}
|
|
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
|
|
args = append(args, limit, offset)
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanProjectFacts(rows)
|
|
}
|
|
|
|
// GetProjectFactByKey 按 key 获取事实。
|
|
func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) {
|
|
row := db.QueryRow(
|
|
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
|
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
|
|
projectID, factKey,
|
|
)
|
|
return scanProjectFactRow(row)
|
|
}
|
|
|
|
// GetProjectFact 按 ID 获取事实。
|
|
func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
|
|
row := db.QueryRow(
|
|
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
|
|
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
|
|
COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at
|
|
FROM project_facts WHERE id = ?`, id,
|
|
)
|
|
return scanProjectFactRow(row)
|
|
}
|
|
|
|
// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。
|
|
func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
|
if err := ValidateFactKey(f.FactKey); err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(f.Category) == "" {
|
|
f.Category = "note"
|
|
}
|
|
if strings.TrimSpace(f.Confidence) == "" {
|
|
f.Confidence = "tentative"
|
|
}
|
|
now := time.Now()
|
|
|
|
existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey)
|
|
if err == nil && existing != nil {
|
|
f.ID = existing.ID
|
|
f.CreatedAt = existing.CreatedAt
|
|
f.UpdatedAt = now
|
|
_, err = db.Exec(
|
|
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
|
|
source_conversation_id = ?, source_message_id = ?, pinned = ?,
|
|
supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ?
|
|
WHERE id = ?`,
|
|
f.Category, f.Summary, f.Body, f.Confidence,
|
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
|
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("更新事实失败: %w", err)
|
|
}
|
|
return f, nil
|
|
}
|
|
|
|
if f.ID == "" {
|
|
f.ID = uuid.New().String()
|
|
}
|
|
f.CreatedAt = now
|
|
f.UpdatedAt = now
|
|
_, err = db.Exec(
|
|
`INSERT INTO project_facts (
|
|
id, project_id, fact_key, category, summary, body, confidence,
|
|
source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id,
|
|
created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
|
|
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
|
|
nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID),
|
|
f.CreatedAt, f.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建事实失败: %w", err)
|
|
}
|
|
return f, nil
|
|
}
|
|
|
|
// DeprecateProjectFact 将事实标记为 deprecated。
|
|
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
|
res, err := db.Exec(
|
|
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
|
time.Now(), projectID, factKey,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return fmt.Errorf("事实不存在")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteProjectFact 删除事实。
|
|
func (db *DB) DeleteProjectFact(id string) error {
|
|
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
|
return err
|
|
}
|
|
|
|
func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) {
|
|
var out []*ProjectFact
|
|
for rows.Next() {
|
|
f, err := scanProjectFactFromRows(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, f)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
|
|
var f ProjectFact
|
|
var pinned int
|
|
var createdAt, updatedAt string
|
|
err := row.Scan(
|
|
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
|
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
|
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
|
)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("事实不存在")
|
|
}
|
|
return nil, err
|
|
}
|
|
f.Pinned = pinned != 0
|
|
f.CreatedAt = parseDBTime(createdAt)
|
|
f.UpdatedAt = parseDBTime(updatedAt)
|
|
return &f, nil
|
|
}
|
|
|
|
func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
|
|
var f ProjectFact
|
|
var pinned int
|
|
var createdAt, updatedAt string
|
|
err := rows.Scan(
|
|
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
|
|
&f.SourceConversationID, &f.SourceMessageID, &pinned,
|
|
&f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
f.Pinned = pinned != 0
|
|
f.CreatedAt = parseDBTime(createdAt)
|
|
f.UpdatedAt = parseDBTime(updatedAt)
|
|
return &f, nil
|
|
}
|
|
|
|
func boolToInt(b bool) int {
|
|
if b {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func nullIfEmpty(s string) interface{} {
|
|
if strings.TrimSpace(s) == "" {
|
|
return nil
|
|
}
|
|
return s
|
|
}
|
|
|
|
func parseDBTime(s string) time.Time {
|
|
s = strings.TrimSpace(s)
|
|
if s == "" {
|
|
return time.Time{}
|
|
}
|
|
// go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态
|
|
layouts := []string{
|
|
time.RFC3339Nano,
|
|
time.RFC3339,
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
"2006-01-02 15:04:05-07:00",
|
|
"2006-01-02T15:04:05.999999999-07:00",
|
|
"2006-01-02T15:04:05-07:00",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02 15:04:05",
|
|
"2006-01-02T15:04:05.999999999",
|
|
"2006-01-02T15:04:05",
|
|
}
|
|
for _, layout := range layouts {
|
|
if t, e := time.Parse(layout, s); e == nil {
|
|
return t
|
|
}
|
|
}
|
|
return time.Time{}
|
|
}
|