From 4661862a1abce4766f7c44b29c9d464bd75bd5de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:03:09 +0800 Subject: [PATCH] Add files via upload --- internal/database/conversation.go | 20 ++- .../conversation_vulnerability_test.go | 69 ++++++++++ internal/database/database.go | 118 +++++++++++++++++- internal/database/vulnerability.go | 8 +- internal/handler/openapi.go | 2 +- 5 files changed, 205 insertions(+), 12 deletions(-) create mode 100644 internal/database/conversation_vulnerability_test.go diff --git a/internal/database/conversation.go b/internal/database/conversation.go index 7427f478..ccff1e0e 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -543,18 +543,28 @@ func (db *DB) UpdateConversationTime(id string) error { return nil } -// DeleteConversation 删除对话及其所有相关数据 +// DeleteConversation 删除对话及其会话相关数据。 // 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除: // - messages(消息) // - process_details(过程详情) // - attack_chain_nodes(攻击链节点) // - attack_chain_edges(攻击链边) -// - vulnerabilities(漏洞) // - conversation_group_mappings(分组映射) -// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL +// 漏洞记录会保留:vulnerabilities.conversation_id 使用 ON DELETE SET NULL,仅解除与会话的关联。 +// 注意:knowledge_retrieval_logs 在删除前会被显式清理。 func (db *DB) DeleteConversation(id string) error { + // 删除对话前补全漏洞来源标签,便于在漏洞库中追溯已删除会话的发现。 + _, err := db.Exec(` + UPDATE vulnerabilities + SET conversation_tag = COALESCE(NULLIF(TRIM(conversation_tag), ''), (SELECT title FROM conversations WHERE id = ?)) + WHERE conversation_id = ? + `, id, id) + if err != nil { + db.logger.Warn("更新漏洞来源标签失败", zap.String("conversationId", id), zap.Error(err)) + } + // 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除) - _, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) + _, err = db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) if err != nil { db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err)) // 不返回错误,继续删除对话 @@ -567,7 +577,7 @@ func (db *DB) DeleteConversation(id string) error { } db.removeConversationScopedDirs(id) - db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) + db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id)) return nil } diff --git a/internal/database/conversation_vulnerability_test.go b/internal/database/conversation_vulnerability_test.go new file mode 100644 index 00000000..f173d5ab --- /dev/null +++ b/internal/database/conversation_vulnerability_test.go @@ -0,0 +1,69 @@ +package database + +import ( + "path/filepath" + "testing" + + "go.uber.org/zap" +) + +func TestDeleteConversationPreservesVulnerabilities(t *testing.T) { + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "vuln-preserve.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatalf("NewDB: %v", err) + } + defer db.Close() + + conv, err := db.CreateConversation("vuln source chat", ConversationCreateMeta{}) + if err != nil { + t.Fatalf("CreateConversation: %v", err) + } + + vuln, err := db.CreateVulnerability(&Vulnerability{ + ConversationID: conv.ID, + Title: "SQL Injection", + Severity: "high", + Status: "open", + }) + if err != nil { + t.Fatalf("CreateVulnerability: %v", err) + } + + if err := db.DeleteConversation(conv.ID); err != nil { + t.Fatalf("DeleteConversation: %v", err) + } + + got, err := db.GetVulnerability(vuln.ID) + if err != nil { + t.Fatalf("GetVulnerability after delete: %v", err) + } + if got.Title != "SQL Injection" { + t.Fatalf("title = %q, want SQL Injection", got.Title) + } + if got.ConversationID != "" { + t.Fatalf("conversation_id = %q, want empty after conversation delete", got.ConversationID) + } + if got.ConversationTag != "vuln source chat" { + t.Fatalf("conversation_tag = %q, want vuln source chat", got.ConversationTag) + } +} + +func TestMigrateVulnerabilitiesConversationFK(t *testing.T) { + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "vuln-fk-migrate.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatalf("NewDB: %v", err) + } + defer db.Close() + + ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB) + if err != nil { + t.Fatalf("vulnerabilitiesConversationFKOnDeleteSetNull: %v", err) + } + if !ok { + t.Fatal("expected vulnerabilities.conversation_id FK to use ON DELETE SET NULL") + } +} diff --git a/internal/database/database.go b/internal/database/database.go index 78c2108a..4be5b95e 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -357,7 +357,7 @@ func (db *DB) initTables() error { createVulnerabilitiesTable := ` CREATE TABLE IF NOT EXISTS vulnerabilities ( id TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, + conversation_id TEXT, conversation_tag TEXT, task_tag TEXT, title TEXT NOT NULL, @@ -371,7 +371,8 @@ func (db *DB) initTables() error { recommendation TEXT, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + project_id TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL );` // 创建批量任务队列表 @@ -737,6 +738,9 @@ func (db *DB) initTables() error { db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err)) // 不返回错误,允许继续运行 } + if err := db.migrateVulnerabilitiesConversationFK(); err != nil { + db.logger.Warn("迁移vulnerabilities会话外键失败", zap.Error(err)) + } if err := db.migrateProjectsTable(); err != nil { db.logger.Warn("迁移projects相关表失败", zap.Error(err)) @@ -1146,6 +1150,116 @@ func (db *DB) dropProjectFactVersionsTable() error { return err } +// migrateVulnerabilitiesConversationFK 将 vulnerabilities.conversation_id 外键改为 ON DELETE SET NULL,删除对话时保留漏洞记录。 +func (db *DB) migrateVulnerabilitiesConversationFK() error { + ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB) + if err != nil { + return err + } + if ok { + return nil + } + + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("开启事务失败: %w", err) + } + defer func() { _ = tx.Rollback() }() + + const createNew = ` + CREATE TABLE vulnerabilities_new ( + id TEXT PRIMARY KEY, + conversation_id TEXT, + conversation_tag TEXT, + task_tag TEXT, + title TEXT NOT NULL, + description TEXT, + severity TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'open', + vulnerability_type TEXT, + target TEXT, + proof TEXT, + impact TEXT, + recommendation TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + project_id TEXT, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL + );` + if _, err := tx.Exec(createNew); err != nil { + return fmt.Errorf("创建 vulnerabilities_new 失败: %w", err) + } + + const copyRows = ` + INSERT INTO vulnerabilities_new ( + id, conversation_id, conversation_tag, task_tag, title, description, + severity, status, vulnerability_type, target, proof, impact, recommendation, + created_at, updated_at, project_id + ) + SELECT + id, conversation_id, conversation_tag, task_tag, title, description, + severity, status, vulnerability_type, target, proof, impact, recommendation, + created_at, updated_at, project_id + FROM vulnerabilities;` + if _, err := tx.Exec(copyRows); err != nil { + return fmt.Errorf("复制 vulnerabilities 数据失败: %w", err) + } + if _, err := tx.Exec(`DROP TABLE vulnerabilities`); err != nil { + return fmt.Errorf("删除旧 vulnerabilities 表失败: %w", err) + } + if _, err := tx.Exec(`ALTER TABLE vulnerabilities_new RENAME TO vulnerabilities`); err != nil { + return fmt.Errorf("重命名 vulnerabilities 表失败: %w", err) + } + + indexes := []string{ + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id)`, + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag)`, + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag)`, + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity)`, + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status)`, + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at)`, + `CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id)`, + } + for _, stmt := range indexes { + if _, err := tx.Exec(stmt); err != nil { + return fmt.Errorf("重建 vulnerabilities 索引失败: %w", err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("提交 vulnerabilities 外键迁移失败: %w", err) + } + db.logger.Info("vulnerabilities 表已迁移:删除对话时保留漏洞记录") + return nil +} + +func vulnerabilitiesConversationFKOnDeleteSetNull(db *sql.DB) (bool, error) { + rows, err := db.Query(`PRAGMA foreign_key_list(vulnerabilities)`) + if err != nil { + return false, err + } + defer rows.Close() + + found := false + for rows.Next() { + var id, seq int + var table, from, to, onUpdate, onDelete, match string + if err := rows.Scan(&id, &seq, &table, &from, &to, &onUpdate, &onDelete, &match); err != nil { + return false, err + } + if from == "conversation_id" { + found = true + if !strings.EqualFold(onDelete, "SET NULL") { + return false, nil + } + } + } + if err := rows.Err(); err != nil { + return false, err + } + return found, nil +} + // migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段 func (db *DB) migrateVulnerabilitiesTable() error { columns := []struct { diff --git a/internal/database/vulnerability.go b/internal/database/vulnerability.go index 8ca3352c..487f1d8b 100644 --- a/internal/database/vulnerability.go +++ b/internal/database/vulnerability.go @@ -138,7 +138,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { _, err := db.Exec( query, - vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, + vuln.ID, nullIfEmpty(vuln.ConversationID), nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.CreatedAt, vuln.UpdatedAt, @@ -154,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { var vuln Vulnerability query := ` - SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, + SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, @@ -183,7 +183,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { // ListVulnerabilities 列出漏洞 func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) { query := ` - SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, + SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, @@ -403,7 +403,7 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) { if err != nil { return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err) } - conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) + conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id IS NOT NULL AND conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询会话ID建议失败: %w", err) } diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index d0c4dc71..08edf59a 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -1344,7 +1344,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "delete": map[string]interface{}{ "tags": []string{"对话管理"}, "summary": "删除对话", - "description": "删除指定的对话及其所有相关数据(消息、漏洞等)。**此操作不可恢复**。", + "description": "删除指定的对话及其会话数据(消息、攻击链等)。**漏洞记录会保留**,仅解除与会话的关联。**此操作不可恢复**。", "operationId": "deleteConversation", "parameters": []map[string]interface{}{ {