diff --git a/internal/database/vulnerability.go b/internal/database/vulnerability.go index 1a584bf6..06d360f2 100644 --- a/internal/database/vulnerability.go +++ b/internal/database/vulnerability.go @@ -3,12 +3,84 @@ package database import ( "database/sql" "fmt" + "strings" "time" "github.com/google/uuid" "go.uber.org/zap" ) +// VulnerabilityListFilter 列表/统计/导出共用的筛选条件 +type VulnerabilityListFilter struct { + ID string + Search string // 关键词模糊匹配(标题、描述、类型、目标等) + ConversationID string + Severity string + Status string + TaskID string + ConversationTag string + TaskTag string +} + +func escapeVulnerabilityLikePattern(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `%`, `\%`) + s = strings.ReplaceAll(s, `_`, `\_`) + return "%" + s + "%" +} + +func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) { + if f.ID != "" { + query += " AND id = ?" + args = append(args, f.ID) + } + if f.ConversationID != "" { + query += " AND conversation_id = ?" + args = append(args, f.ConversationID) + } + if f.TaskID != "" { + query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" + args = append(args, f.TaskID, f.TaskID) + } + if f.ConversationTag != "" { + query += " AND conversation_tag = ?" + args = append(args, f.ConversationTag) + } + if f.TaskTag != "" { + query += " AND task_tag = ?" + args = append(args, f.TaskTag) + } + if f.Severity != "" { + query += " AND severity = ?" + args = append(args, f.Severity) + } + if f.Status != "" { + query += " AND status = ?" + args = append(args, f.Status) + } + search := strings.TrimSpace(f.Search) + if search != "" { + pattern := escapeVulnerabilityLikePattern(search) + query += ` AND ( + LOWER(id) LIKE LOWER(?) OR + LOWER(title) LIKE LOWER(?) OR + LOWER(COALESCE(description, '')) LIKE LOWER(?) OR + LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR + LOWER(COALESCE(target, '')) LIKE LOWER(?) OR + LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR + LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR + LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR + LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR + LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR + LOWER(COALESCE(task_tag, '')) LIKE LOWER(?) + )` + for i := 0; i < 11; i++ { + args = append(args, pattern) + } + } + return query, args +} + // Vulnerability 漏洞 type Vulnerability struct { ID string `json:"id"` @@ -97,7 +169,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { } // ListVulnerabilities 列出漏洞 -func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) { +func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) { query := ` SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, @@ -108,35 +180,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit WHERE 1=1 ` args := []interface{}{} - - if id != "" { - query += " AND id = ?" - args = append(args, id) - } - if conversationID != "" { - query += " AND conversation_id = ?" - args = append(args, conversationID) - } - if taskID != "" { - query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" - args = append(args, taskID, taskID) - } - if conversationTag != "" { - query += " AND conversation_tag = ?" - args = append(args, conversationTag) - } - if taskTag != "" { - query += " AND task_tag = ?" - args = append(args, taskTag) - } - if severity != "" { - query += " AND severity = ?" - args = append(args, severity) - } - if status != "" { - query += " AND status = ?" - args = append(args, status) - } + query, args = filter.appendWhere(query, args) query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" args = append(args, limit, offset) @@ -168,38 +212,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit } // CountVulnerabilities 统计漏洞总数(支持筛选条件) -func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) { +func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) { query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" args := []interface{}{} - - if id != "" { - query += " AND id = ?" - args = append(args, id) - } - if conversationID != "" { - query += " AND conversation_id = ?" - args = append(args, conversationID) - } - if taskID != "" { - query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" - args = append(args, taskID, taskID) - } - if conversationTag != "" { - query += " AND conversation_tag = ?" - args = append(args, conversationTag) - } - if taskTag != "" { - query += " AND task_tag = ?" - args = append(args, taskTag) - } - if severity != "" { - query += " AND severity = ?" - args = append(args, severity) - } - if status != "" { - query += " AND status = ?" - args = append(args, status) - } + query, args = filter.appendWhere(query, args) var count int err := db.QueryRow(query, args...).Scan(&count) @@ -245,19 +261,12 @@ func (db *DB) DeleteVulnerability(id string) error { } // GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致) -func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) { +func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) { stats := make(map[string]interface{}) where := "WHERE 1=1" args := []interface{}{} - if conversationID != "" { - where += " AND conversation_id = ?" - args = append(args, conversationID) - } - if taskID != "" { - where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" - args = append(args, taskID, taskID) - } + where, args = filter.appendWhere(where, args) // 总漏洞数 var totalCount int