package database import ( "database/sql" "fmt" "time" "github.com/google/uuid" "go.uber.org/zap" ) // Vulnerability 漏洞 type Vulnerability struct { ID string `json:"id"` ConversationID string `json:"conversation_id"` ConversationTag string `json:"conversation_tag,omitempty"` TaskTag string `json:"task_tag,omitempty"` TaskID string `json:"task_id,omitempty"` TaskQueueID string `json:"task_queue_id,omitempty"` Title string `json:"title"` Description string `json:"description"` Severity string `json:"severity"` // critical, high, medium, low, info Status string `json:"status"` // open, confirmed, fixed, false_positive Type string `json:"type"` Target string `json:"target"` Proof string `json:"proof"` Impact string `json:"impact"` Recommendation string `json:"recommendation"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // CreateVulnerability 创建漏洞 func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { if vuln.ID == "" { vuln.ID = uuid.New().String() } if vuln.Status == "" { vuln.Status = "open" } now := time.Now() if vuln.CreatedAt.IsZero() { vuln.CreatedAt = now } vuln.UpdatedAt = now query := ` INSERT INTO vulnerabilities ( id, conversation_id, conversation_tag, task_tag, title, description, severity, status, vulnerability_type, target, proof, impact, recommendation, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err := db.Exec( query, vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.CreatedAt, vuln.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("创建漏洞失败: %w", err) } return vuln, nil } // GetVulnerability 获取漏洞 func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { var vuln Vulnerability query := ` SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, created_at, updated_at FROM vulnerabilities WHERE id = ? ` err := db.QueryRow(query, id).Scan( &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.TaskID, &vuln.TaskQueueID, &vuln.CreatedAt, &vuln.UpdatedAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("漏洞不存在") } return nil, fmt.Errorf("获取漏洞失败: %w", err) } return &vuln, nil } // ListVulnerabilities 列出漏洞 func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) { query := ` SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, created_at, updated_at FROM vulnerabilities WHERE 1=1 ` args := []interface{}{} if id != "" { query += " AND id = ?" args = append(args, id) } if conversationID != "" { query += " AND conversation_id = ?" args = append(args, conversationID) } if taskID != "" { query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" args = append(args, taskID, taskID) } if conversationTag != "" { query += " AND conversation_tag = ?" args = append(args, conversationTag) } if taskTag != "" { query += " AND task_tag = ?" args = append(args, taskTag) } if severity != "" { query += " AND severity = ?" args = append(args, severity) } if status != "" { query += " AND status = ?" args = append(args, status) } query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" args = append(args, limit, offset) rows, err := db.Query(query, args...) if err != nil { return nil, fmt.Errorf("查询漏洞列表失败: %w", err) } defer rows.Close() var vulnerabilities []*Vulnerability for rows.Next() { var vuln Vulnerability err := rows.Scan( &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.TaskID, &vuln.TaskQueueID, &vuln.CreatedAt, &vuln.UpdatedAt, ) if err != nil { db.logger.Warn("扫描漏洞记录失败", zap.Error(err)) continue } vulnerabilities = append(vulnerabilities, &vuln) } return vulnerabilities, nil } // CountVulnerabilities 统计漏洞总数(支持筛选条件) func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) { query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" args := []interface{}{} if id != "" { query += " AND id = ?" args = append(args, id) } if conversationID != "" { query += " AND conversation_id = ?" args = append(args, conversationID) } if taskID != "" { query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" args = append(args, taskID, taskID) } if conversationTag != "" { query += " AND conversation_tag = ?" args = append(args, conversationTag) } if taskTag != "" { query += " AND task_tag = ?" args = append(args, taskTag) } if severity != "" { query += " AND severity = ?" args = append(args, severity) } if status != "" { query += " AND status = ?" args = append(args, status) } var count int err := db.QueryRow(query, args...).Scan(&count) if err != nil { return 0, fmt.Errorf("统计漏洞总数失败: %w", err) } return count, nil } // UpdateVulnerability 更新漏洞 func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { vuln.UpdatedAt = time.Now() query := ` UPDATE vulnerabilities SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?, recommendation = ?, updated_at = ? WHERE id = ? ` _, err := db.Exec( query, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.UpdatedAt, id, ) if err != nil { return fmt.Errorf("更新漏洞失败: %w", err) } return nil } // DeleteVulnerability 删除漏洞 func (db *DB) DeleteVulnerability(id string) error { _, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id) if err != nil { return fmt.Errorf("删除漏洞失败: %w", err) } return nil } // GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致) func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) { stats := make(map[string]interface{}) where := "WHERE 1=1" args := []interface{}{} if conversationID != "" { where += " AND conversation_id = ?" args = append(args, conversationID) } if taskID != "" { where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" args = append(args, taskID, taskID) } // 总漏洞数 var totalCount int query := "SELECT COUNT(*) FROM vulnerabilities " + where err := db.QueryRow(query, args...).Scan(&totalCount) if err != nil { return nil, fmt.Errorf("获取总漏洞数失败: %w", err) } stats["total"] = totalCount // 按严重程度统计 severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity" rows, err := db.Query(severityQuery, args...) if err != nil { return nil, fmt.Errorf("获取严重程度统计失败: %w", err) } defer rows.Close() severityStats := make(map[string]int) for rows.Next() { var severity string var count int if err := rows.Scan(&severity, &count); err != nil { continue } severityStats[severity] = count } stats["by_severity"] = severityStats // 按状态统计 statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status" rows, err = db.Query(statusQuery, args...) if err != nil { return nil, fmt.Errorf("获取状态统计失败: %w", err) } defer rows.Close() statusStats := make(map[string]int) for rows.Next() { var status string var count int if err := rows.Scan(&status, &count); err != nil { continue } statusStats[status] = count } stats["by_status"] = statusStats return stats, nil } // GetVulnerabilityFilterOptions 获取漏洞筛选建议项 func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) { collect := func(query string, args ...interface{}) ([]string, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err } defer rows.Close() items := make([]string, 0) for rows.Next() { var val string if err := rows.Scan(&val); err != nil { continue } if val == "" { continue } items = append(items, val) } return items, nil } vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err) } conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询会话ID建议失败: %w", err) } taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询任务ID建议失败: %w", err) } queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询队列ID建议失败: %w", err) } conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询对话标签建议失败: %w", err) } taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`) if err != nil { return nil, fmt.Errorf("查询任务标签建议失败: %w", err) } return map[string][]string{ "vulnerability_ids": vulnIDs, "conversation_ids": conversationIDs, "task_ids": taskIDs, "queue_ids": queueIDs, "conversation_tags": conversationTags, "task_tags": taskTags, }, nil }