mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 00:09:29 +02:00
282 lines
7.2 KiB
Go
282 lines
7.2 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Vulnerability 漏洞
|
|
type Vulnerability struct {
|
|
ID string `json:"id"`
|
|
ConversationID string `json:"conversation_id"`
|
|
Title string `json:"title"`
|
|
Description string `json:"description"`
|
|
Severity string `json:"severity"` // critical, high, medium, low, info
|
|
Status string `json:"status"` // open, confirmed, fixed, false_positive
|
|
Type string `json:"type"`
|
|
Target string `json:"target"`
|
|
Proof string `json:"proof"`
|
|
Impact string `json:"impact"`
|
|
Recommendation string `json:"recommendation"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// CreateVulnerability 创建漏洞
|
|
func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
|
if vuln.ID == "" {
|
|
vuln.ID = uuid.New().String()
|
|
}
|
|
if vuln.Status == "" {
|
|
vuln.Status = "open"
|
|
}
|
|
now := time.Now()
|
|
if vuln.CreatedAt.IsZero() {
|
|
vuln.CreatedAt = now
|
|
}
|
|
vuln.UpdatedAt = now
|
|
|
|
query := `
|
|
INSERT INTO vulnerabilities (
|
|
id, conversation_id, title, description, severity, status,
|
|
vulnerability_type, target, proof, impact, recommendation,
|
|
created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
_, err := db.Exec(
|
|
query,
|
|
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description,
|
|
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
|
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
|
vuln.CreatedAt, vuln.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建漏洞失败: %w", err)
|
|
}
|
|
|
|
return vuln, nil
|
|
}
|
|
|
|
// GetVulnerability 获取漏洞
|
|
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
|
var vuln Vulnerability
|
|
query := `
|
|
SELECT id, conversation_id, title, description, severity, status,
|
|
vulnerability_type, target, proof, impact, recommendation,
|
|
created_at, updated_at
|
|
FROM vulnerabilities
|
|
WHERE id = ?
|
|
`
|
|
|
|
err := db.QueryRow(query, id).Scan(
|
|
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
|
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
|
|
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
|
&vuln.CreatedAt, &vuln.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("漏洞不存在")
|
|
}
|
|
return nil, fmt.Errorf("获取漏洞失败: %w", err)
|
|
}
|
|
|
|
return &vuln, nil
|
|
}
|
|
|
|
// ListVulnerabilities 列出漏洞
|
|
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) {
|
|
query := `
|
|
SELECT id, conversation_id, title, description, severity, status,
|
|
vulnerability_type, target, proof, impact, recommendation,
|
|
created_at, updated_at
|
|
FROM vulnerabilities
|
|
WHERE 1=1
|
|
`
|
|
args := []interface{}{}
|
|
|
|
if id != "" {
|
|
query += " AND id = ?"
|
|
args = append(args, id)
|
|
}
|
|
if conversationID != "" {
|
|
query += " AND conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
if severity != "" {
|
|
query += " AND severity = ?"
|
|
args = append(args, severity)
|
|
}
|
|
if status != "" {
|
|
query += " AND status = ?"
|
|
args = append(args, status)
|
|
}
|
|
|
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
|
args = append(args, limit, offset)
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询漏洞列表失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var vulnerabilities []*Vulnerability
|
|
for rows.Next() {
|
|
var vuln Vulnerability
|
|
err := rows.Scan(
|
|
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
|
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
|
|
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
|
&vuln.CreatedAt, &vuln.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
db.logger.Warn("扫描漏洞记录失败", zap.Error(err))
|
|
continue
|
|
}
|
|
vulnerabilities = append(vulnerabilities, &vuln)
|
|
}
|
|
|
|
return vulnerabilities, nil
|
|
}
|
|
|
|
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
|
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) {
|
|
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
|
args := []interface{}{}
|
|
|
|
if id != "" {
|
|
query += " AND id = ?"
|
|
args = append(args, id)
|
|
}
|
|
if conversationID != "" {
|
|
query += " AND conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
if severity != "" {
|
|
query += " AND severity = ?"
|
|
args = append(args, severity)
|
|
}
|
|
if status != "" {
|
|
query += " AND status = ?"
|
|
args = append(args, status)
|
|
}
|
|
|
|
var count int
|
|
err := db.QueryRow(query, args...).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("统计漏洞总数失败: %w", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// UpdateVulnerability 更新漏洞
|
|
func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
|
vuln.UpdatedAt = time.Now()
|
|
|
|
query := `
|
|
UPDATE vulnerabilities
|
|
SET title = ?, description = ?, severity = ?, status = ?,
|
|
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
|
recommendation = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := db.Exec(
|
|
query,
|
|
vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
|
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
|
vuln.Recommendation, vuln.UpdatedAt, id,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("更新漏洞失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteVulnerability 删除漏洞
|
|
func (db *DB) DeleteVulnerability(id string) error {
|
|
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
|
|
if err != nil {
|
|
return fmt.Errorf("删除漏洞失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetVulnerabilityStats 获取漏洞统计
|
|
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) {
|
|
stats := make(map[string]interface{})
|
|
|
|
// 总漏洞数
|
|
var totalCount int
|
|
query := "SELECT COUNT(*) FROM vulnerabilities"
|
|
args := []interface{}{}
|
|
if conversationID != "" {
|
|
query += " WHERE conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
err := db.QueryRow(query, args...).Scan(&totalCount)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
|
|
}
|
|
stats["total"] = totalCount
|
|
|
|
// 按严重程度统计
|
|
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities"
|
|
if conversationID != "" {
|
|
severityQuery += " WHERE conversation_id = ?"
|
|
}
|
|
severityQuery += " GROUP BY severity"
|
|
|
|
rows, err := db.Query(severityQuery, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取严重程度统计失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
severityStats := make(map[string]int)
|
|
for rows.Next() {
|
|
var severity string
|
|
var count int
|
|
if err := rows.Scan(&severity, &count); err != nil {
|
|
continue
|
|
}
|
|
severityStats[severity] = count
|
|
}
|
|
stats["by_severity"] = severityStats
|
|
|
|
// 按状态统计
|
|
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities"
|
|
if conversationID != "" {
|
|
statusQuery += " WHERE conversation_id = ?"
|
|
}
|
|
statusQuery += " GROUP BY status"
|
|
|
|
rows, err = db.Query(statusQuery, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取状态统计失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
statusStats := make(map[string]int)
|
|
for rows.Next() {
|
|
var status string
|
|
var count int
|
|
if err := rows.Scan(&status, &count); err != nil {
|
|
continue
|
|
}
|
|
statusStats[status] = count
|
|
}
|
|
stats["by_status"] = statusStats
|
|
|
|
return stats, nil
|
|
}
|
|
|