package database import ( "database/sql" "fmt" "time" "github.com/google/uuid" "go.uber.org/zap" ) // Vulnerability 漏洞 type Vulnerability struct { ID string `json:"id"` ConversationID string `json:"conversation_id"` Title string `json:"title"` Description string `json:"description"` Severity string `json:"severity"` // critical, high, medium, low, info Status string `json:"status"` // open, confirmed, fixed, false_positive Type string `json:"type"` Target string `json:"target"` Proof string `json:"proof"` Impact string `json:"impact"` Recommendation string `json:"recommendation"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // CreateVulnerability 创建漏洞 func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { if vuln.ID == "" { vuln.ID = uuid.New().String() } if vuln.Status == "" { vuln.Status = "open" } now := time.Now() if vuln.CreatedAt.IsZero() { vuln.CreatedAt = now } vuln.UpdatedAt = now query := ` INSERT INTO vulnerabilities ( id, conversation_id, title, description, severity, status, vulnerability_type, target, proof, impact, recommendation, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err := db.Exec( query, vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.CreatedAt, vuln.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("创建漏洞失败: %w", err) } return vuln, nil } // GetVulnerability 获取漏洞 func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { var vuln Vulnerability query := ` SELECT id, conversation_id, title, description, severity, status, vulnerability_type, target, proof, impact, recommendation, created_at, updated_at FROM vulnerabilities WHERE id = ? ` err := db.QueryRow(query, id).Scan( &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.CreatedAt, &vuln.UpdatedAt, ) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("漏洞不存在") } return nil, fmt.Errorf("获取漏洞失败: %w", err) } return &vuln, nil } // ListVulnerabilities 列出漏洞 func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) { query := ` SELECT id, conversation_id, title, description, severity, status, vulnerability_type, target, proof, impact, recommendation, created_at, updated_at FROM vulnerabilities WHERE 1=1 ` args := []interface{}{} if id != "" { query += " AND id = ?" args = append(args, id) } if conversationID != "" { query += " AND conversation_id = ?" args = append(args, conversationID) } if severity != "" { query += " AND severity = ?" args = append(args, severity) } if status != "" { query += " AND status = ?" args = append(args, status) } query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" args = append(args, limit, offset) rows, err := db.Query(query, args...) if err != nil { return nil, fmt.Errorf("查询漏洞列表失败: %w", err) } defer rows.Close() var vulnerabilities []*Vulnerability for rows.Next() { var vuln Vulnerability err := rows.Scan( &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.CreatedAt, &vuln.UpdatedAt, ) if err != nil { db.logger.Warn("扫描漏洞记录失败", zap.Error(err)) continue } vulnerabilities = append(vulnerabilities, &vuln) } return vulnerabilities, nil } // CountVulnerabilities 统计漏洞总数(支持筛选条件) func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) { query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" args := []interface{}{} if id != "" { query += " AND id = ?" args = append(args, id) } if conversationID != "" { query += " AND conversation_id = ?" args = append(args, conversationID) } if severity != "" { query += " AND severity = ?" args = append(args, severity) } if status != "" { query += " AND status = ?" args = append(args, status) } var count int err := db.QueryRow(query, args...).Scan(&count) if err != nil { return 0, fmt.Errorf("统计漏洞总数失败: %w", err) } return count, nil } // UpdateVulnerability 更新漏洞 func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { vuln.UpdatedAt = time.Now() query := ` UPDATE vulnerabilities SET title = ?, description = ?, severity = ?, status = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?, recommendation = ?, updated_at = ? WHERE id = ? ` _, err := db.Exec( query, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.UpdatedAt, id, ) if err != nil { return fmt.Errorf("更新漏洞失败: %w", err) } return nil } // DeleteVulnerability 删除漏洞 func (db *DB) DeleteVulnerability(id string) error { _, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id) if err != nil { return fmt.Errorf("删除漏洞失败: %w", err) } return nil } // GetVulnerabilityStats 获取漏洞统计 func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) { stats := make(map[string]interface{}) // 总漏洞数 var totalCount int query := "SELECT COUNT(*) FROM vulnerabilities" args := []interface{}{} if conversationID != "" { query += " WHERE conversation_id = ?" args = append(args, conversationID) } err := db.QueryRow(query, args...).Scan(&totalCount) if err != nil { return nil, fmt.Errorf("获取总漏洞数失败: %w", err) } stats["total"] = totalCount // 按严重程度统计 severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities" if conversationID != "" { severityQuery += " WHERE conversation_id = ?" } severityQuery += " GROUP BY severity" rows, err := db.Query(severityQuery, args...) if err != nil { return nil, fmt.Errorf("获取严重程度统计失败: %w", err) } defer rows.Close() severityStats := make(map[string]int) for rows.Next() { var severity string var count int if err := rows.Scan(&severity, &count); err != nil { continue } severityStats[severity] = count } stats["by_severity"] = severityStats // 按状态统计 statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities" if conversationID != "" { statusQuery += " WHERE conversation_id = ?" } statusQuery += " GROUP BY status" rows, err = db.Query(statusQuery, args...) if err != nil { return nil, fmt.Errorf("获取状态统计失败: %w", err) } defer rows.Close() statusStats := make(map[string]int) for rows.Next() { var status string var count int if err := rows.Scan(&status, &count); err != nil { continue } statusStats[status] = count } stats["by_status"] = statusStats return stats, nil }