From fad6b3c808a9e76539a809c5d34a48155b3aef83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:05:58 +0800 Subject: [PATCH] Add files via upload --- internal/database/database.go | 39 ++++++++++ internal/database/vulnerability.go | 113 ++++++++++++++++++++++++++--- 2 files changed, 141 insertions(+), 11 deletions(-) diff --git a/internal/database/database.go b/internal/database/database.go index 61fc053f..34afbe20 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -197,6 +197,8 @@ func (db *DB) initTables() error { CREATE TABLE IF NOT EXISTS vulnerabilities ( id TEXT PRIMARY KEY, conversation_id TEXT NOT NULL, + conversation_tag TEXT, + task_tag TEXT, title TEXT NOT NULL, description TEXT, severity TEXT NOT NULL, @@ -289,6 +291,8 @@ func (db *DB) initTables() error { CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id); CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); @@ -383,6 +387,10 @@ func (db *DB) initTables() error { db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err)) // 不返回错误,允许继续运行 } + if err := db.migrateVulnerabilitiesTable(); err != nil { + db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } if _, err := db.Exec(createIndexes); err != nil { return fmt.Errorf("创建索引失败: %w", err) @@ -683,6 +691,37 @@ func (db *DB) migrateBatchTaskQueuesTable() error { return nil } +// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段 +func (db *DB) migrateVulnerabilitiesTable() error { + columns := []struct { + name string + stmt string + }{ + {name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"}, + {name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"}, + } + + for _, col := range columns { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('vulnerabilities') WHERE name=?", col.name).Scan(&count) + if err != nil { + if _, addErr := db.Exec(col.stmt); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr)) + } + } + continue + } + if count == 0 { + if _, addErr := db.Exec(col.stmt); addErr != nil { + db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr)) + } + } + } + return nil +} + // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") diff --git a/internal/database/vulnerability.go b/internal/database/vulnerability.go index c4ec69b2..20b2d8da 100644 --- a/internal/database/vulnerability.go +++ b/internal/database/vulnerability.go @@ -13,6 +13,10 @@ import ( type Vulnerability struct { ID string `json:"id"` ConversationID string `json:"conversation_id"` + ConversationTag string `json:"conversation_tag,omitempty"` + TaskTag string `json:"task_tag,omitempty"` + TaskID string `json:"task_id,omitempty"` + TaskQueueID string `json:"task_queue_id,omitempty"` Title string `json:"title"` Description string `json:"description"` Severity string `json:"severity"` // critical, high, medium, low, info @@ -42,15 +46,15 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { query := ` INSERT INTO vulnerabilities ( - id, conversation_id, title, description, severity, status, + id, conversation_id, conversation_tag, task_tag, title, description, severity, status, vulnerability_type, target, proof, impact, recommendation, created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err := db.Exec( query, - vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description, + vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.CreatedAt, vuln.UpdatedAt, @@ -67,7 +71,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { var vuln Vulnerability query := ` SELECT id, conversation_id, title, description, severity, status, - vulnerability_type, target, proof, impact, recommendation, + conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, + COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, + COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, created_at, updated_at FROM vulnerabilities WHERE id = ? @@ -75,8 +81,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { err := db.QueryRow(query, id).Scan( &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, - &vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, + &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, + &vuln.TaskID, &vuln.TaskQueueID, &vuln.CreatedAt, &vuln.UpdatedAt, ) if err != nil { @@ -90,10 +97,12 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { } // ListVulnerabilities 列出漏洞 -func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) { +func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) { query := ` - SELECT id, conversation_id, title, description, severity, status, + SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, + COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, + COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, created_at, updated_at FROM vulnerabilities WHERE 1=1 @@ -108,6 +117,18 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit query += " AND conversation_id = ?" args = append(args, conversationID) } + if taskID != "" { + query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" + args = append(args, taskID, taskID) + } + if conversationTag != "" { + query += " AND conversation_tag = ?" + args = append(args, conversationTag) + } + if taskTag != "" { + query += " AND task_tag = ?" + args = append(args, taskTag) + } if severity != "" { query += " AND severity = ?" args = append(args, severity) @@ -131,8 +152,9 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit var vuln Vulnerability err := rows.Scan( &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, - &vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, + &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, + &vuln.TaskID, &vuln.TaskQueueID, &vuln.CreatedAt, &vuln.UpdatedAt, ) if err != nil { @@ -146,7 +168,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit } // CountVulnerabilities 统计漏洞总数(支持筛选条件) -func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) { +func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) { query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" args := []interface{}{} @@ -158,6 +180,18 @@ func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) query += " AND conversation_id = ?" args = append(args, conversationID) } + if taskID != "" { + query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" + args = append(args, taskID, taskID) + } + if conversationTag != "" { + query += " AND conversation_tag = ?" + args = append(args, conversationTag) + } + if taskTag != "" { + query += " AND task_tag = ?" + args = append(args, taskTag) + } if severity != "" { query += " AND severity = ?" args = append(args, severity) @@ -182,7 +216,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { query := ` UPDATE vulnerabilities - SET title = ?, description = ?, severity = ?, status = ?, + SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?, recommendation = ?, updated_at = ? WHERE id = ? @@ -190,7 +224,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { _, err := db.Exec( query, - vuln.Title, vuln.Description, vuln.Severity, vuln.Status, + vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.UpdatedAt, id, ) @@ -279,3 +313,60 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface return stats, nil } +// GetVulnerabilityFilterOptions 获取漏洞筛选建议项 +func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) { + collect := func(query string, args ...interface{}) ([]string, error) { + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + items := make([]string, 0) + for rows.Next() { + var val string + if err := rows.Scan(&val); err != nil { + continue + } + if val == "" { + continue + } + items = append(items, val) + } + return items, nil + } + + vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`) + if err != nil { + return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err) + } + conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) + if err != nil { + return nil, fmt.Errorf("查询会话ID建议失败: %w", err) + } + taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`) + if err != nil { + return nil, fmt.Errorf("查询任务ID建议失败: %w", err) + } + queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`) + if err != nil { + return nil, fmt.Errorf("查询队列ID建议失败: %w", err) + } + conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`) + if err != nil { + return nil, fmt.Errorf("查询对话标签建议失败: %w", err) + } + taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`) + if err != nil { + return nil, fmt.Errorf("查询任务标签建议失败: %w", err) + } + + return map[string][]string{ + "vulnerability_ids": vulnIDs, + "conversation_ids": conversationIDs, + "task_ids": taskIDs, + "queue_ids": queueIDs, + "conversation_tags": conversationTags, + "task_tags": taskTags, + }, nil +} +