Add files via upload

This commit is contained in:
公明
2026-04-28 01:05:58 +08:00
committed by GitHub
parent 0597838217
commit fad6b3c808
2 changed files with 141 additions and 11 deletions
+39
View File
@@ -197,6 +197,8 @@ func (db *DB) initTables() error {
CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
conversation_tag TEXT,
task_tag TEXT,
title TEXT NOT NULL,
description TEXT,
severity TEXT NOT NULL,
@@ -289,6 +291,8 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
@@ -383,6 +387,10 @@ func (db *DB) initTables() error {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateVulnerabilitiesTable(); err != nil {
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
@@ -683,6 +691,37 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
return nil
}
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
func (db *DB) migrateVulnerabilitiesTable() error {
columns := []struct {
name string
stmt string
}{
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
}
for _, col := range columns {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('vulnerabilities') WHERE name=?", col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
+102 -11
View File
@@ -13,6 +13,10 @@ import (
type Vulnerability struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
ConversationTag string `json:"conversation_tag,omitempty"`
TaskTag string `json:"task_tag,omitempty"`
TaskID string `json:"task_id,omitempty"`
TaskQueueID string `json:"task_queue_id,omitempty"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info
@@ -42,15 +46,15 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
query := `
INSERT INTO vulnerabilities (
id, conversation_id, title, description, severity, status,
id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(
query,
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description,
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt,
@@ -67,7 +71,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability
query := `
SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at
FROM vulnerabilities
WHERE id = ?
@@ -75,8 +81,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
@@ -90,10 +97,12 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
}
// ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) {
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
query := `
SELECT id, conversation_id, title, description, severity, status,
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at
FROM vulnerabilities
WHERE 1=1
@@ -108,6 +117,18 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
@@ -131,8 +152,9 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
var vuln Vulnerability
err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
@@ -146,7 +168,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
}
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) {
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{}
@@ -158,6 +180,18 @@ func (db *DB) CountVulnerabilities(id, conversationID, severity, status string)
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
@@ -182,7 +216,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
query := `
UPDATE vulnerabilities
SET title = ?, description = ?, severity = ?, status = ?,
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ?
WHERE id = ?
@@ -190,7 +224,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
_, err := db.Exec(
query,
vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id,
)
@@ -279,3 +313,60 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
return stats, nil
}
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
collect := func(query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]string, 0)
for rows.Next() {
var val string
if err := rows.Scan(&val); err != nil {
continue
}
if val == "" {
continue
}
items = append(items, val)
}
return items, nil
}
vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
}
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
}
taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务ID建议失败: %w", err)
}
queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询队列ID建议失败: %w", err)
}
conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询对话标签建议失败: %w", err)
}
taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
}
return map[string][]string{
"vulnerability_ids": vulnIDs,
"conversation_ids": conversationIDs,
"task_ids": taskIDs,
"queue_ids": queueIDs,
"conversation_tags": conversationTags,
"task_tags": taskTags,
}, nil
}