mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-02 07:45:24 +02:00
373 lines
11 KiB
Go
373 lines
11 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Vulnerability 漏洞
|
|
type Vulnerability struct {
|
|
ID string `json:"id"`
|
|
ConversationID string `json:"conversation_id"`
|
|
ConversationTag string `json:"conversation_tag,omitempty"`
|
|
TaskTag string `json:"task_tag,omitempty"`
|
|
TaskID string `json:"task_id,omitempty"`
|
|
TaskQueueID string `json:"task_queue_id,omitempty"`
|
|
Title string `json:"title"`
|
|
Description string `json:"description"`
|
|
Severity string `json:"severity"` // critical, high, medium, low, info
|
|
Status string `json:"status"` // open, confirmed, fixed, false_positive
|
|
Type string `json:"type"`
|
|
Target string `json:"target"`
|
|
Proof string `json:"proof"`
|
|
Impact string `json:"impact"`
|
|
Recommendation string `json:"recommendation"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// CreateVulnerability 创建漏洞
|
|
func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
|
if vuln.ID == "" {
|
|
vuln.ID = uuid.New().String()
|
|
}
|
|
if vuln.Status == "" {
|
|
vuln.Status = "open"
|
|
}
|
|
now := time.Now()
|
|
if vuln.CreatedAt.IsZero() {
|
|
vuln.CreatedAt = now
|
|
}
|
|
vuln.UpdatedAt = now
|
|
|
|
query := `
|
|
INSERT INTO vulnerabilities (
|
|
id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
|
|
vulnerability_type, target, proof, impact, recommendation,
|
|
created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
_, err := db.Exec(
|
|
query,
|
|
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
|
|
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
|
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
|
vuln.CreatedAt, vuln.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建漏洞失败: %w", err)
|
|
}
|
|
|
|
return vuln, nil
|
|
}
|
|
|
|
// GetVulnerability 获取漏洞
|
|
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
|
var vuln Vulnerability
|
|
query := `
|
|
SELECT id, conversation_id, title, description, severity, status,
|
|
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
|
|
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
|
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
|
created_at, updated_at
|
|
FROM vulnerabilities
|
|
WHERE id = ?
|
|
`
|
|
|
|
err := db.QueryRow(query, id).Scan(
|
|
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
|
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
|
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
|
&vuln.TaskID, &vuln.TaskQueueID,
|
|
&vuln.CreatedAt, &vuln.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("漏洞不存在")
|
|
}
|
|
return nil, fmt.Errorf("获取漏洞失败: %w", err)
|
|
}
|
|
|
|
return &vuln, nil
|
|
}
|
|
|
|
// ListVulnerabilities 列出漏洞
|
|
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
|
|
query := `
|
|
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
|
vulnerability_type, target, proof, impact, recommendation,
|
|
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
|
|
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
|
|
created_at, updated_at
|
|
FROM vulnerabilities
|
|
WHERE 1=1
|
|
`
|
|
args := []interface{}{}
|
|
|
|
if id != "" {
|
|
query += " AND id = ?"
|
|
args = append(args, id)
|
|
}
|
|
if conversationID != "" {
|
|
query += " AND conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
if taskID != "" {
|
|
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
args = append(args, taskID, taskID)
|
|
}
|
|
if conversationTag != "" {
|
|
query += " AND conversation_tag = ?"
|
|
args = append(args, conversationTag)
|
|
}
|
|
if taskTag != "" {
|
|
query += " AND task_tag = ?"
|
|
args = append(args, taskTag)
|
|
}
|
|
if severity != "" {
|
|
query += " AND severity = ?"
|
|
args = append(args, severity)
|
|
}
|
|
if status != "" {
|
|
query += " AND status = ?"
|
|
args = append(args, status)
|
|
}
|
|
|
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
|
args = append(args, limit, offset)
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询漏洞列表失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var vulnerabilities []*Vulnerability
|
|
for rows.Next() {
|
|
var vuln Vulnerability
|
|
err := rows.Scan(
|
|
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
|
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
|
|
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
|
&vuln.TaskID, &vuln.TaskQueueID,
|
|
&vuln.CreatedAt, &vuln.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
db.logger.Warn("扫描漏洞记录失败", zap.Error(err))
|
|
continue
|
|
}
|
|
vulnerabilities = append(vulnerabilities, &vuln)
|
|
}
|
|
|
|
return vulnerabilities, nil
|
|
}
|
|
|
|
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
|
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
|
|
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
|
args := []interface{}{}
|
|
|
|
if id != "" {
|
|
query += " AND id = ?"
|
|
args = append(args, id)
|
|
}
|
|
if conversationID != "" {
|
|
query += " AND conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
if taskID != "" {
|
|
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
args = append(args, taskID, taskID)
|
|
}
|
|
if conversationTag != "" {
|
|
query += " AND conversation_tag = ?"
|
|
args = append(args, conversationTag)
|
|
}
|
|
if taskTag != "" {
|
|
query += " AND task_tag = ?"
|
|
args = append(args, taskTag)
|
|
}
|
|
if severity != "" {
|
|
query += " AND severity = ?"
|
|
args = append(args, severity)
|
|
}
|
|
if status != "" {
|
|
query += " AND status = ?"
|
|
args = append(args, status)
|
|
}
|
|
|
|
var count int
|
|
err := db.QueryRow(query, args...).Scan(&count)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("统计漏洞总数失败: %w", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// UpdateVulnerability 更新漏洞
|
|
func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
|
vuln.UpdatedAt = time.Now()
|
|
|
|
query := `
|
|
UPDATE vulnerabilities
|
|
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
|
|
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
|
recommendation = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`
|
|
|
|
_, err := db.Exec(
|
|
query,
|
|
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
|
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
|
vuln.Recommendation, vuln.UpdatedAt, id,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("更新漏洞失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteVulnerability 删除漏洞
|
|
func (db *DB) DeleteVulnerability(id string) error {
|
|
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
|
|
if err != nil {
|
|
return fmt.Errorf("删除漏洞失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetVulnerabilityStats 获取漏洞统计
|
|
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) {
|
|
stats := make(map[string]interface{})
|
|
|
|
// 总漏洞数
|
|
var totalCount int
|
|
query := "SELECT COUNT(*) FROM vulnerabilities"
|
|
args := []interface{}{}
|
|
if conversationID != "" {
|
|
query += " WHERE conversation_id = ?"
|
|
args = append(args, conversationID)
|
|
}
|
|
err := db.QueryRow(query, args...).Scan(&totalCount)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
|
|
}
|
|
stats["total"] = totalCount
|
|
|
|
// 按严重程度统计
|
|
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities"
|
|
if conversationID != "" {
|
|
severityQuery += " WHERE conversation_id = ?"
|
|
}
|
|
severityQuery += " GROUP BY severity"
|
|
|
|
rows, err := db.Query(severityQuery, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取严重程度统计失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
severityStats := make(map[string]int)
|
|
for rows.Next() {
|
|
var severity string
|
|
var count int
|
|
if err := rows.Scan(&severity, &count); err != nil {
|
|
continue
|
|
}
|
|
severityStats[severity] = count
|
|
}
|
|
stats["by_severity"] = severityStats
|
|
|
|
// 按状态统计
|
|
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities"
|
|
if conversationID != "" {
|
|
statusQuery += " WHERE conversation_id = ?"
|
|
}
|
|
statusQuery += " GROUP BY status"
|
|
|
|
rows, err = db.Query(statusQuery, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取状态统计失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
statusStats := make(map[string]int)
|
|
for rows.Next() {
|
|
var status string
|
|
var count int
|
|
if err := rows.Scan(&status, &count); err != nil {
|
|
continue
|
|
}
|
|
statusStats[status] = count
|
|
}
|
|
stats["by_status"] = statusStats
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
|
|
func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
|
|
collect := func(query string, args ...interface{}) ([]string, error) {
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
items := make([]string, 0)
|
|
for rows.Next() {
|
|
var val string
|
|
if err := rows.Scan(&val); err != nil {
|
|
continue
|
|
}
|
|
if val == "" {
|
|
continue
|
|
}
|
|
items = append(items, val)
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
|
|
}
|
|
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
|
|
}
|
|
taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询任务ID建议失败: %w", err)
|
|
}
|
|
queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询队列ID建议失败: %w", err)
|
|
}
|
|
conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询对话标签建议失败: %w", err)
|
|
}
|
|
taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
|
|
}
|
|
|
|
return map[string][]string{
|
|
"vulnerability_ids": vulnIDs,
|
|
"conversation_ids": conversationIDs,
|
|
"task_ids": taskIDs,
|
|
"queue_ids": queueIDs,
|
|
"conversation_tags": conversationTags,
|
|
"task_tags": taskTags,
|
|
}, nil
|
|
}
|
|
|