From 21677350222d97caf9bfa654740bae99ad4c4787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Thu, 14 May 2026 19:24:58 +0800 Subject: [PATCH] Delete database directory --- database/attackchain.go | 167 ---- database/batch_task.go | 537 ------------ database/c2.go | 1259 ---------------------------- database/conversation.go | 812 ------------------ database/conversation_turn_test.go | 39 - database/database.go | 1108 ------------------------ database/group.go | 449 ---------- database/monitor.go | 537 ------------ database/robot_session.go | 84 -- database/skill_stats.go | 142 ---- database/vulnerability.go | 369 -------- database/webshell.go | 152 ---- 12 files changed, 5655 deletions(-) delete mode 100644 database/attackchain.go delete mode 100644 database/batch_task.go delete mode 100644 database/c2.go delete mode 100644 database/conversation.go delete mode 100644 database/conversation_turn_test.go delete mode 100644 database/database.go delete mode 100644 database/group.go delete mode 100644 database/monitor.go delete mode 100644 database/robot_session.go delete mode 100644 database/skill_stats.go delete mode 100644 database/vulnerability.go delete mode 100644 database/webshell.go diff --git a/database/attackchain.go b/database/attackchain.go deleted file mode 100644 index dc3b8362..00000000 --- a/database/attackchain.go +++ /dev/null @@ -1,167 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "fmt" - - "go.uber.org/zap" -) - -// AttackChainNode 攻击链节点 -type AttackChainNode struct { - ID string `json:"id"` - Type string `json:"type"` // tool, vulnerability, target, exploit - Label string `json:"label"` - ToolExecutionID string `json:"tool_execution_id,omitempty"` - Metadata map[string]interface{} `json:"metadata"` - RiskScore int `json:"risk_score"` -} - -// AttackChainEdge 攻击链边 -type AttackChainEdge struct { - ID string `json:"id"` - Source string `json:"source"` - Target string `json:"target"` - Type string `json:"type"` // leads_to, exploits, enables, depends_on - Weight int `json:"weight"` -} - -// SaveAttackChainNode 保存攻击链节点 -func (db *DB) SaveAttackChainNode(conversationID, nodeID, nodeType, nodeName, toolExecutionID, metadata string, riskScore int) error { - var toolExecID sql.NullString - if toolExecutionID != "" { - toolExecID = sql.NullString{String: toolExecutionID, Valid: true} - } - - var metadataJSON sql.NullString - if metadata != "" { - metadataJSON = sql.NullString{String: metadata, Valid: true} - } - - query := ` - INSERT OR REPLACE INTO attack_chain_nodes - (id, conversation_id, node_type, node_name, tool_execution_id, metadata, risk_score, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) - ` - - _, err := db.Exec(query, nodeID, conversationID, nodeType, nodeName, toolExecID, metadataJSON, riskScore) - if err != nil { - db.logger.Error("保存攻击链节点失败", zap.Error(err), zap.String("nodeId", nodeID)) - return err - } - - return nil -} - -// SaveAttackChainEdge 保存攻击链边 -func (db *DB) SaveAttackChainEdge(conversationID, edgeID, sourceNodeID, targetNodeID, edgeType string, weight int) error { - query := ` - INSERT OR REPLACE INTO attack_chain_edges - (id, conversation_id, source_node_id, target_node_id, edge_type, weight, created_at) - VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) - ` - - _, err := db.Exec(query, edgeID, conversationID, sourceNodeID, targetNodeID, edgeType, weight) - if err != nil { - db.logger.Error("保存攻击链边失败", zap.Error(err), zap.String("edgeId", edgeID)) - return err - } - - return nil -} - -// LoadAttackChainNodes 加载攻击链节点 -func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, error) { - query := ` - SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score - FROM attack_chain_nodes - WHERE conversation_id = ? - ORDER BY created_at ASC - ` - - rows, err := db.Query(query, conversationID) - if err != nil { - return nil, fmt.Errorf("查询攻击链节点失败: %w", err) - } - defer rows.Close() - - var nodes []AttackChainNode - for rows.Next() { - var node AttackChainNode - var toolExecID sql.NullString - var metadataJSON sql.NullString - - err := rows.Scan(&node.ID, &node.Type, &node.Label, &toolExecID, &metadataJSON, &node.RiskScore) - if err != nil { - db.logger.Warn("扫描攻击链节点失败", zap.Error(err)) - continue - } - - if toolExecID.Valid { - node.ToolExecutionID = toolExecID.String - } - - if metadataJSON.Valid && metadataJSON.String != "" { - if err := json.Unmarshal([]byte(metadataJSON.String), &node.Metadata); err != nil { - db.logger.Warn("解析节点元数据失败", zap.Error(err)) - node.Metadata = make(map[string]interface{}) - } - } else { - node.Metadata = make(map[string]interface{}) - } - - nodes = append(nodes, node) - } - - return nodes, nil -} - -// LoadAttackChainEdges 加载攻击链边 -func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, error) { - query := ` - SELECT id, source_node_id, target_node_id, edge_type, weight - FROM attack_chain_edges - WHERE conversation_id = ? - ORDER BY created_at ASC - ` - - rows, err := db.Query(query, conversationID) - if err != nil { - return nil, fmt.Errorf("查询攻击链边失败: %w", err) - } - defer rows.Close() - - var edges []AttackChainEdge - for rows.Next() { - var edge AttackChainEdge - - err := rows.Scan(&edge.ID, &edge.Source, &edge.Target, &edge.Type, &edge.Weight) - if err != nil { - db.logger.Warn("扫描攻击链边失败", zap.Error(err)) - continue - } - - edges = append(edges, edge) - } - - return edges, nil -} - -// DeleteAttackChain 删除对话的攻击链数据 -func (db *DB) DeleteAttackChain(conversationID string) error { - // 先删除边(因为有外键约束) - _, err := db.Exec("DELETE FROM attack_chain_edges WHERE conversation_id = ?", conversationID) - if err != nil { - db.logger.Warn("删除攻击链边失败", zap.Error(err)) - } - - // 再删除节点 - _, err = db.Exec("DELETE FROM attack_chain_nodes WHERE conversation_id = ?", conversationID) - if err != nil { - db.logger.Error("删除攻击链节点失败", zap.Error(err), zap.String("conversationId", conversationID)) - return err - } - - return nil -} diff --git a/database/batch_task.go b/database/batch_task.go deleted file mode 100644 index c774be65..00000000 --- a/database/batch_task.go +++ /dev/null @@ -1,537 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "strings" - "time" - - "go.uber.org/zap" -) - -// BatchTaskQueueRow 批量任务队列数据库行 -type BatchTaskQueueRow struct { - ID string - Title sql.NullString - Role sql.NullString - AgentMode sql.NullString - ScheduleMode sql.NullString - CronExpr sql.NullString - NextRunAt sql.NullTime - ScheduleEnabled sql.NullInt64 - LastScheduleTriggerAt sql.NullTime - LastScheduleError sql.NullString - LastRunError sql.NullString - Status string - CreatedAt time.Time - StartedAt sql.NullTime - CompletedAt sql.NullTime - CurrentIndex int -} - -// BatchTaskRow 批量任务数据库行 -type BatchTaskRow struct { - ID string - QueueID string - Message string - ConversationID sql.NullString - Status string - StartedAt sql.NullTime - CompletedAt sql.NullTime - Error sql.NullString - Result sql.NullString -} - -// CreateBatchQueue 创建批量任务队列 -func (db *DB) CreateBatchQueue( - queueID string, - title string, - role string, - agentMode string, - scheduleMode string, - cronExpr string, - nextRunAt *time.Time, - tasks []map[string]interface{}, -) error { - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("开始事务失败: %w", err) - } - defer tx.Rollback() - - now := time.Now() - var nextRunAtValue interface{} - if nextRunAt != nil { - nextRunAtValue = *nextRunAt - } - - _, err = tx.Exec( - "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0, - ) - if err != nil { - return fmt.Errorf("创建批量任务队列失败: %w", err) - } - - // 插入任务 - for _, task := range tasks { - taskID, ok := task["id"].(string) - if !ok { - continue - } - message, ok := task["message"].(string) - if !ok { - continue - } - - _, err = tx.Exec( - "INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)", - taskID, queueID, message, "pending", - ) - if err != nil { - return fmt.Errorf("创建批量任务失败: %w", err) - } - } - - return tx.Commit() -} - -// GetBatchQueue 获取批量任务队列 -func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { - var row BatchTaskQueueRow - var createdAt string - err := db.QueryRow( - "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", - queueID, - ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, fmt.Errorf("查询批量任务队列失败: %w", err) - } - - parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) - if parseErr != nil { - // 尝试其他时间格式 - parsedTime, parseErr = time.Parse(time.RFC3339, createdAt) - if parseErr != nil { - db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr)) - parsedTime = time.Now() - } - } - row.CreatedAt = parsedTime - return &row, nil -} - -// GetAllBatchQueues 获取所有批量任务队列 -func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { - rows, err := db.Query( - "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", - ) - if err != nil { - return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) - } - defer rows.Close() - - var queues []*BatchTaskQueueRow - for rows.Next() { - var row BatchTaskQueueRow - var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { - return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) - } - parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) - if parseErr != nil { - parsedTime, parseErr = time.Parse(time.RFC3339, createdAt) - if parseErr != nil { - db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr)) - parsedTime = time.Now() - } - } - row.CreatedAt = parsedTime - queues = append(queues, &row) - } - - return queues, nil -} - -// ListBatchQueues 列出批量任务队列(支持筛选和分页) -func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { - query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" - args := []interface{}{} - - // 状态筛选 - if status != "" && status != "all" { - query += " AND status = ?" - args = append(args, status) - } - - // 关键字搜索(搜索队列ID和标题) - if keyword != "" { - query += " AND (id LIKE ? OR title LIKE ?)" - args = append(args, "%"+keyword+"%", "%"+keyword+"%") - } - - query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" - args = append(args, limit, offset) - - rows, err := db.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) - } - defer rows.Close() - - var queues []*BatchTaskQueueRow - for rows.Next() { - var row BatchTaskQueueRow - var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { - return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) - } - parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) - if parseErr != nil { - parsedTime, parseErr = time.Parse(time.RFC3339, createdAt) - if parseErr != nil { - db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr)) - parsedTime = time.Now() - } - } - row.CreatedAt = parsedTime - queues = append(queues, &row) - } - - return queues, nil -} - -// CountBatchQueues 统计批量任务队列总数(支持筛选条件) -func (db *DB) CountBatchQueues(status, keyword string) (int, error) { - query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1" - args := []interface{}{} - - // 状态筛选 - if status != "" && status != "all" { - query += " AND status = ?" - args = append(args, status) - } - - // 关键字搜索(搜索队列ID和标题) - if keyword != "" { - query += " AND (id LIKE ? OR title LIKE ?)" - args = append(args, "%"+keyword+"%", "%"+keyword+"%") - } - - var count int - err := db.QueryRow(query, args...).Scan(&count) - if err != nil { - return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err) - } - - return count, nil -} - -// GetBatchTasks 获取批量任务队列的所有任务 -func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) { - rows, err := db.Query( - "SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id", - queueID, - ) - if err != nil { - return nil, fmt.Errorf("查询批量任务失败: %w", err) - } - defer rows.Close() - - var tasks []*BatchTaskRow - for rows.Next() { - var task BatchTaskRow - if err := rows.Scan( - &task.ID, &task.QueueID, &task.Message, &task.ConversationID, - &task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result, - ); err != nil { - return nil, fmt.Errorf("扫描批量任务失败: %w", err) - } - tasks = append(tasks, &task) - } - - return tasks, nil -} - -// UpdateBatchQueueStatus 更新批量任务队列状态 -func (db *DB) UpdateBatchQueueStatus(queueID, status string) error { - var err error - now := time.Now() - - if status == "running" { - _, err = db.Exec( - "UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?", - status, now, queueID, - ) - } else if status == "completed" || status == "cancelled" { - _, err = db.Exec( - "UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?", - status, now, queueID, - ) - } else { - _, err = db.Exec( - "UPDATE batch_task_queues SET status = ? WHERE id = ?", - status, queueID, - ) - } - - if err != nil { - return fmt.Errorf("更新批量任务队列状态失败: %w", err) - } - return nil -} - -// UpdateBatchTaskStatus 更新批量任务状态 -func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error { - var err error - now := time.Now() - - // 构建更新语句 - var updates []string - var args []interface{} - - updates = append(updates, "status = ?") - args = append(args, status) - - if conversationID != "" { - updates = append(updates, "conversation_id = ?") - args = append(args, conversationID) - } - - if result != "" { - updates = append(updates, "result = ?") - args = append(args, result) - } - - if errorMsg != "" { - updates = append(updates, "error = ?") - args = append(args, errorMsg) - } - - if status == "running" { - updates = append(updates, "started_at = COALESCE(started_at, ?)") - args = append(args, now) - } - - if status == "completed" || status == "failed" || status == "cancelled" { - updates = append(updates, "completed_at = COALESCE(completed_at, ?)") - args = append(args, now) - } - - args = append(args, queueID, taskID) - - // 构建SQL语句 - sql := "UPDATE batch_tasks SET " - for i, update := range updates { - if i > 0 { - sql += ", " - } - sql += update - } - sql += " WHERE queue_id = ? AND id = ?" - - _, err = db.Exec(sql, args...) - if err != nil { - return fmt.Errorf("更新批量任务状态失败: %w", err) - } - return nil -} - -// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引 -func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error { - _, err := db.Exec( - "UPDATE batch_task_queues SET current_index = ? WHERE id = ?", - currentIndex, queueID, - ) - if err != nil { - return fmt.Errorf("更新批量任务队列当前索引失败: %w", err) - } - return nil -} - -// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式 -func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error { - _, err := db.Exec( - "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?", - title, role, agentMode, queueID, - ) - if err != nil { - return fmt.Errorf("更新批量任务队列元数据失败: %w", err) - } - return nil -} - -// UpdateBatchQueueSchedule 更新批量任务队列调度相关信息 -func (db *DB) UpdateBatchQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) error { - var nextRunAtValue interface{} - if nextRunAt != nil { - nextRunAtValue = *nextRunAt - } - _, err := db.Exec( - "UPDATE batch_task_queues SET schedule_mode = ?, cron_expr = ?, next_run_at = ? WHERE id = ?", - scheduleMode, cronExpr, nextRunAtValue, queueID, - ) - if err != nil { - return fmt.Errorf("更新批量任务调度配置失败: %w", err) - } - return nil -} - -// UpdateBatchQueueScheduleEnabled 是否允许 Cron 自动触发(手工「开始执行」不受影响) -func (db *DB) UpdateBatchQueueScheduleEnabled(queueID string, enabled bool) error { - v := 0 - if enabled { - v = 1 - } - _, err := db.Exec( - "UPDATE batch_task_queues SET schedule_enabled = ? WHERE id = ?", - v, queueID, - ) - if err != nil { - return fmt.Errorf("更新批量任务调度开关失败: %w", err) - } - return nil -} - -// RecordBatchQueueScheduledTriggerStart 记录一次由调度触发的开始时间并清空调度层错误 -func (db *DB) RecordBatchQueueScheduledTriggerStart(queueID string, at time.Time) error { - _, err := db.Exec( - "UPDATE batch_task_queues SET last_schedule_trigger_at = ?, last_schedule_error = NULL WHERE id = ?", - at, queueID, - ) - if err != nil { - return fmt.Errorf("记录调度触发时间失败: %w", err) - } - return nil -} - -// SetBatchQueueLastScheduleError 调度启动失败等原因(如状态不允许、重置失败) -func (db *DB) SetBatchQueueLastScheduleError(queueID, msg string) error { - _, err := db.Exec( - "UPDATE batch_task_queues SET last_schedule_error = ? WHERE id = ?", - msg, queueID, - ) - if err != nil { - return fmt.Errorf("写入调度错误信息失败: %w", err) - } - return nil -} - -// SetBatchQueueLastRunError 最近一轮执行中出现的子任务失败摘要(空串表示清空) -func (db *DB) SetBatchQueueLastRunError(queueID, msg string) error { - var v interface{} - if strings.TrimSpace(msg) == "" { - v = nil - } else { - v = msg - } - _, err := db.Exec( - "UPDATE batch_task_queues SET last_run_error = ? WHERE id = ?", - v, queueID, - ) - if err != nil { - return fmt.Errorf("写入最近运行错误失败: %w", err) - } - return nil -} - -// ResetBatchQueueForRerun 重置队列和任务状态用于下一轮调度执行 -func (db *DB) ResetBatchQueueForRerun(queueID string) error { - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("开始事务失败: %w", err) - } - defer tx.Rollback() - - _, err = tx.Exec( - "UPDATE batch_task_queues SET status = ?, current_index = 0, started_at = NULL, completed_at = NULL, last_run_error = NULL, last_schedule_error = NULL WHERE id = ?", - "pending", queueID, - ) - if err != nil { - return fmt.Errorf("重置批量任务队列状态失败: %w", err) - } - - _, err = tx.Exec( - "UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ?", - "pending", queueID, - ) - if err != nil { - return fmt.Errorf("重置批量任务状态失败: %w", err) - } - - return tx.Commit() -} - -// UpdateBatchTaskMessage 更新批量任务消息 -func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error { - _, err := db.Exec( - "UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?", - message, queueID, taskID, - ) - if err != nil { - return fmt.Errorf("更新批量任务消息失败: %w", err) - } - return nil -} - -// AddBatchTask 添加任务到批量任务队列 -func (db *DB) AddBatchTask(queueID, taskID, message string) error { - _, err := db.Exec( - "INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)", - taskID, queueID, message, "pending", - ) - if err != nil { - return fmt.Errorf("添加批量任务失败: %w", err) - } - return nil -} - -// CancelPendingBatchTasks 批量取消队列中所有 pending 状态的任务(单条 SQL) -func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) error { - _, err := db.Exec( - "UPDATE batch_tasks SET status = ?, completed_at = ? WHERE queue_id = ? AND status = ?", - "cancelled", completedAt, queueID, "pending", - ) - if err != nil { - return fmt.Errorf("批量取消 pending 任务失败: %w", err) - } - return nil -} - -// DeleteBatchTask 删除批量任务 -func (db *DB) DeleteBatchTask(queueID, taskID string) error { - _, err := db.Exec( - "DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?", - queueID, taskID, - ) - if err != nil { - return fmt.Errorf("删除批量任务失败: %w", err) - } - return nil -} - -// DeleteBatchQueue 删除批量任务队列 -func (db *DB) DeleteBatchQueue(queueID string) error { - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("开始事务失败: %w", err) - } - defer tx.Rollback() - - // 删除任务(外键会自动级联删除) - _, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID) - if err != nil { - return fmt.Errorf("删除批量任务失败: %w", err) - } - - // 删除队列 - _, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID) - if err != nil { - return fmt.Errorf("删除批量任务队列失败: %w", err) - } - - return tx.Commit() -} diff --git a/database/c2.go b/database/c2.go deleted file mode 100644 index 0965ba3d..00000000 --- a/database/c2.go +++ /dev/null @@ -1,1259 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - "go.uber.org/zap" -) - -// ErrNoValidC2EventIDs 批量删除事件时未提供任何合法 ID -var ErrNoValidC2EventIDs = errors.New("no valid event ids") - -// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID -var ErrNoValidC2TaskIDs = errors.New("no valid task ids") - -// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参 -func validC2TextIDForDelete(id string) bool { - if len(id) < 2 || len(id) > 80 { - return false - } - for _, c := range id { - if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' { - continue - } - return false - } - return true -} - -// ============================================================================ -// C2 模块数据模型 — 6 张表的领域类型 -// 设计要点: -// - 全部使用文本主键(l_/s_/t_/f_/e_/p_ 前缀),与项目现有 ws_/v_ 风格一致; -// - 时间字段统一 time.Time,由 SQLite 自动序列化为 ISO8601; -// - 大字段(profile 配置、心跳元数据、任务结果)走 JSON 文本,避免频繁加列; -// - 任意会话/任务/文件均可按 listener_id / session_id 级联删除(FOREIGN KEY ON DELETE CASCADE)。 -// ============================================================================ - -// C2Listener 监听器实体 -type C2Listener struct { - ID string `json:"id"` - Name string `json:"name"` - Type string `json:"type"` // tcp_reverse|http_beacon|https_beacon|websocket|dns - BindHost string `json:"bindHost"` // 默认 127.0.0.1 - BindPort int `json:"bindPort"` // 1-65535 - ProfileID string `json:"profileId"` // 可空:关联 c2_profiles.id - EncryptionKey string `json:"-"` // base64(AES-256),前端不返回 - ImplantToken string `json:"-"` // beacon 携带的鉴权 token,前端不返回 - Status string `json:"status"` // stopped|running|error - ConfigJSON string `json:"configJson"` // TLS 证书路径 / URI 模式 / 上限并发 等 - Remark string `json:"remark"` - CreatedAt time.Time `json:"createdAt"` - StartedAt *time.Time `json:"startedAt,omitempty"` - LastError string `json:"lastError,omitempty"` -} - -// C2Session 已上线会话 -type C2Session struct { - ID string `json:"id"` - ListenerID string `json:"listenerId"` - ImplantUUID string `json:"implantUuid"` - Hostname string `json:"hostname"` - Username string `json:"username"` - OS string `json:"os"` - Arch string `json:"arch"` - PID int `json:"pid"` - ProcessName string `json:"processName"` - IsAdmin bool `json:"isAdmin"` - InternalIP string `json:"internalIp"` - ExternalIP string `json:"externalIp"` - UserAgent string `json:"userAgent"` - SleepSeconds int `json:"sleepSeconds"` - JitterPercent int `json:"jitterPercent"` - Status string `json:"status"` // active|sleeping|dead|killed - FirstSeenAt time.Time `json:"firstSeenAt"` - LastCheckIn time.Time `json:"lastCheckIn"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Note string `json:"note"` -} - -// C2Task 下发任务 -type C2Task struct { - ID string `json:"id"` - SessionID string `json:"sessionId"` - TaskType string `json:"taskType"` - Payload map[string]interface{} `json:"payload,omitempty"` - Status string `json:"status"` // queued|sent|running|success|failed|cancelled - ResultText string `json:"resultText,omitempty"` - ResultBlobPath string `json:"resultBlobPath,omitempty"` - Error string `json:"error,omitempty"` - Source string `json:"source"` // manual|ai|batch|api - ConversationID string `json:"conversationId,omitempty"` - ApprovalStatus string `json:"approvalStatus,omitempty"` // pending|approved|rejected - CreatedAt time.Time `json:"createdAt"` - SentAt *time.Time `json:"sentAt,omitempty"` - StartedAt *time.Time `json:"startedAt,omitempty"` - CompletedAt *time.Time `json:"completedAt,omitempty"` - DurationMS int64 `json:"durationMs,omitempty"` -} - -// C2File 上传/下载凭证 -type C2File struct { - ID string `json:"id"` - SessionID string `json:"sessionId"` - TaskID string `json:"taskId"` - Direction string `json:"direction"` // upload|download - RemotePath string `json:"remotePath"` - LocalPath string `json:"localPath"` - SizeBytes int64 `json:"sizeBytes"` - SHA256 string `json:"sha256"` - CreatedAt time.Time `json:"createdAt"` -} - -// C2Event 事件审计 -type C2Event struct { - ID string `json:"id"` - Level string `json:"level"` // info|warn|critical - Category string `json:"category"` // listener|session|task|payload|opsec - SessionID string `json:"sessionId,omitempty"` - TaskID string `json:"taskId,omitempty"` - Message string `json:"message"` - Data map[string]interface{} `json:"data,omitempty"` - CreatedAt time.Time `json:"createdAt"` -} - -// C2Profile Malleable Profile -type C2Profile struct { - ID string `json:"id"` - Name string `json:"name"` - UserAgent string `json:"userAgent"` - URIs []string `json:"uris"` - RequestHeaders map[string]string `json:"requestHeaders,omitempty"` - ResponseHeaders map[string]string `json:"responseHeaders,omitempty"` - BodyTemplate string `json:"bodyTemplate"` - JitterMinMS int `json:"jitterMinMs"` - JitterMaxMS int `json:"jitterMaxMs"` - Extra map[string]interface{} `json:"extra,omitempty"` - CreatedAt time.Time `json:"createdAt"` -} - -// ---------------------------------------------------------------------------- -// CRUD:C2 监听器 -// ---------------------------------------------------------------------------- - -// CreateC2Listener 写入新监听器;ID/Name 由调用方生成校验 -func (db *DB) CreateC2Listener(l *C2Listener) error { - if l == nil || strings.TrimSpace(l.ID) == "" { - return errors.New("listener id is required") - } - if l.CreatedAt.IsZero() { - l.CreatedAt = time.Now() - } - if strings.TrimSpace(l.Status) == "" { - l.Status = "stopped" - } - if strings.TrimSpace(l.ConfigJSON) == "" { - l.ConfigJSON = "{}" - } - query := ` - INSERT INTO c2_listeners (id, name, type, bind_host, bind_port, profile_id, encryption_key, - implant_token, status, config_json, remark, created_at, started_at, last_error) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - _, err := db.Exec(query, - l.ID, l.Name, l.Type, l.BindHost, l.BindPort, l.ProfileID, l.EncryptionKey, - l.ImplantToken, l.Status, l.ConfigJSON, l.Remark, l.CreatedAt, l.StartedAt, l.LastError, - ) - if err != nil { - db.logger.Error("创建 C2 监听器失败", zap.Error(err), zap.String("id", l.ID)) - return err - } - return nil -} - -// UpdateC2Listener 更新监听器;空字段也会被覆盖(请先 GetC2Listener 拿到完整对象再改) -func (db *DB) UpdateC2Listener(l *C2Listener) error { - if l == nil || strings.TrimSpace(l.ID) == "" { - return errors.New("listener id is required") - } - if strings.TrimSpace(l.ConfigJSON) == "" { - l.ConfigJSON = "{}" - } - query := ` - UPDATE c2_listeners SET - name = ?, type = ?, bind_host = ?, bind_port = ?, profile_id = ?, encryption_key = ?, - implant_token = ?, status = ?, config_json = ?, remark = ?, started_at = ?, last_error = ? - WHERE id = ? - ` - res, err := db.Exec(query, - l.Name, l.Type, l.BindHost, l.BindPort, l.ProfileID, l.EncryptionKey, - l.ImplantToken, l.Status, l.ConfigJSON, l.Remark, l.StartedAt, l.LastError, l.ID, - ) - if err != nil { - db.logger.Error("更新 C2 监听器失败", zap.Error(err), zap.String("id", l.ID)) - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// SetC2ListenerStatus 仅更新状态/started_at/last_error 三个字段,避免与全量更新竞争 -func (db *DB) SetC2ListenerStatus(id, status, lastError string, startedAt *time.Time) error { - query := ` - UPDATE c2_listeners SET status = ?, last_error = ?, started_at = COALESCE(?, started_at) - WHERE id = ? - ` - res, err := db.Exec(query, status, lastError, startedAt, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// GetC2Listener 单条查询 -func (db *DB) GetC2Listener(id string) (*C2Listener, error) { - query := ` - SELECT id, name, type, bind_host, bind_port, COALESCE(profile_id, ''), - COALESCE(encryption_key, ''), COALESCE(implant_token, ''), status, - COALESCE(config_json, '{}'), COALESCE(remark, ''), - created_at, started_at, COALESCE(last_error, '') - FROM c2_listeners WHERE id = ? - ` - var l C2Listener - var startedAt sql.NullTime - err := db.QueryRow(query, id).Scan( - &l.ID, &l.Name, &l.Type, &l.BindHost, &l.BindPort, &l.ProfileID, - &l.EncryptionKey, &l.ImplantToken, &l.Status, - &l.ConfigJSON, &l.Remark, - &l.CreatedAt, &startedAt, &l.LastError, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - if startedAt.Valid { - t := startedAt.Time - l.StartedAt = &t - } - return &l, nil -} - -// ListC2Listeners 全量列表,按创建时间倒序 -func (db *DB) ListC2Listeners() ([]*C2Listener, error) { - query := ` - SELECT id, name, type, bind_host, bind_port, COALESCE(profile_id, ''), - COALESCE(encryption_key, ''), COALESCE(implant_token, ''), status, - COALESCE(config_json, '{}'), COALESCE(remark, ''), - created_at, started_at, COALESCE(last_error, '') - FROM c2_listeners ORDER BY created_at DESC - ` - rows, err := db.Query(query) - if err != nil { - return nil, err - } - defer rows.Close() - var list []*C2Listener - for rows.Next() { - var l C2Listener - var startedAt sql.NullTime - if err := rows.Scan( - &l.ID, &l.Name, &l.Type, &l.BindHost, &l.BindPort, &l.ProfileID, - &l.EncryptionKey, &l.ImplantToken, &l.Status, - &l.ConfigJSON, &l.Remark, - &l.CreatedAt, &startedAt, &l.LastError, - ); err != nil { - db.logger.Warn("扫描 c2_listeners 行失败", zap.Error(err)) - continue - } - if startedAt.Valid { - t := startedAt.Time - l.StartedAt = &t - } - list = append(list, &l) - } - return list, rows.Err() -} - -// DeleteC2Listener 级联删除(会话/任务/文件/事件随之消失) -func (db *DB) DeleteC2Listener(id string) error { - res, err := db.Exec(`DELETE FROM c2_listeners WHERE id = ?`, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// ---------------------------------------------------------------------------- -// CRUD:C2 会话 -// ---------------------------------------------------------------------------- - -// UpsertC2Session 按 implant_uuid 唯一约束:首次插入 / 已存在则更新心跳和状态 -func (db *DB) UpsertC2Session(s *C2Session) error { - if s == nil || strings.TrimSpace(s.ID) == "" || strings.TrimSpace(s.ImplantUUID) == "" { - return errors.New("session id and implant_uuid are required") - } - if s.FirstSeenAt.IsZero() { - s.FirstSeenAt = time.Now() - } - if s.LastCheckIn.IsZero() { - s.LastCheckIn = s.FirstSeenAt - } - if strings.TrimSpace(s.Status) == "" { - s.Status = "active" - } - metadataJSON := "{}" - if len(s.Metadata) > 0 { - if b, err := json.Marshal(s.Metadata); err == nil { - metadataJSON = string(b) - } - } - query := ` - INSERT INTO c2_sessions (id, listener_id, implant_uuid, hostname, username, os, arch, - pid, process_name, is_admin, internal_ip, external_ip, user_agent, - sleep_seconds, jitter_percent, status, first_seen_at, last_check_in, - metadata_json, note) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(implant_uuid) DO UPDATE SET - hostname = excluded.hostname, - username = excluded.username, - os = excluded.os, - arch = excluded.arch, - pid = excluded.pid, - process_name = excluded.process_name, - is_admin = excluded.is_admin, - internal_ip = excluded.internal_ip, - external_ip = excluded.external_ip, - user_agent = excluded.user_agent, - sleep_seconds = excluded.sleep_seconds, - jitter_percent = excluded.jitter_percent, - status = excluded.status, - last_check_in = excluded.last_check_in, - metadata_json = excluded.metadata_json - ` - isAdminInt := 0 - if s.IsAdmin { - isAdminInt = 1 - } - _, err := db.Exec(query, - s.ID, s.ListenerID, s.ImplantUUID, s.Hostname, s.Username, s.OS, s.Arch, - s.PID, s.ProcessName, isAdminInt, s.InternalIP, s.ExternalIP, s.UserAgent, - s.SleepSeconds, s.JitterPercent, s.Status, s.FirstSeenAt, s.LastCheckIn, - metadataJSON, s.Note, - ) - if err != nil { - db.logger.Error("upsert C2 会话失败", zap.Error(err), zap.String("implant_uuid", s.ImplantUUID)) - return err - } - return nil -} - -// TouchC2Session 仅更新 last_check_in / status,性能比 UpsertC2Session 高,给 beacon 高频心跳用 -func (db *DB) TouchC2Session(id, status string, t time.Time) error { - if t.IsZero() { - t = time.Now() - } - res, err := db.Exec(`UPDATE c2_sessions SET last_check_in = ?, status = ? WHERE id = ?`, t, status, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// SetC2SessionStatus 单独改状态 -func (db *DB) SetC2SessionStatus(id, status string) error { - res, err := db.Exec(`UPDATE c2_sessions SET status = ? WHERE id = ?`, status, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// SetC2SessionSleep 改 sleep / jitter(操作员或 AI 主动调整心跳节律) -func (db *DB) SetC2SessionSleep(id string, sleepSeconds, jitterPercent int) error { - if sleepSeconds < 0 { - sleepSeconds = 0 - } - if jitterPercent < 0 { - jitterPercent = 0 - } - if jitterPercent > 100 { - jitterPercent = 100 - } - res, err := db.Exec(`UPDATE c2_sessions SET sleep_seconds = ?, jitter_percent = ? WHERE id = ?`, - sleepSeconds, jitterPercent, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// SetC2SessionNote 改备注 -func (db *DB) SetC2SessionNote(id, note string) error { - _, err := db.Exec(`UPDATE c2_sessions SET note = ? WHERE id = ?`, note, id) - return err -} - -// GetC2Session 按内部 ID 查 -func (db *DB) GetC2Session(id string) (*C2Session, error) { - return db.queryC2SessionWhere(`id = ?`, id) -} - -// GetC2SessionByImplantUUID 按 implant 自报的 UUID 查(重连必需) -func (db *DB) GetC2SessionByImplantUUID(uuid string) (*C2Session, error) { - return db.queryC2SessionWhere(`implant_uuid = ?`, uuid) -} - -func (db *DB) queryC2SessionWhere(whereClause string, args ...interface{}) (*C2Session, error) { - query := ` - SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''), - COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''), - COALESCE(is_admin, 0), COALESCE(internal_ip,''), COALESCE(external_ip,''), - COALESCE(user_agent,''), COALESCE(sleep_seconds, 5), COALESCE(jitter_percent, 0), - status, first_seen_at, last_check_in, COALESCE(metadata_json, '{}'), - COALESCE(note, '') - FROM c2_sessions WHERE ` + whereClause - row := db.QueryRow(query, args...) - var s C2Session - var isAdminInt int - var metadataJSON string - err := row.Scan( - &s.ID, &s.ListenerID, &s.ImplantUUID, &s.Hostname, &s.Username, - &s.OS, &s.Arch, &s.PID, &s.ProcessName, - &isAdminInt, &s.InternalIP, &s.ExternalIP, - &s.UserAgent, &s.SleepSeconds, &s.JitterPercent, - &s.Status, &s.FirstSeenAt, &s.LastCheckIn, &metadataJSON, - &s.Note, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - s.IsAdmin = isAdminInt != 0 - if metadataJSON != "" && metadataJSON != "{}" { - _ = json.Unmarshal([]byte(metadataJSON), &s.Metadata) - } - return &s, nil -} - -// ListC2SessionsFilter 列表过滤参数 -type ListC2SessionsFilter struct { - ListenerID string - Status string // active|sleeping|dead|killed;空表示全部 - OS string - Search string // 模糊匹配 hostname/username/internal_ip - Limit int // 0 表示无限制 -} - -// ListC2Sessions 列表,按 last_check_in 倒序 -func (db *DB) ListC2Sessions(filter ListC2SessionsFilter) ([]*C2Session, error) { - conditions := []string{"1=1"} - args := []interface{}{} - if filter.ListenerID != "" { - conditions = append(conditions, "listener_id = ?") - args = append(args, filter.ListenerID) - } - if filter.Status != "" { - conditions = append(conditions, "status = ?") - args = append(args, filter.Status) - } - if filter.OS != "" { - conditions = append(conditions, "os = ?") - args = append(args, filter.OS) - } - if filter.Search != "" { - conditions = append(conditions, "(hostname LIKE ? OR username LIKE ? OR internal_ip LIKE ?)") - kw := "%" + filter.Search + "%" - args = append(args, kw, kw, kw) - } - query := ` - SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''), - COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''), - COALESCE(is_admin, 0), COALESCE(internal_ip,''), COALESCE(external_ip,''), - COALESCE(user_agent,''), COALESCE(sleep_seconds, 5), COALESCE(jitter_percent, 0), - status, first_seen_at, last_check_in, COALESCE(metadata_json, '{}'), - COALESCE(note, '') - FROM c2_sessions - WHERE ` + strings.Join(conditions, " AND ") + ` - ORDER BY last_check_in DESC - ` - if filter.Limit > 0 { - query += fmt.Sprintf(" LIMIT %d", filter.Limit) - } - rows, err := db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - var list []*C2Session - for rows.Next() { - var s C2Session - var isAdminInt int - var metadataJSON string - if err := rows.Scan( - &s.ID, &s.ListenerID, &s.ImplantUUID, &s.Hostname, &s.Username, - &s.OS, &s.Arch, &s.PID, &s.ProcessName, - &isAdminInt, &s.InternalIP, &s.ExternalIP, - &s.UserAgent, &s.SleepSeconds, &s.JitterPercent, - &s.Status, &s.FirstSeenAt, &s.LastCheckIn, &metadataJSON, - &s.Note, - ); err != nil { - db.logger.Warn("扫描 c2_sessions 行失败", zap.Error(err)) - continue - } - s.IsAdmin = isAdminInt != 0 - if metadataJSON != "" && metadataJSON != "{}" { - _ = json.Unmarshal([]byte(metadataJSON), &s.Metadata) - } - list = append(list, &s) - } - return list, rows.Err() -} - -// DeleteC2Session 级联删除其 tasks/files -func (db *DB) DeleteC2Session(id string) error { - res, err := db.Exec(`DELETE FROM c2_sessions WHERE id = ?`, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// ---------------------------------------------------------------------------- -// CRUD:C2 任务 -// ---------------------------------------------------------------------------- - -// CreateC2Task 入队一个新任务 -func (db *DB) CreateC2Task(t *C2Task) error { - if t == nil || strings.TrimSpace(t.ID) == "" { - return errors.New("task id is required") - } - if t.CreatedAt.IsZero() { - t.CreatedAt = time.Now() - } - if strings.TrimSpace(t.Status) == "" { - t.Status = "queued" - } - if strings.TrimSpace(t.Source) == "" { - t.Source = "manual" - } - payloadJSON := "{}" - if len(t.Payload) > 0 { - if b, err := json.Marshal(t.Payload); err == nil { - payloadJSON = string(b) - } - } - query := ` - INSERT INTO c2_tasks (id, session_id, task_type, payload_json, status, - result_text, result_blob_path, error, source, conversation_id, approval_status, - created_at, sent_at, started_at, completed_at, duration_ms) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - _, err := db.Exec(query, - t.ID, t.SessionID, t.TaskType, payloadJSON, t.Status, - t.ResultText, t.ResultBlobPath, t.Error, t.Source, t.ConversationID, t.ApprovalStatus, - t.CreatedAt, t.SentAt, t.StartedAt, t.CompletedAt, t.DurationMS, - ) - if err != nil { - db.logger.Error("创建 C2 任务失败", zap.Error(err), zap.String("id", t.ID)) - return err - } - return nil -} - -// SetC2TaskStatus 更新任务的状态/结果/错误/时间戳 -type C2TaskUpdate struct { - Status *string - ResultText *string - ResultBlobPath *string - Error *string - ApprovalStatus *string - SentAt *time.Time - StartedAt *time.Time - CompletedAt *time.Time - DurationMS *int64 -} - -// UpdateC2Task 增量更新任务字段;nil 字段保持原值 -func (db *DB) UpdateC2Task(id string, u C2TaskUpdate) error { - sets := []string{} - args := []interface{}{} - if u.Status != nil { - sets = append(sets, "status = ?") - args = append(args, *u.Status) - } - if u.ResultText != nil { - sets = append(sets, "result_text = ?") - args = append(args, *u.ResultText) - } - if u.ResultBlobPath != nil { - sets = append(sets, "result_blob_path = ?") - args = append(args, *u.ResultBlobPath) - } - if u.Error != nil { - sets = append(sets, "error = ?") - args = append(args, *u.Error) - } - if u.ApprovalStatus != nil { - sets = append(sets, "approval_status = ?") - args = append(args, *u.ApprovalStatus) - } - if u.SentAt != nil { - sets = append(sets, "sent_at = ?") - args = append(args, *u.SentAt) - } - if u.StartedAt != nil { - sets = append(sets, "started_at = ?") - args = append(args, *u.StartedAt) - } - if u.CompletedAt != nil { - sets = append(sets, "completed_at = ?") - args = append(args, *u.CompletedAt) - } - if u.DurationMS != nil { - sets = append(sets, "duration_ms = ?") - args = append(args, *u.DurationMS) - } - if len(sets) == 0 { - return nil - } - query := "UPDATE c2_tasks SET " + strings.Join(sets, ", ") + " WHERE id = ?" - args = append(args, id) - res, err := db.Exec(query, args...) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// GetC2Task 单条 -func (db *DB) GetC2Task(id string) (*C2Task, error) { - query := ` - SELECT id, session_id, task_type, COALESCE(payload_json, '{}'), - status, COALESCE(result_text, ''), COALESCE(result_blob_path, ''), - COALESCE(error, ''), COALESCE(source, 'manual'), - COALESCE(conversation_id, ''), COALESCE(approval_status, ''), - created_at, sent_at, started_at, completed_at, COALESCE(duration_ms, 0) - FROM c2_tasks WHERE id = ? - ` - var t C2Task - var payloadJSON string - var sentAt, startedAt, completedAt sql.NullTime - err := db.QueryRow(query, id).Scan( - &t.ID, &t.SessionID, &t.TaskType, &payloadJSON, - &t.Status, &t.ResultText, &t.ResultBlobPath, - &t.Error, &t.Source, - &t.ConversationID, &t.ApprovalStatus, - &t.CreatedAt, &sentAt, &startedAt, &completedAt, &t.DurationMS, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - if payloadJSON != "" && payloadJSON != "{}" { - _ = json.Unmarshal([]byte(payloadJSON), &t.Payload) - } - if sentAt.Valid { - x := sentAt.Time - t.SentAt = &x - } - if startedAt.Valid { - x := startedAt.Time - t.StartedAt = &x - } - if completedAt.Valid { - x := completedAt.Time - t.CompletedAt = &x - } - return &t, nil -} - -// ListC2TasksFilter 任务过滤 -type ListC2TasksFilter struct { - SessionID string - Status string - Limit int - Offset int -} - -func buildC2TasksWhere(filter ListC2TasksFilter) (where string, args []interface{}) { - conditions := []string{"1=1"} - args = []interface{}{} - if filter.SessionID != "" { - conditions = append(conditions, "session_id = ?") - args = append(args, filter.SessionID) - } - if filter.Status != "" { - conditions = append(conditions, "status = ?") - args = append(args, filter.Status) - } - return strings.Join(conditions, " AND "), args -} - -// CountC2Tasks 与 ListC2Tasks 相同过滤条件下的记录总数 -func (db *DB) CountC2Tasks(filter ListC2TasksFilter) (int64, error) { - where, args := buildC2TasksWhere(filter) - query := `SELECT COUNT(*) FROM c2_tasks WHERE ` + where - var n int64 - err := db.QueryRow(query, args...).Scan(&n) - return n, err -} - -// CountC2TasksQueuedOrPending 统计 queued/pending 状态任务数(仪表盘「待审任务」) -func (db *DB) CountC2TasksQueuedOrPending(sessionID string) (int64, error) { - conditions := []string{"status IN ('queued', 'pending')"} - args := []interface{}{} - if sessionID != "" { - conditions = append(conditions, "session_id = ?") - args = append(args, sessionID) - } - query := `SELECT COUNT(*) FROM c2_tasks WHERE ` + strings.Join(conditions, " AND ") - var n int64 - err := db.QueryRow(query, args...).Scan(&n) - return n, err -} - -// ListC2Tasks 任务列表,按创建时间倒序 -func (db *DB) ListC2Tasks(filter ListC2TasksFilter) ([]*C2Task, error) { - where, args := buildC2TasksWhere(filter) - query := ` - SELECT id, session_id, task_type, COALESCE(payload_json, '{}'), - status, COALESCE(result_text, ''), COALESCE(result_blob_path, ''), - COALESCE(error, ''), COALESCE(source, 'manual'), - COALESCE(conversation_id, ''), COALESCE(approval_status, ''), - created_at, sent_at, started_at, completed_at, COALESCE(duration_ms, 0) - FROM c2_tasks - WHERE ` + where + ` - ORDER BY created_at DESC - ` - limit := filter.Limit - offset := filter.Offset - if offset < 0 { - offset = 0 - } - if limit > 0 { - if limit > 1000 { - limit = 1000 - } - query += ` LIMIT ? OFFSET ?` - args = append(args, limit, offset) - } - rows, err := db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - var list []*C2Task - for rows.Next() { - var t C2Task - var payloadJSON string - var sentAt, startedAt, completedAt sql.NullTime - if err := rows.Scan( - &t.ID, &t.SessionID, &t.TaskType, &payloadJSON, - &t.Status, &t.ResultText, &t.ResultBlobPath, - &t.Error, &t.Source, - &t.ConversationID, &t.ApprovalStatus, - &t.CreatedAt, &sentAt, &startedAt, &completedAt, &t.DurationMS, - ); err != nil { - db.logger.Warn("扫描 c2_tasks 行失败", zap.Error(err)) - continue - } - if payloadJSON != "" && payloadJSON != "{}" { - _ = json.Unmarshal([]byte(payloadJSON), &t.Payload) - } - if sentAt.Valid { - x := sentAt.Time - t.SentAt = &x - } - if startedAt.Valid { - x := startedAt.Time - t.StartedAt = &x - } - if completedAt.Valid { - x := completedAt.Time - t.CompletedAt = &x - } - list = append(list, &t) - } - return list, rows.Err() -} - -// PopQueuedC2Tasks 取出某会话所有 queued/approved 任务(用于 beacon 拉取),原子置为 sent -func (db *DB) PopQueuedC2Tasks(sessionID string, limit int) ([]*C2Task, error) { - if limit <= 0 { - limit = 50 - } - tx, err := db.Begin() - if err != nil { - return nil, err - } - committed := false - defer func() { - if !committed { - _ = tx.Rollback() - } - }() - query := ` - SELECT id, session_id, task_type, COALESCE(payload_json, '{}'), - status, COALESCE(source, 'manual'), COALESCE(approval_status, ''), - created_at - FROM c2_tasks - WHERE session_id = ? AND (status = 'queued' AND (approval_status = '' OR approval_status = 'approved')) - ORDER BY created_at ASC - LIMIT ? - ` - rows, err := tx.Query(query, sessionID, limit) - if err != nil { - return nil, err - } - var list []*C2Task - for rows.Next() { - var t C2Task - var payloadJSON string - if err := rows.Scan(&t.ID, &t.SessionID, &t.TaskType, &payloadJSON, - &t.Status, &t.Source, &t.ApprovalStatus, &t.CreatedAt); err != nil { - rows.Close() - return nil, err - } - if payloadJSON != "" && payloadJSON != "{}" { - _ = json.Unmarshal([]byte(payloadJSON), &t.Payload) - } - list = append(list, &t) - } - rows.Close() - - now := time.Now() - for _, t := range list { - if _, err := tx.Exec( - `UPDATE c2_tasks SET status = 'sent', sent_at = ? WHERE id = ?`, now, t.ID, - ); err != nil { - return nil, err - } - t.Status = "sent" - t.SentAt = &now - } - if err := tx.Commit(); err != nil { - return nil, err - } - committed = true - return list, nil -} - -// DeleteC2Task 删除任务(一般用于 cancel queued) -func (db *DB) DeleteC2Task(id string) error { - res, err := db.Exec(`DELETE FROM c2_tasks WHERE id = ?`, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// DeleteC2TasksByIDs 按主键批量删除任务 -func (db *DB) DeleteC2TasksByIDs(ids []string) (int64, error) { - if len(ids) == 0 { - return 0, nil - } - const maxBatch = 500 - if len(ids) > maxBatch { - ids = ids[:maxBatch] - } - clean := make([]string, 0, len(ids)) - seen := make(map[string]struct{}, len(ids)) - for _, id := range ids { - id = strings.TrimSpace(id) - if !validC2TextIDForDelete(id) { - continue - } - if _, ok := seen[id]; ok { - continue - } - seen[id] = struct{}{} - clean = append(clean, id) - } - if len(clean) == 0 { - return 0, ErrNoValidC2TaskIDs - } - placeholders := strings.Repeat("?,", len(clean)-1) + "?" - args := make([]interface{}, len(clean)) - for i := range clean { - args[i] = clean[i] - } - query := `DELETE FROM c2_tasks WHERE id IN (` + placeholders + `)` - res, err := db.Exec(query, args...) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -// ---------------------------------------------------------------------------- -// CRUD:C2 文件 -// ---------------------------------------------------------------------------- - -// CreateC2File 记录上传/下载凭证(实际文件落盘由调用方处理) -func (db *DB) CreateC2File(f *C2File) error { - if f == nil || strings.TrimSpace(f.ID) == "" { - return errors.New("file id is required") - } - if f.CreatedAt.IsZero() { - f.CreatedAt = time.Now() - } - query := ` - INSERT INTO c2_files (id, session_id, task_id, direction, remote_path, - local_path, size_bytes, sha256, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - _, err := db.Exec(query, f.ID, f.SessionID, f.TaskID, f.Direction, - f.RemotePath, f.LocalPath, f.SizeBytes, f.SHA256, f.CreatedAt) - return err -} - -// ListC2FilesBySession 列出某会话下所有上传/下载凭证 -func (db *DB) ListC2FilesBySession(sessionID string) ([]*C2File, error) { - query := ` - SELECT id, session_id, COALESCE(task_id, ''), direction, remote_path, local_path, - COALESCE(size_bytes, 0), COALESCE(sha256, ''), created_at - FROM c2_files WHERE session_id = ? ORDER BY created_at DESC - ` - rows, err := db.Query(query, sessionID) - if err != nil { - return nil, err - } - defer rows.Close() - var list []*C2File - for rows.Next() { - var f C2File - if err := rows.Scan(&f.ID, &f.SessionID, &f.TaskID, &f.Direction, - &f.RemotePath, &f.LocalPath, &f.SizeBytes, &f.SHA256, &f.CreatedAt); err != nil { - continue - } - list = append(list, &f) - } - return list, rows.Err() -} - -// ---------------------------------------------------------------------------- -// CRUD:C2 事件审计 -// ---------------------------------------------------------------------------- - -// AppendC2Event 写一条审计事件 -func (db *DB) AppendC2Event(e *C2Event) error { - if e == nil { - return errors.New("event is nil") - } - if strings.TrimSpace(e.ID) == "" { - return errors.New("event id is required") - } - if e.CreatedAt.IsZero() { - e.CreatedAt = time.Now() - } - if strings.TrimSpace(e.Level) == "" { - e.Level = "info" - } - dataJSON := "" - if len(e.Data) > 0 { - if b, err := json.Marshal(e.Data); err == nil { - dataJSON = string(b) - } - } - query := ` - INSERT INTO c2_events (id, level, category, session_id, task_id, message, data_json, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ` - _, err := db.Exec(query, e.ID, e.Level, e.Category, e.SessionID, e.TaskID, e.Message, dataJSON, e.CreatedAt) - return err -} - -// ListC2EventsFilter 事件查询参数 -type ListC2EventsFilter struct { - Level string - Category string - SessionID string - TaskID string - Since *time.Time - Limit int - Offset int -} - -func buildC2EventsWhere(filter ListC2EventsFilter) (where string, args []interface{}) { - conditions := []string{"1=1"} - args = []interface{}{} - if filter.Level != "" { - conditions = append(conditions, "level = ?") - args = append(args, filter.Level) - } - if filter.Category != "" { - conditions = append(conditions, "category = ?") - args = append(args, filter.Category) - } - if filter.SessionID != "" { - conditions = append(conditions, "session_id = ?") - args = append(args, filter.SessionID) - } - if filter.TaskID != "" { - conditions = append(conditions, "task_id = ?") - args = append(args, filter.TaskID) - } - if filter.Since != nil { - conditions = append(conditions, "created_at >= ?") - args = append(args, *filter.Since) - } - return strings.Join(conditions, " AND "), args -} - -// CountC2Events 与 ListC2Events 相同过滤条件下的记录总数 -func (db *DB) CountC2Events(filter ListC2EventsFilter) (int64, error) { - where, args := buildC2EventsWhere(filter) - query := `SELECT COUNT(*) FROM c2_events WHERE ` + where - var n int64 - err := db.QueryRow(query, args...).Scan(&n) - return n, err -} - -// ListC2Events 事件查询,按创建时间倒序 -func (db *DB) ListC2Events(filter ListC2EventsFilter) ([]*C2Event, error) { - where, args := buildC2EventsWhere(filter) - limit := filter.Limit - if limit <= 0 || limit > 1000 { - limit = 200 - } - offset := filter.Offset - if offset < 0 { - offset = 0 - } - query := ` - SELECT id, level, category, COALESCE(session_id, ''), COALESCE(task_id, ''), - message, COALESCE(data_json, ''), created_at - FROM c2_events - WHERE ` + where + ` - ORDER BY created_at DESC - LIMIT ? OFFSET ? - ` - args = append(args, limit, offset) - rows, err := db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - var list []*C2Event - for rows.Next() { - var e C2Event - var dataJSON string - if err := rows.Scan(&e.ID, &e.Level, &e.Category, &e.SessionID, &e.TaskID, - &e.Message, &dataJSON, &e.CreatedAt); err != nil { - continue - } - if dataJSON != "" { - _ = json.Unmarshal([]byte(dataJSON), &e.Data) - } - list = append(list, &e) - } - return list, rows.Err() -} - -// DeleteC2EventsByIDs 按主键批量删除事件,返回实际删除行数 -func (db *DB) DeleteC2EventsByIDs(ids []string) (int64, error) { - if len(ids) == 0 { - return 0, nil - } - const maxBatch = 500 - if len(ids) > maxBatch { - ids = ids[:maxBatch] - } - clean := make([]string, 0, len(ids)) - seen := make(map[string]struct{}, len(ids)) - for _, id := range ids { - id = strings.TrimSpace(id) - if !validC2TextIDForDelete(id) { - continue - } - if _, ok := seen[id]; ok { - continue - } - seen[id] = struct{}{} - clean = append(clean, id) - } - if len(clean) == 0 { - return 0, ErrNoValidC2EventIDs - } - placeholders := strings.Repeat("?,", len(clean)-1) + "?" - args := make([]interface{}, len(clean)) - for i := range clean { - args[i] = clean[i] - } - query := `DELETE FROM c2_events WHERE id IN (` + placeholders + `)` - res, err := db.Exec(query, args...) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -// ---------------------------------------------------------------------------- -// CRUD:C2 Malleable Profile -// ---------------------------------------------------------------------------- - -// CreateC2Profile 创建/覆盖 Profile(按 name 唯一) -func (db *DB) CreateC2Profile(p *C2Profile) error { - if p == nil || strings.TrimSpace(p.ID) == "" { - return errors.New("profile id is required") - } - if p.CreatedAt.IsZero() { - p.CreatedAt = time.Now() - } - urisJSON, _ := json.Marshal(p.URIs) - reqHdrJSON, _ := json.Marshal(p.RequestHeaders) - resHdrJSON, _ := json.Marshal(p.ResponseHeaders) - query := ` - INSERT INTO c2_profiles (id, name, user_agent, uris_json, request_headers_json, - response_headers_json, body_template, jitter_min_ms, jitter_max_ms, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - _, err := db.Exec(query, p.ID, p.Name, p.UserAgent, string(urisJSON), - string(reqHdrJSON), string(resHdrJSON), p.BodyTemplate, - p.JitterMinMS, p.JitterMaxMS, p.CreatedAt) - return err -} - -// UpdateC2Profile 全量更新 Profile -func (db *DB) UpdateC2Profile(p *C2Profile) error { - if p == nil || strings.TrimSpace(p.ID) == "" { - return errors.New("profile id is required") - } - urisJSON, _ := json.Marshal(p.URIs) - reqHdrJSON, _ := json.Marshal(p.RequestHeaders) - resHdrJSON, _ := json.Marshal(p.ResponseHeaders) - query := ` - UPDATE c2_profiles SET name = ?, user_agent = ?, uris_json = ?, - request_headers_json = ?, response_headers_json = ?, body_template = ?, - jitter_min_ms = ?, jitter_max_ms = ? - WHERE id = ? - ` - res, err := db.Exec(query, p.Name, p.UserAgent, string(urisJSON), - string(reqHdrJSON), string(resHdrJSON), p.BodyTemplate, - p.JitterMinMS, p.JitterMaxMS, p.ID) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// GetC2Profile 单条 -func (db *DB) GetC2Profile(id string) (*C2Profile, error) { - query := ` - SELECT id, name, COALESCE(user_agent, ''), COALESCE(uris_json, '[]'), - COALESCE(request_headers_json, '{}'), COALESCE(response_headers_json, '{}'), - COALESCE(body_template, ''), COALESCE(jitter_min_ms, 0), COALESCE(jitter_max_ms, 0), - created_at - FROM c2_profiles WHERE id = ? - ` - var p C2Profile - var urisJSON, reqHdrJSON, resHdrJSON string - err := db.QueryRow(query, id).Scan(&p.ID, &p.Name, &p.UserAgent, &urisJSON, - &reqHdrJSON, &resHdrJSON, &p.BodyTemplate, &p.JitterMinMS, &p.JitterMaxMS, &p.CreatedAt) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } - _ = json.Unmarshal([]byte(urisJSON), &p.URIs) - _ = json.Unmarshal([]byte(reqHdrJSON), &p.RequestHeaders) - _ = json.Unmarshal([]byte(resHdrJSON), &p.ResponseHeaders) - return &p, nil -} - -// ListC2Profiles 全量列表 -func (db *DB) ListC2Profiles() ([]*C2Profile, error) { - query := ` - SELECT id, name, COALESCE(user_agent, ''), COALESCE(uris_json, '[]'), - COALESCE(request_headers_json, '{}'), COALESCE(response_headers_json, '{}'), - COALESCE(body_template, ''), COALESCE(jitter_min_ms, 0), COALESCE(jitter_max_ms, 0), - created_at - FROM c2_profiles ORDER BY created_at DESC - ` - rows, err := db.Query(query) - if err != nil { - return nil, err - } - defer rows.Close() - var list []*C2Profile - for rows.Next() { - var p C2Profile - var urisJSON, reqHdrJSON, resHdrJSON string - if err := rows.Scan(&p.ID, &p.Name, &p.UserAgent, &urisJSON, - &reqHdrJSON, &resHdrJSON, &p.BodyTemplate, &p.JitterMinMS, &p.JitterMaxMS, &p.CreatedAt); err != nil { - continue - } - _ = json.Unmarshal([]byte(urisJSON), &p.URIs) - _ = json.Unmarshal([]byte(reqHdrJSON), &p.RequestHeaders) - _ = json.Unmarshal([]byte(resHdrJSON), &p.ResponseHeaders) - list = append(list, &p) - } - return list, rows.Err() -} - -// DeleteC2Profile 删除 Profile(不影响已用此 Profile 的 listener,仅断开关联) -func (db *DB) DeleteC2Profile(id string) error { - if _, err := db.Exec(`UPDATE c2_listeners SET profile_id = '' WHERE profile_id = ?`, id); err != nil { - return err - } - res, err := db.Exec(`DELETE FROM c2_profiles WHERE id = ?`, id) - if err != nil { - return err - } - affected, _ := res.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} diff --git a/database/conversation.go b/database/conversation.go deleted file mode 100644 index d23506a4..00000000 --- a/database/conversation.go +++ /dev/null @@ -1,812 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - "github.com/google/uuid" - "go.uber.org/zap" -) - -// Conversation 对话 -type Conversation struct { - ID string `json:"id"` - Title string `json:"title"` - Pinned bool `json:"pinned"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` - Messages []Message `json:"messages,omitempty"` -} - -// Message 消息 -type Message struct { - ID string `json:"id"` - ConversationID string `json:"conversationId"` - Role string `json:"role"` - Content string `json:"content"` - ReasoningContent string `json:"reasoningContent,omitempty"` - MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` - ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// CreateConversation 创建新对话 -func (db *DB) CreateConversation(title string) (*Conversation, error) { - return db.CreateConversationWithWebshell("", title) -} - -// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话) -func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) { - id := uuid.New().String() - now := time.Now() - - var err error - if webshellConnectionID != "" { - _, err = db.Exec( - "INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)", - id, title, now, now, webshellConnectionID, - ) - } else { - _, err = db.Exec( - "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", - id, title, now, now, - ) - } - if err != nil { - return nil, fmt.Errorf("创建对话失败: %w", err) - } - - return &Conversation{ - ID: id, - Title: title, - CreatedAt: now, - UpdatedAt: now, - }, nil -} - -// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化) -func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conversation, error) { - if connectionID == "" { - return nil, fmt.Errorf("connectionID is empty") - } - var conv Conversation - var createdAt, updatedAt string - var pinned int - err := db.QueryRow( - "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC LIMIT 1", - connectionID, - ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("查询对话失败: %w", err) - } - conv.Pinned = pinned != 0 - if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt); e == nil { - conv.CreatedAt = t - } else if t, e := time.Parse("2006-01-02 15:04:05", createdAt); e == nil { - conv.CreatedAt = t - } else { - conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { - conv.UpdatedAt = t - } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { - conv.UpdatedAt = t - } else { - conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - messages, err := db.GetMessages(conv.ID) - if err != nil { - return nil, fmt.Errorf("加载消息失败: %w", err) - } - conv.Messages = messages - - // 加载过程详情并附加到对应消息(与 GetConversation 一致,便于刷新后仍可查看执行过程) - processDetailsMap, err := db.GetProcessDetailsByConversation(conv.ID) - if err != nil { - db.logger.Warn("加载过程详情失败", zap.Error(err)) - processDetailsMap = make(map[string][]ProcessDetail) - } - for i := range conv.Messages { - if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { - detailsJSON := make([]map[string]interface{}, len(details)) - for j, detail := range details { - var data interface{} - if detail.Data != "" { - if err := json.Unmarshal([]byte(detail.Data), &data); err != nil { - db.logger.Warn("解析过程详情数据失败", zap.Error(err)) - } - } - detailsJSON[j] = map[string]interface{}{ - "id": detail.ID, - "messageId": detail.MessageID, - "conversationId": detail.ConversationID, - "eventType": detail.EventType, - "message": detail.Message, - "data": data, - "createdAt": detail.CreatedAt, - } - } - conv.Messages[i].ProcessDetails = detailsJSON - } - } - - return &conv, nil -} - -// WebShellConversationItem 用于侧边栏列表,不含消息 -type WebShellConversationItem struct { - ID string `json:"id"` - Title string `json:"title"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// ListConversationsByWebshellConnectionID 列出该 WebShell 连接下的所有对话(按更新时间倒序),供侧边栏展示 -func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]WebShellConversationItem, error) { - if connectionID == "" { - return nil, nil - } - rows, err := db.Query( - "SELECT id, title, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC", - connectionID, - ) - if err != nil { - return nil, fmt.Errorf("查询对话列表失败: %w", err) - } - defer rows.Close() - var list []WebShellConversationItem - for rows.Next() { - var item WebShellConversationItem - var updatedAt string - if err := rows.Scan(&item.ID, &item.Title, &updatedAt); err != nil { - continue - } - if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { - item.UpdatedAt = t - } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { - item.UpdatedAt = t - } else { - item.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - list = append(list, item) - } - return list, rows.Err() -} - -// GetConversation 获取对话 -func (db *DB) GetConversation(id string) (*Conversation, error) { - var conv Conversation - var createdAt, updatedAt string - var pinned int - - err := db.QueryRow( - "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", - id, - ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("对话不存在") - } - return nil, fmt.Errorf("查询对话失败: %w", err) - } - - // 尝试多种时间格式解析 - var err1, err2 error - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - conv.Pinned = pinned != 0 - - // 加载消息 - messages, err := db.GetMessages(id) - if err != nil { - return nil, fmt.Errorf("加载消息失败: %w", err) - } - conv.Messages = messages - - // 加载过程详情(按消息ID分组) - processDetailsMap, err := db.GetProcessDetailsByConversation(id) - if err != nil { - db.logger.Warn("加载过程详情失败", zap.Error(err)) - processDetailsMap = make(map[string][]ProcessDetail) - } - - // 将过程详情附加到对应的消息上 - for i := range conv.Messages { - if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { - // 将ProcessDetail转换为JSON格式,以便前端使用 - detailsJSON := make([]map[string]interface{}, len(details)) - for j, detail := range details { - var data interface{} - if detail.Data != "" { - if err := json.Unmarshal([]byte(detail.Data), &data); err != nil { - db.logger.Warn("解析过程详情数据失败", zap.Error(err)) - } - } - detailsJSON[j] = map[string]interface{}{ - "id": detail.ID, - "messageId": detail.MessageID, - "conversationId": detail.ConversationID, - "eventType": detail.EventType, - "message": detail.Message, - "data": data, - "createdAt": detail.CreatedAt, - } - } - conv.Messages[i].ProcessDetails = detailsJSON - } - } - - return &conv, nil -} - -// GetConversationLite 获取对话(轻量版):包含 messages,但不加载 process_details。 -// 用于历史会话快速切换,避免一次性把大体量过程详情灌到前端导致卡顿。 -func (db *DB) GetConversationLite(id string) (*Conversation, error) { - var conv Conversation - var createdAt, updatedAt string - var pinned int - - err := db.QueryRow( - "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", - id, - ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("对话不存在") - } - return nil, fmt.Errorf("查询对话失败: %w", err) - } - - // 尝试多种时间格式解析 - var err1, err2 error - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - conv.Pinned = pinned != 0 - - // 加载消息(不加载 process_details) - messages, err := db.GetMessages(id) - if err != nil { - return nil, fmt.Errorf("加载消息失败: %w", err) - } - conv.Messages = messages - return &conv, nil -} - -// ListConversations 列出所有对话 -func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) { - var rows *sql.Rows - var err error - - if search != "" { - // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 - searchPattern := "%" + search + "%" - rows, err = db.Query( - `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at - FROM conversations c - WHERE c.title LIKE ? - OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) - ORDER BY c.updated_at DESC - LIMIT ? OFFSET ?`, - searchPattern, searchPattern, limit, offset, - ) - } else { - rows, err = db.Query( - "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", - limit, offset, - ) - } - - if err != nil { - return nil, fmt.Errorf("查询对话列表失败: %w", err) - } - defer rows.Close() - - var conversations []*Conversation - for rows.Next() { - var conv Conversation - var createdAt, updatedAt string - var pinned int - - if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil { - return nil, fmt.Errorf("扫描对话失败: %w", err) - } - - // 尝试多种时间格式解析 - var err1, err2 error - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - conv.Pinned = pinned != 0 - - conversations = append(conversations, &conv) - } - - return conversations, nil -} - -// UpdateConversationTitle 更新对话标题 -func (db *DB) UpdateConversationTitle(id, title string) error { - // 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间 - _, err := db.Exec( - "UPDATE conversations SET title = ? WHERE id = ?", - title, id, - ) - if err != nil { - return fmt.Errorf("更新对话标题失败: %w", err) - } - return nil -} - -// UpdateConversationTime 更新对话时间 -func (db *DB) UpdateConversationTime(id string) error { - _, err := db.Exec( - "UPDATE conversations SET updated_at = ? WHERE id = ?", - time.Now(), id, - ) - if err != nil { - return fmt.Errorf("更新对话时间失败: %w", err) - } - return nil -} - -// DeleteConversation 删除对话及其所有相关数据 -// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除: -// - messages(消息) -// - process_details(过程详情) -// - attack_chain_nodes(攻击链节点) -// - attack_chain_edges(攻击链边) -// - vulnerabilities(漏洞) -// - conversation_group_mappings(分组映射) -// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL -func (db *DB) DeleteConversation(id string) error { - // 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除) - _, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) - if err != nil { - db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err)) - // 不返回错误,继续删除对话 - } - - // 删除对话(外键CASCADE会自动删除其他相关数据) - _, err = db.Exec("DELETE FROM conversations WHERE id = ?", id) - if err != nil { - return fmt.Errorf("删除对话失败: %w", err) - } - // Best-effort cleanup for conversation-scoped filesystem artifacts - // (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/). - if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" { - artDir := filepath.Join(base, id) - if rmErr := os.RemoveAll(artDir); rmErr != nil { - db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr)) - } - } - - db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) - return nil -} - -// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。 -// SQLite 列名仍为 last_react_input / last_react_output,与历史库表兼容;语义上为「全模式代理轨迹」,非仅 ReAct。 -func (db *DB) SaveAgentTrace(conversationID, traceInputJSON, assistantOutput string) error { - _, err := db.Exec( - "UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?", - traceInputJSON, assistantOutput, time.Now(), conversationID, - ) - if err != nil { - return fmt.Errorf("保存代理轨迹失败: %w", err) - } - return nil -} - -// GetAgentTrace 读取 conversations 中保存的代理轨迹(列名 last_react_*)。 -func (db *DB) GetAgentTrace(conversationID string) (traceInputJSON, assistantOutput string, err error) { - var input, output sql.NullString - err = db.QueryRow( - "SELECT last_react_input, last_react_output FROM conversations WHERE id = ?", - conversationID, - ).Scan(&input, &output) - if err != nil { - if err == sql.ErrNoRows { - return "", "", fmt.Errorf("对话不存在") - } - return "", "", fmt.Errorf("获取代理轨迹失败: %w", err) - } - - if input.Valid { - traceInputJSON = input.String - } - if output.Valid { - assistantOutput = output.String - } - - return traceInputJSON, assistantOutput, nil -} - -// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。 -func (db *DB) ConversationHasToolProcessDetails(conversationID string) (bool, error) { - var n int - err := db.QueryRow( - `SELECT COUNT(*) FROM process_details WHERE conversation_id = ? AND event_type IN ('tool_call', 'tool_result')`, - conversationID, - ).Scan(&n) - if err != nil { - return false, fmt.Errorf("查询过程详情失败: %w", err) - } - return n > 0, nil -} - -// AddMessage 添加消息 -func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { - id := uuid.New().String() - now := time.Now() - - var mcpIDsJSON string - if len(mcpExecutionIDs) > 0 { - jsonData, err := json.Marshal(mcpExecutionIDs) - if err != nil { - db.logger.Warn("序列化MCP执行ID失败", zap.Error(err)) - } else { - mcpIDsJSON = string(jsonData) - } - } - - _, err := db.Exec( - "INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - id, conversationID, role, content, "", mcpIDsJSON, now, now, - ) - if err != nil { - return nil, fmt.Errorf("添加消息失败: %w", err) - } - - // 更新对话时间 - if err := db.UpdateConversationTime(conversationID); err != nil { - db.logger.Warn("更新对话时间失败", zap.Error(err)) - } - - message := &Message{ - ID: id, - ConversationID: conversationID, - Role: role, - Content: content, - MCPExecutionIDs: mcpExecutionIDs, - CreatedAt: now, - UpdatedAt: now, - } - - return message, nil -} - -// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。 -func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error { - var mcpIDsJSON string - if len(mcpExecutionIDs) > 0 { - jsonData, err := json.Marshal(mcpExecutionIDs) - if err != nil { - return fmt.Errorf("序列化MCP执行ID失败: %w", err) - } - mcpIDsJSON = string(jsonData) - } - _, err := db.Exec( - "UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?", - content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID, - ) - if err != nil { - return fmt.Errorf("更新助手消息失败: %w", err) - } - return nil -} - -// GetMessages 获取对话的所有消息 -func (db *DB) GetMessages(conversationID string) ([]Message, error) { - rows, err := db.Query( - "SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", - conversationID, - ) - if err != nil { - return nil, fmt.Errorf("查询消息失败: %w", err) - } - defer rows.Close() - - var messages []Message - for rows.Next() { - var msg Message - var reasoning sql.NullString - var mcpIDsJSON sql.NullString - var createdAt string - var updatedAt sql.NullString - - if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil { - return nil, fmt.Errorf("扫描消息失败: %w", err) - } - if reasoning.Valid { - msg.ReasoningContent = reasoning.String - } - - // 尝试多种时间格式解析 - var err error - msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err != nil { - msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err != nil { - msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - // updated_at 兼容老库:字段不存在/为空时回退为 created_at - if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" { - msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String) - if err != nil { - msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String) - } - if err != nil { - msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String) - } - } - if msg.UpdatedAt.IsZero() { - msg.UpdatedAt = msg.CreatedAt - } - - // 解析MCP执行ID - if mcpIDsJSON.Valid && mcpIDsJSON.String != "" { - if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil { - db.logger.Warn("解析MCP执行ID失败", zap.Error(err)) - } - } - - messages = append(messages, msg) - } - - return messages, nil -} - -// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。 -// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。 -func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) { - idx := -1 - for i := range msgs { - if msgs[i].ID == anchorID { - idx = i - break - } - } - if idx < 0 { - return 0, 0, fmt.Errorf("message not found") - } - start = idx - for start > 0 && msgs[start].Role != "user" { - start-- - } - if start < len(msgs) && msgs[start].Role != "user" { - start = 0 - } - end = len(msgs) - for i := start + 1; i < len(msgs); i++ { - if msgs[i].Role == "user" { - end = i - break - } - } - return start, end, nil -} - -// DeleteConversationTurn 删除锚点所在轮次的全部消息(用户提问 + 该轮助手回复等),并清空 last_react_*,避免与消息表不一致。 -func (db *DB) DeleteConversationTurn(conversationID, anchorMessageID string) (deletedIDs []string, err error) { - msgs, err := db.GetMessages(conversationID) - if err != nil { - return nil, err - } - start, end, err := turnSliceRange(msgs, anchorMessageID) - if err != nil { - return nil, err - } - if start >= end { - return nil, fmt.Errorf("empty turn range") - } - deletedIDs = make([]string, 0, end-start) - for i := start; i < end; i++ { - deletedIDs = append(deletedIDs, msgs[i].ID) - } - - tx, err := db.Begin() - if err != nil { - return nil, fmt.Errorf("begin tx: %w", err) - } - defer func() { _ = tx.Rollback() }() - - ph := strings.Repeat("?,", len(deletedIDs)) - ph = ph[:len(ph)-1] - args := make([]interface{}, 0, 1+len(deletedIDs)) - args = append(args, conversationID) - for _, id := range deletedIDs { - args = append(args, id) - } - res, err := tx.Exec( - "DELETE FROM messages WHERE conversation_id = ? AND id IN ("+ph+")", - args..., - ) - if err != nil { - return nil, fmt.Errorf("delete messages: %w", err) - } - n, err := res.RowsAffected() - if err != nil { - return nil, err - } - if int(n) != len(deletedIDs) { - return nil, fmt.Errorf("deleted count mismatch") - } - - _, err = tx.Exec( - `UPDATE conversations SET last_react_input = NULL, last_react_output = NULL, updated_at = ? WHERE id = ?`, - time.Now(), conversationID, - ) - if err != nil { - return nil, fmt.Errorf("clear react data: %w", err) - } - - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("commit: %w", err) - } - - db.logger.Info("conversation turn deleted", - zap.String("conversationId", conversationID), - zap.Strings("deletedMessageIds", deletedIDs), - zap.Int("count", len(deletedIDs)), - ) - return deletedIDs, nil -} - -// ProcessDetail 过程详情事件 -type ProcessDetail struct { - ID string `json:"id"` - MessageID string `json:"messageId"` - ConversationID string `json:"conversationId"` - EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error - Message string `json:"message"` - Data string `json:"data"` // JSON格式的数据 - CreatedAt time.Time `json:"createdAt"` -} - -// AddProcessDetail 添加过程详情事件 -func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error { - id := uuid.New().String() - - var dataJSON string - if data != nil { - jsonData, err := json.Marshal(data) - if err != nil { - db.logger.Warn("序列化过程详情数据失败", zap.Error(err)) - } else { - dataJSON = string(jsonData) - } - } - - _, err := db.Exec( - "INSERT INTO process_details (id, message_id, conversation_id, event_type, message, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, messageID, conversationID, eventType, message, dataJSON, time.Now(), - ) - if err != nil { - return fmt.Errorf("添加过程详情失败: %w", err) - } - - return nil -} - -// GetProcessDetails 获取消息的过程详情 -func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) { - rows, err := db.Query( - "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC", - messageID, - ) - if err != nil { - return nil, fmt.Errorf("查询过程详情失败: %w", err) - } - defer rows.Close() - - var details []ProcessDetail - for rows.Next() { - var detail ProcessDetail - var createdAt string - - if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil { - return nil, fmt.Errorf("扫描过程详情失败: %w", err) - } - - // 尝试多种时间格式解析 - var err error - detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err != nil { - detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err != nil { - detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - details = append(details, detail) - } - - return details, nil -} - -// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组) -func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) { - rows, err := db.Query( - "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC", - conversationID, - ) - if err != nil { - return nil, fmt.Errorf("查询过程详情失败: %w", err) - } - defer rows.Close() - - detailsMap := make(map[string][]ProcessDetail) - for rows.Next() { - var detail ProcessDetail - var createdAt string - - if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil { - return nil, fmt.Errorf("扫描过程详情失败: %w", err) - } - - // 尝试多种时间格式解析 - var err error - detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err != nil { - detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err != nil { - detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail) - } - - return detailsMap, nil -} diff --git a/database/conversation_turn_test.go b/database/conversation_turn_test.go deleted file mode 100644 index 68743468..00000000 --- a/database/conversation_turn_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package database - -import ( - "testing" -) - -func TestTurnSliceRange(t *testing.T) { - mk := func(id, role string) Message { - return Message{ID: id, Role: role} - } - msgs := []Message{ - mk("u1", "user"), - mk("a1", "assistant"), - mk("u2", "user"), - mk("a2", "assistant"), - } - cases := []struct { - anchor string - start int - end int - }{ - {"u1", 0, 2}, - {"a1", 0, 2}, - {"u2", 2, 4}, - {"a2", 2, 4}, - } - for _, tc := range cases { - s, e, err := turnSliceRange(msgs, tc.anchor) - if err != nil { - t.Fatalf("anchor %s: %v", tc.anchor, err) - } - if s != tc.start || e != tc.end { - t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end) - } - } - if _, _, err := turnSliceRange(msgs, "nope"); err == nil { - t.Fatal("expected error for missing id") - } -} diff --git a/database/database.go b/database/database.go deleted file mode 100644 index 6321e1a5..00000000 --- a/database/database.go +++ /dev/null @@ -1,1108 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - _ "github.com/mattn/go-sqlite3" - "go.uber.org/zap" -) - -// configureDBPool 设置 SQLite 连接池参数,提升并发稳定性 -func configureDBPool(db *sql.DB) { - // SQLite 同一时间只允许一个写入者,限制连接数避免 "database is locked" 错误 - db.SetMaxOpenConns(25) - db.SetMaxIdleConns(5) - db.SetConnMaxLifetime(30 * time.Minute) -} - -// DB 数据库连接 -type DB struct { - *sql.DB - logger *zap.Logger - conversationArtifactsDir string -} - -// NewDB 创建数据库连接 -func NewDB(dbPath string, logger *zap.Logger) (*DB, error) { - db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") - if err != nil { - return nil, fmt.Errorf("打开数据库失败: %w", err) - } - - configureDBPool(db) - - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("连接数据库失败: %w", err) - } - - database := &DB{ - DB: db, - logger: logger, - } - // Keep conversation-scoped artifacts near database files, so cleanup can follow conversation lifecycle. - baseDir := filepath.Join(filepath.Dir(dbPath), "conversation_artifacts") - if mkErr := os.MkdirAll(baseDir, 0o755); mkErr == nil { - database.conversationArtifactsDir = baseDir - } else if logger != nil { - logger.Warn("创建 conversation artifacts 目录失败", zap.String("dir", baseDir), zap.Error(mkErr)) - } - - // 初始化表 - if err := database.initTables(); err != nil { - return nil, fmt.Errorf("初始化表失败: %w", err) - } - - return database, nil -} - -// initTables 初始化数据库表 -func (db *DB) initTables() error { - // 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库) - createConversationsTable := ` - CREATE TABLE IF NOT EXISTS conversations ( - id TEXT PRIMARY KEY, - title TEXT NOT NULL, - created_at DATETIME NOT NULL, - updated_at DATETIME NOT NULL, - last_react_input TEXT, - last_react_output TEXT - );` - - // 创建消息表 - createMessagesTable := ` - CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, - role TEXT NOT NULL, - content TEXT NOT NULL, - mcp_execution_ids TEXT, - created_at DATETIME NOT NULL, - updated_at DATETIME NOT NULL, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - );` - - // 创建过程详情表 - createProcessDetailsTable := ` - CREATE TABLE IF NOT EXISTS process_details ( - id TEXT PRIMARY KEY, - message_id TEXT NOT NULL, - conversation_id TEXT NOT NULL, - event_type TEXT NOT NULL, - message TEXT, - data TEXT, - created_at DATETIME NOT NULL, - FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - );` - - // 创建工具执行记录表 - createToolExecutionsTable := ` - CREATE TABLE IF NOT EXISTS tool_executions ( - id TEXT PRIMARY KEY, - tool_name TEXT NOT NULL, - arguments TEXT NOT NULL, - status TEXT NOT NULL, - result TEXT, - error TEXT, - start_time DATETIME NOT NULL, - end_time DATETIME, - duration_ms INTEGER, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - );` - - // 创建工具统计表 - createToolStatsTable := ` - CREATE TABLE IF NOT EXISTS tool_stats ( - tool_name TEXT PRIMARY KEY, - total_calls INTEGER NOT NULL DEFAULT 0, - success_calls INTEGER NOT NULL DEFAULT 0, - failed_calls INTEGER NOT NULL DEFAULT 0, - last_call_time DATETIME, - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - );` - - // 创建Skills统计表 - createSkillStatsTable := ` - CREATE TABLE IF NOT EXISTS skill_stats ( - skill_name TEXT PRIMARY KEY, - total_calls INTEGER NOT NULL DEFAULT 0, - success_calls INTEGER NOT NULL DEFAULT 0, - failed_calls INTEGER NOT NULL DEFAULT 0, - last_call_time DATETIME, - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - );` - - // 创建攻击链节点表 - createAttackChainNodesTable := ` - CREATE TABLE IF NOT EXISTS attack_chain_nodes ( - id TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, - node_type TEXT NOT NULL, - node_name TEXT NOT NULL, - tool_execution_id TEXT, - metadata TEXT, - risk_score INTEGER DEFAULT 0, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, - FOREIGN KEY (tool_execution_id) REFERENCES tool_executions(id) ON DELETE SET NULL - );` - - // 创建攻击链边表 - createAttackChainEdgesTable := ` - CREATE TABLE IF NOT EXISTS attack_chain_edges ( - id TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, - source_node_id TEXT NOT NULL, - target_node_id TEXT NOT NULL, - edge_type TEXT NOT NULL, - weight INTEGER DEFAULT 1, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, - FOREIGN KEY (source_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE, - FOREIGN KEY (target_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE - );` - - // 创建知识检索日志表(保留在会话数据库中,因为有外键关联) - createKnowledgeRetrievalLogsTable := ` - CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - message_id TEXT, - query TEXT NOT NULL, - risk_type TEXT, - retrieved_items TEXT, - created_at DATETIME NOT NULL, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL, - FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL - );` - - // 创建对话分组表 - createConversationGroupsTable := ` - CREATE TABLE IF NOT EXISTS conversation_groups ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - icon TEXT, - created_at DATETIME NOT NULL, - updated_at DATETIME NOT NULL - );` - - // 创建对话分组映射表 - createConversationGroupMappingsTable := ` - CREATE TABLE IF NOT EXISTS conversation_group_mappings ( - id TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, - group_id TEXT NOT NULL, - created_at DATETIME NOT NULL, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, - FOREIGN KEY (group_id) REFERENCES conversation_groups(id) ON DELETE CASCADE, - UNIQUE(conversation_id, group_id) - );` - - // 机器人会话绑定表(用于跨重启保持「平台+租户+用户」到 conversation 的映射) - createRobotUserSessionsTable := ` - CREATE TABLE IF NOT EXISTS robot_user_sessions ( - session_key TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, - role_name TEXT NOT NULL DEFAULT '默认', - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - );` - - // 创建漏洞表 - createVulnerabilitiesTable := ` - CREATE TABLE IF NOT EXISTS vulnerabilities ( - id TEXT PRIMARY KEY, - conversation_id TEXT NOT NULL, - conversation_tag TEXT, - task_tag TEXT, - title TEXT NOT NULL, - description TEXT, - severity TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'open', - vulnerability_type TEXT, - target TEXT, - proof TEXT, - impact TEXT, - recommendation TEXT, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE - );` - - // 创建批量任务队列表 - createBatchTaskQueuesTable := ` - CREATE TABLE IF NOT EXISTS batch_task_queues ( - id TEXT PRIMARY KEY, - title TEXT, - role TEXT, - agent_mode TEXT NOT NULL DEFAULT 'single', - schedule_mode TEXT NOT NULL DEFAULT 'manual', - cron_expr TEXT, - next_run_at DATETIME, - schedule_enabled INTEGER NOT NULL DEFAULT 1, - last_schedule_trigger_at DATETIME, - last_schedule_error TEXT, - last_run_error TEXT, - status TEXT NOT NULL, - created_at DATETIME NOT NULL, - started_at DATETIME, - completed_at DATETIME, - current_index INTEGER NOT NULL DEFAULT 0 - );` - - // 创建批量任务表 - createBatchTasksTable := ` - CREATE TABLE IF NOT EXISTS batch_tasks ( - id TEXT PRIMARY KEY, - queue_id TEXT NOT NULL, - message TEXT NOT NULL, - conversation_id TEXT, - status TEXT NOT NULL, - started_at DATETIME, - completed_at DATETIME, - error TEXT, - result TEXT, - FOREIGN KEY (queue_id) REFERENCES batch_task_queues(id) ON DELETE CASCADE - );` - - // 创建 WebShell 连接表 - createWebshellConnectionsTable := ` - CREATE TABLE IF NOT EXISTS webshell_connections ( - id TEXT PRIMARY KEY, - url TEXT NOT NULL, - password TEXT NOT NULL DEFAULT '', - type TEXT NOT NULL DEFAULT 'php', - method TEXT NOT NULL DEFAULT 'post', - cmd_param TEXT NOT NULL DEFAULT '', - remark TEXT NOT NULL DEFAULT '', - encoding TEXT NOT NULL DEFAULT '', - os TEXT NOT NULL DEFAULT '', - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - );` - - // 创建 WebShell 连接扩展状态表(前端工作区/终端状态持久化) - createWebshellConnectionStatesTable := ` - CREATE TABLE IF NOT EXISTS webshell_connection_states ( - connection_id TEXT PRIMARY KEY, - state_json TEXT NOT NULL DEFAULT '{}', - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE - );` - - // ======================================================================== - // C2 模块(监听器 / 会话 / 任务 / 文件 / 事件 / Malleable Profile) - // ======================================================================== - createC2ListenersTable := ` - CREATE TABLE IF NOT EXISTS c2_listeners ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - type TEXT NOT NULL, - bind_host TEXT NOT NULL DEFAULT '127.0.0.1', - bind_port INTEGER NOT NULL, - profile_id TEXT, - encryption_key TEXT NOT NULL DEFAULT '', - implant_token TEXT NOT NULL DEFAULT '', - status TEXT NOT NULL DEFAULT 'stopped', - config_json TEXT NOT NULL DEFAULT '{}', - remark TEXT NOT NULL DEFAULT '', - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - started_at DATETIME, - last_error TEXT - );` - - createC2SessionsTable := ` - CREATE TABLE IF NOT EXISTS c2_sessions ( - id TEXT PRIMARY KEY, - listener_id TEXT NOT NULL, - implant_uuid TEXT NOT NULL UNIQUE, - hostname TEXT, - username TEXT, - os TEXT, - arch TEXT, - pid INTEGER DEFAULT 0, - process_name TEXT, - is_admin INTEGER DEFAULT 0, - internal_ip TEXT, - external_ip TEXT, - user_agent TEXT, - sleep_seconds INTEGER NOT NULL DEFAULT 5, - jitter_percent INTEGER NOT NULL DEFAULT 0, - status TEXT NOT NULL DEFAULT 'active', - first_seen_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - last_check_in DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - metadata_json TEXT DEFAULT '{}', - note TEXT NOT NULL DEFAULT '', - FOREIGN KEY (listener_id) REFERENCES c2_listeners(id) ON DELETE CASCADE - );` - - createC2TasksTable := ` - CREATE TABLE IF NOT EXISTS c2_tasks ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - task_type TEXT NOT NULL, - payload_json TEXT NOT NULL DEFAULT '{}', - status TEXT NOT NULL DEFAULT 'queued', - result_text TEXT, - result_blob_path TEXT, - error TEXT, - source TEXT NOT NULL DEFAULT 'manual', - conversation_id TEXT, - approval_status TEXT, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - sent_at DATETIME, - started_at DATETIME, - completed_at DATETIME, - duration_ms INTEGER DEFAULT 0, - FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE - );` - - createC2FilesTable := ` - CREATE TABLE IF NOT EXISTS c2_files ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - task_id TEXT, - direction TEXT NOT NULL, - remote_path TEXT NOT NULL, - local_path TEXT NOT NULL, - size_bytes INTEGER DEFAULT 0, - sha256 TEXT, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE - );` - - createC2EventsTable := ` - CREATE TABLE IF NOT EXISTS c2_events ( - id TEXT PRIMARY KEY, - level TEXT NOT NULL DEFAULT 'info', - category TEXT NOT NULL, - session_id TEXT, - task_id TEXT, - message TEXT NOT NULL, - data_json TEXT, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - );` - - createC2ProfilesTable := ` - CREATE TABLE IF NOT EXISTS c2_profiles ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - user_agent TEXT, - uris_json TEXT NOT NULL DEFAULT '[]', - request_headers_json TEXT, - response_headers_json TEXT, - body_template TEXT, - jitter_min_ms INTEGER DEFAULT 0, - jitter_max_ms INTEGER DEFAULT 0, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP - );` - - // 创建索引 - createIndexes := ` - CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); - CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at); - CREATE INDEX IF NOT EXISTS idx_process_details_message_id ON process_details(message_id); - CREATE INDEX IF NOT EXISTS idx_process_details_conversation_id ON process_details(conversation_id); - CREATE INDEX IF NOT EXISTS idx_tool_executions_tool_name ON tool_executions(tool_name); - CREATE INDEX IF NOT EXISTS idx_tool_executions_start_time ON tool_executions(start_time); - CREATE INDEX IF NOT EXISTS idx_tool_executions_status ON tool_executions(status); - CREATE INDEX IF NOT EXISTS idx_chain_nodes_conversation ON attack_chain_nodes(conversation_id); - CREATE INDEX IF NOT EXISTS idx_chain_edges_conversation ON attack_chain_edges(conversation_id); - CREATE INDEX IF NOT EXISTS idx_chain_edges_source ON attack_chain_edges(source_node_id); - CREATE INDEX IF NOT EXISTS idx_chain_edges_target ON attack_chain_edges(target_node_id); - CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id); - CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id); - CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); - CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id); - CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id); - CREATE INDEX IF NOT EXISTS idx_robot_user_sessions_updated_at ON robot_user_sessions(updated_at); - CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned); - CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id); - CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag); - CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag); - CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); - CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); - CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); - CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); - CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at); - CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title); - CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at); - CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at); - CREATE INDEX IF NOT EXISTS idx_c2_listeners_created_at ON c2_listeners(created_at); - CREATE INDEX IF NOT EXISTS idx_c2_listeners_status ON c2_listeners(status); - CREATE INDEX IF NOT EXISTS idx_c2_sessions_listener ON c2_sessions(listener_id); - CREATE INDEX IF NOT EXISTS idx_c2_sessions_status ON c2_sessions(status); - CREATE INDEX IF NOT EXISTS idx_c2_sessions_last_check_in ON c2_sessions(last_check_in); - CREATE INDEX IF NOT EXISTS idx_c2_tasks_session ON c2_tasks(session_id); - CREATE INDEX IF NOT EXISTS idx_c2_tasks_status ON c2_tasks(status); - CREATE INDEX IF NOT EXISTS idx_c2_tasks_created_at ON c2_tasks(created_at); - CREATE INDEX IF NOT EXISTS idx_c2_tasks_conversation ON c2_tasks(conversation_id); - CREATE INDEX IF NOT EXISTS idx_c2_files_session ON c2_files(session_id); - CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at); - CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category); - CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id); - ` - - if _, err := db.Exec(createConversationsTable); err != nil { - return fmt.Errorf("创建conversations表失败: %w", err) - } - - if _, err := db.Exec(createMessagesTable); err != nil { - return fmt.Errorf("创建messages表失败: %w", err) - } - - if _, err := db.Exec(createProcessDetailsTable); err != nil { - return fmt.Errorf("创建process_details表失败: %w", err) - } - - if _, err := db.Exec(createToolExecutionsTable); err != nil { - return fmt.Errorf("创建tool_executions表失败: %w", err) - } - - if _, err := db.Exec(createToolStatsTable); err != nil { - return fmt.Errorf("创建tool_stats表失败: %w", err) - } - - if _, err := db.Exec(createSkillStatsTable); err != nil { - return fmt.Errorf("创建skill_stats表失败: %w", err) - } - - if _, err := db.Exec(createAttackChainNodesTable); err != nil { - return fmt.Errorf("创建attack_chain_nodes表失败: %w", err) - } - - if _, err := db.Exec(createAttackChainEdgesTable); err != nil { - return fmt.Errorf("创建attack_chain_edges表失败: %w", err) - } - - if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil { - return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err) - } - - if _, err := db.Exec(createConversationGroupsTable); err != nil { - return fmt.Errorf("创建conversation_groups表失败: %w", err) - } - - if _, err := db.Exec(createConversationGroupMappingsTable); err != nil { - return fmt.Errorf("创建conversation_group_mappings表失败: %w", err) - } - if _, err := db.Exec(createRobotUserSessionsTable); err != nil { - return fmt.Errorf("创建robot_user_sessions表失败: %w", err) - } - - if _, err := db.Exec(createVulnerabilitiesTable); err != nil { - return fmt.Errorf("创建vulnerabilities表失败: %w", err) - } - - if _, err := db.Exec(createBatchTaskQueuesTable); err != nil { - return fmt.Errorf("创建batch_task_queues表失败: %w", err) - } - - if _, err := db.Exec(createBatchTasksTable); err != nil { - return fmt.Errorf("创建batch_tasks表失败: %w", err) - } - - if _, err := db.Exec(createWebshellConnectionsTable); err != nil { - return fmt.Errorf("创建webshell_connections表失败: %w", err) - } - - if _, err := db.Exec(createWebshellConnectionStatesTable); err != nil { - return fmt.Errorf("创建webshell_connection_states表失败: %w", err) - } - - for tableName, ddl := range map[string]string{ - "c2_listeners": createC2ListenersTable, - "c2_sessions": createC2SessionsTable, - "c2_tasks": createC2TasksTable, - "c2_files": createC2FilesTable, - "c2_events": createC2EventsTable, - "c2_profiles": createC2ProfilesTable, - } { - if _, err := db.Exec(ddl); err != nil { - return fmt.Errorf("创建%s表失败: %w", tableName, err) - } - } - - // 为已有表添加新字段(如果不存在)- 必须在创建索引之前 - if err := db.migrateConversationsTable(); err != nil { - db.logger.Warn("迁移conversations表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - - if err := db.migrateMessagesTable(); err != nil { - db.logger.Warn("迁移messages表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - - if err := db.migrateConversationGroupsTable(); err != nil { - db.logger.Warn("迁移conversation_groups表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - - if err := db.migrateConversationGroupMappingsTable(); err != nil { - db.logger.Warn("迁移conversation_group_mappings表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - - if err := db.migrateBatchTaskQueuesTable(); err != nil { - db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - if err := db.migrateVulnerabilitiesTable(); err != nil { - db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - - if err := db.migrateWebshellConnectionsTable(); err != nil { - db.logger.Warn("迁移webshell_connections表失败", zap.Error(err)) - // 不返回错误,允许继续运行 - } - - if _, err := db.Exec(createIndexes); err != nil { - return fmt.Errorf("创建索引失败: %w", err) - } - - db.logger.Info("数据库表初始化完成") - return nil -} - -// migrateMessagesTable 迁移 messages 表,补充 updated_at 字段。 -// 语义:updated_at 表示该条消息最后一次被写入/更新的时间(例如助手占位消息在任务结束时更新正文)。 -func (db *DB) migrateMessagesTable() error { - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='updated_at'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN updated_at DATETIME"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - return fmt.Errorf("添加 messages.updated_at 字段失败: %w", addErr) - } - } - } else if count == 0 { - if _, err := db.Exec("ALTER TABLE messages ADD COLUMN updated_at DATETIME"); err != nil { - errMsg := strings.ToLower(err.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - return fmt.Errorf("添加 messages.updated_at 字段失败: %w", err) - } - } - } - - // 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。 - _, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''") - - // reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放 - var rcColCount int - errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount) - if errRC != nil { - if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr) - } - } - } else if rcColCount == 0 { - if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil { - errMsg := strings.ToLower(err.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err) - } - } - } - return nil -} - -// migrateConversationsTable 迁移conversations表,添加新字段 -func (db *DB) migrateConversationsTable() error { - // 检查last_react_input字段是否存在 - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil { - // 如果字段已存在,忽略错误(SQLite错误信息可能不同) - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil { - db.logger.Warn("添加last_react_input字段失败", zap.Error(err)) - } - } - - // 检查last_react_output字段是否存在 - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil { - // 如果字段已存在,忽略错误 - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil { - db.logger.Warn("添加last_react_output字段失败", zap.Error(err)) - } - } - - // 检查pinned字段是否存在 - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='pinned'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil { - // 如果字段已存在,忽略错误 - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加pinned字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil { - db.logger.Warn("添加pinned字段失败", zap.Error(err)) - } - } - - // 检查 webshell_connection_id 字段是否存在(WebShell AI 助手对话关联) - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='webshell_connection_id'").Scan(&count) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); err != nil { - db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(err)) - } - } - - return nil -} - -// migrateConversationGroupsTable 迁移conversation_groups表,添加新字段 -func (db *DB) migrateConversationGroupsTable() error { - // 检查pinned字段是否存在 - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_groups') WHERE name='pinned'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil { - // 如果字段已存在,忽略错误 - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加pinned字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil { - db.logger.Warn("添加pinned字段失败", zap.Error(err)) - } - } - - return nil -} - -// migrateConversationGroupMappingsTable 迁移conversation_group_mappings表,添加新字段 -func (db *DB) migrateConversationGroupMappingsTable() error { - // 检查pinned字段是否存在 - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_group_mappings') WHERE name='pinned'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil { - // 如果字段已存在,忽略错误 - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加pinned字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil { - db.logger.Warn("添加pinned字段失败", zap.Error(err)) - } - } - - return nil -} - -// migrateBatchTaskQueuesTable 迁移batch_task_queues表,补充新字段 -func (db *DB) migrateBatchTaskQueuesTable() error { - // 检查title字段是否存在 - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil { - // 如果字段已存在,忽略错误 - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加title字段失败", zap.Error(addErr)) - } - } - } else if count == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil { - db.logger.Warn("添加title字段失败", zap.Error(err)) - } - } - - // 检查role字段是否存在 - var roleCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='role'").Scan(&roleCount) - if err != nil { - // 如果查询失败,尝试添加字段 - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); addErr != nil { - // 如果字段已存在,忽略错误 - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加role字段失败", zap.Error(addErr)) - } - } - } else if roleCount == 0 { - // 字段不存在,添加它 - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); err != nil { - db.logger.Warn("添加role字段失败", zap.Error(err)) - } - } - - // 检查agent_mode字段是否存在 - var agentModeCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr)) - } - } - } else if agentModeCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil { - db.logger.Warn("添加agent_mode字段失败", zap.Error(err)) - } - } - - // 检查schedule_mode字段是否存在 - var scheduleModeCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_mode'").Scan(&scheduleModeCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加schedule_mode字段失败", zap.Error(addErr)) - } - } - } else if scheduleModeCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); err != nil { - db.logger.Warn("添加schedule_mode字段失败", zap.Error(err)) - } - } - - // 检查cron_expr字段是否存在 - var cronExprCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='cron_expr'").Scan(&cronExprCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加cron_expr字段失败", zap.Error(addErr)) - } - } - } else if cronExprCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); err != nil { - db.logger.Warn("添加cron_expr字段失败", zap.Error(err)) - } - } - - // 检查next_run_at字段是否存在 - var nextRunAtCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='next_run_at'").Scan(&nextRunAtCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加next_run_at字段失败", zap.Error(addErr)) - } - } - } else if nextRunAtCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); err != nil { - db.logger.Warn("添加next_run_at字段失败", zap.Error(err)) - } - } - - // schedule_enabled:0=暂停 Cron 自动调度,1=允许(手工执行不受影响) - var scheduleEnCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_enabled'").Scan(&scheduleEnCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加schedule_enabled字段失败", zap.Error(addErr)) - } - } - } else if scheduleEnCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); err != nil { - db.logger.Warn("添加schedule_enabled字段失败", zap.Error(err)) - } - } - - var lastTrigCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_trigger_at'").Scan(&lastTrigCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(addErr)) - } - } - } else if lastTrigCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); err != nil { - db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(err)) - } - } - - var lastSchedErrCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_error'").Scan(&lastSchedErrCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加last_schedule_error字段失败", zap.Error(addErr)) - } - } - } else if lastSchedErrCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); err != nil { - db.logger.Warn("添加last_schedule_error字段失败", zap.Error(err)) - } - } - - var lastRunErrCount int - err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_run_error'").Scan(&lastRunErrCount) - if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加last_run_error字段失败", zap.Error(addErr)) - } - } - } else if lastRunErrCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); err != nil { - db.logger.Warn("添加last_run_error字段失败", zap.Error(err)) - } - } - - return nil -} - -// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段 -func (db *DB) migrateVulnerabilitiesTable() error { - columns := []struct { - name string - stmt string - }{ - {name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"}, - {name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"}, - } - - for _, col := range columns { - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('vulnerabilities') WHERE name=?", col.name).Scan(&count) - if err != nil { - if _, addErr := db.Exec(col.stmt); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr)) - } - } - continue - } - if count == 0 { - if _, addErr := db.Exec(col.stmt); addErr != nil { - db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr)) - } - } - } - return nil -} - -// migrateWebshellConnectionsTable 迁移 webshell_connections 表,补充新字段 -func (db *DB) migrateWebshellConnectionsTable() error { - columns := []struct { - name string - stmt string - }{ - {name: "encoding", stmt: "ALTER TABLE webshell_connections ADD COLUMN encoding TEXT NOT NULL DEFAULT ''"}, - {name: "os", stmt: "ALTER TABLE webshell_connections ADD COLUMN os TEXT NOT NULL DEFAULT ''"}, - } - - for _, col := range columns { - var count int - err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('webshell_connections') WHERE name=?", col.name).Scan(&count) - if err != nil { - if _, addErr := db.Exec(col.stmt); addErr != nil { - errMsg := strings.ToLower(addErr.Error()) - if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { - db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr)) - } - } - continue - } - if count == 0 { - if _, addErr := db.Exec(col.stmt); addErr != nil { - db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr)) - } - } - } - return nil -} - -// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) -func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { - sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") - if err != nil { - return nil, fmt.Errorf("打开知识库数据库失败: %w", err) - } - - configureDBPool(sqlDB) - - if err := sqlDB.Ping(); err != nil { - return nil, fmt.Errorf("连接知识库数据库失败: %w", err) - } - - database := &DB{ - DB: sqlDB, - logger: logger, - } - - // 初始化知识库表 - if err := database.initKnowledgeTables(); err != nil { - return nil, fmt.Errorf("初始化知识库表失败: %w", err) - } - - return database, nil -} - -// initKnowledgeTables 初始化知识库数据库表(只包含知识库相关的表) -func (db *DB) initKnowledgeTables() error { - // 创建知识库项表 - createKnowledgeBaseItemsTable := ` - CREATE TABLE IF NOT EXISTS knowledge_base_items ( - id TEXT PRIMARY KEY, - category TEXT NOT NULL, - title TEXT NOT NULL, - file_path TEXT NOT NULL, - content TEXT, - created_at DATETIME NOT NULL, - updated_at DATETIME NOT NULL - );` - - // 创建知识库向量表 - createKnowledgeEmbeddingsTable := ` - CREATE TABLE IF NOT EXISTS knowledge_embeddings ( - id TEXT PRIMARY KEY, - item_id TEXT NOT NULL, - chunk_index INTEGER NOT NULL, - chunk_text TEXT NOT NULL, - embedding TEXT NOT NULL, - sub_indexes TEXT NOT NULL DEFAULT '', - embedding_model TEXT NOT NULL DEFAULT '', - embedding_dim INTEGER NOT NULL DEFAULT 0, - created_at DATETIME NOT NULL, - FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE - );` - - // 创建知识检索日志表(在独立知识库数据库中,不使用外键约束,因为conversations和messages表可能不在这个数据库中) - createKnowledgeRetrievalLogsTable := ` - CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs ( - id TEXT PRIMARY KEY, - conversation_id TEXT, - message_id TEXT, - query TEXT NOT NULL, - risk_type TEXT, - retrieved_items TEXT, - created_at DATETIME NOT NULL - );` - - // 创建索引 - createIndexes := ` - CREATE INDEX IF NOT EXISTS idx_knowledge_items_category ON knowledge_base_items(category); - CREATE INDEX IF NOT EXISTS idx_knowledge_embeddings_item_id ON knowledge_embeddings(item_id); - CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id); - CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id); - CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); - ` - - if _, err := db.Exec(createKnowledgeBaseItemsTable); err != nil { - return fmt.Errorf("创建knowledge_base_items表失败: %w", err) - } - - if _, err := db.Exec(createKnowledgeEmbeddingsTable); err != nil { - return fmt.Errorf("创建knowledge_embeddings表失败: %w", err) - } - - if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil { - return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err) - } - - if _, err := db.Exec(createIndexes); err != nil { - return fmt.Errorf("创建索引失败: %w", err) - } - - if err := db.migrateKnowledgeEmbeddingsColumns(); err != nil { - return fmt.Errorf("迁移 knowledge_embeddings 列失败: %w", err) - } - - db.logger.Info("知识库数据库表初始化完成") - return nil -} - -// migrateKnowledgeEmbeddingsColumns 为已有库补充 sub_indexes、embedding_model、embedding_dim。 -func (db *DB) migrateKnowledgeEmbeddingsColumns() error { - var n int - if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { - return err - } - if n == 0 { - return nil - } - migrations := []struct { - col string - stmt string - }{ - {"sub_indexes", `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`}, - {"embedding_model", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`}, - {"embedding_dim", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`}, - } - for _, m := range migrations { - var colCount int - q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` - if err := db.QueryRow(q, m.col).Scan(&colCount); err != nil { - return err - } - if colCount > 0 { - continue - } - if _, err := db.Exec(m.stmt); err != nil { - return err - } - } - return nil -} - -// Close 关闭数据库连接 -func (db *DB) Close() error { - return db.DB.Close() -} diff --git a/database/group.go b/database/group.go deleted file mode 100644 index a3d32106..00000000 --- a/database/group.go +++ /dev/null @@ -1,449 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "time" - - "github.com/google/uuid" -) - -// ConversationGroup 对话分组 -type ConversationGroup struct { - ID string `json:"id"` - Name string `json:"name"` - Icon string `json:"icon"` - Pinned bool `json:"pinned"` - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// GroupExistsByName 检查分组名称是否已存在 -func (db *DB) GroupExistsByName(name string, excludeID string) (bool, error) { - var count int - var err error - - if excludeID != "" { - err = db.QueryRow( - "SELECT COUNT(*) FROM conversation_groups WHERE name = ? AND id != ?", - name, excludeID, - ).Scan(&count) - } else { - err = db.QueryRow( - "SELECT COUNT(*) FROM conversation_groups WHERE name = ?", - name, - ).Scan(&count) - } - - if err != nil { - return false, fmt.Errorf("检查分组名称失败: %w", err) - } - - return count > 0, nil -} - -// CreateGroup 创建分组 -func (db *DB) CreateGroup(name, icon string) (*ConversationGroup, error) { - // 检查名称是否已存在 - exists, err := db.GroupExistsByName(name, "") - if err != nil { - return nil, err - } - if exists { - return nil, fmt.Errorf("分组名称已存在") - } - - id := uuid.New().String() - now := time.Now() - - if icon == "" { - icon = "📁" - } - - _, err = db.Exec( - "INSERT INTO conversation_groups (id, name, icon, pinned, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)", - id, name, icon, 0, now, now, - ) - if err != nil { - return nil, fmt.Errorf("创建分组失败: %w", err) - } - - return &ConversationGroup{ - ID: id, - Name: name, - Icon: icon, - Pinned: false, - CreatedAt: now, - UpdatedAt: now, - }, nil -} - -// ListGroups 列出所有分组 -func (db *DB) ListGroups() ([]*ConversationGroup, error) { - rows, err := db.Query( - "SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups ORDER BY COALESCE(pinned, 0) DESC, created_at ASC", - ) - if err != nil { - return nil, fmt.Errorf("查询分组列表失败: %w", err) - } - defer rows.Close() - - var groups []*ConversationGroup - for rows.Next() { - var group ConversationGroup - var createdAt, updatedAt string - var pinned int - - if err := rows.Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt); err != nil { - return nil, fmt.Errorf("扫描分组失败: %w", err) - } - - group.Pinned = pinned != 0 - - // 尝试多种时间格式解析 - var err1, err2 error - group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - groups = append(groups, &group) - } - - return groups, nil -} - -// GetGroup 获取分组 -func (db *DB) GetGroup(id string) (*ConversationGroup, error) { - var group ConversationGroup - var createdAt, updatedAt string - var pinned int - - err := db.QueryRow( - "SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups WHERE id = ?", - id, - ).Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("分组不存在") - } - return nil, fmt.Errorf("查询分组失败: %w", err) - } - - // 尝试多种时间格式解析 - var err1, err2 error - group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - group.Pinned = pinned != 0 - - return &group, nil -} - -// UpdateGroup 更新分组 -func (db *DB) UpdateGroup(id, name, icon string) error { - // 检查名称是否已存在(排除当前分组) - exists, err := db.GroupExistsByName(name, id) - if err != nil { - return err - } - if exists { - return fmt.Errorf("分组名称已存在") - } - - _, err = db.Exec( - "UPDATE conversation_groups SET name = ?, icon = ?, updated_at = ? WHERE id = ?", - name, icon, time.Now(), id, - ) - if err != nil { - return fmt.Errorf("更新分组失败: %w", err) - } - return nil -} - -// DeleteGroup 删除分组 -func (db *DB) DeleteGroup(id string) error { - _, err := db.Exec("DELETE FROM conversation_groups WHERE id = ?", id) - if err != nil { - return fmt.Errorf("删除分组失败: %w", err) - } - return nil -} - -// AddConversationToGroup 将对话添加到分组 -// 注意:一个对话只能属于一个分组,所以在添加新分组之前,会先删除该对话的所有旧分组关联 -func (db *DB) AddConversationToGroup(conversationID, groupID string) error { - // 先删除该对话的所有旧分组关联,确保一个对话只属于一个分组 - _, err := db.Exec( - "DELETE FROM conversation_group_mappings WHERE conversation_id = ?", - conversationID, - ) - if err != nil { - return fmt.Errorf("删除对话旧分组关联失败: %w", err) - } - - // 然后插入新的分组关联 - id := uuid.New().String() - _, err = db.Exec( - "INSERT INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)", - id, conversationID, groupID, time.Now(), - ) - if err != nil { - return fmt.Errorf("添加对话到分组失败: %w", err) - } - return nil -} - -// RemoveConversationFromGroup 从分组中移除对话 -func (db *DB) RemoveConversationFromGroup(conversationID, groupID string) error { - _, err := db.Exec( - "DELETE FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?", - conversationID, groupID, - ) - if err != nil { - return fmt.Errorf("从分组中移除对话失败: %w", err) - } - return nil -} - -// GetConversationsByGroup 获取分组中的所有对话 -func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) { - rows, err := db.Query( - `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned - FROM conversations c - INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id - WHERE cgm.group_id = ? - ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC`, - groupID, - ) - if err != nil { - return nil, fmt.Errorf("查询分组对话失败: %w", err) - } - defer rows.Close() - - var conversations []*Conversation - for rows.Next() { - var conv Conversation - var createdAt, updatedAt string - var pinned int - var groupPinned int - - if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil { - return nil, fmt.Errorf("扫描对话失败: %w", err) - } - - // 尝试多种时间格式解析 - var err1, err2 error - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - conv.Pinned = pinned != 0 - - conversations = append(conversations, &conv) - } - - return conversations, nil -} - -// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配) -func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) { - // 构建SQL查询,支持按标题和消息内容搜索 - // 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复 - query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned - FROM conversations c - INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id - WHERE cgm.group_id = ?` - - args := []interface{}{groupID} - - // 如果有搜索关键词,添加标题和消息内容搜索条件 - if searchQuery != "" { - searchPattern := "%" + searchQuery + "%" - // 搜索标题或消息内容 - // 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题) - query += ` AND ( - LOWER(c.title) LIKE LOWER(?) - OR EXISTS ( - SELECT 1 FROM messages m - WHERE m.conversation_id = c.id - AND LOWER(m.content) LIKE LOWER(?) - ) - )` - args = append(args, searchPattern, searchPattern) - } - - query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC" - - rows, err := db.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("搜索分组对话失败: %w", err) - } - defer rows.Close() - - var conversations []*Conversation - for rows.Next() { - var conv Conversation - var createdAt, updatedAt string - var pinned int - var groupPinned int - - if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil { - return nil, fmt.Errorf("扫描对话失败: %w", err) - } - - // 尝试多种时间格式解析 - var err1, err2 error - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) - if err1 != nil { - conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) - } - if err1 != nil { - conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) - } - - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) - if err2 != nil { - conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) - } - if err2 != nil { - conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - - conv.Pinned = pinned != 0 - - conversations = append(conversations, &conv) - } - - return conversations, nil -} - -// GetGroupByConversation 获取对话所属的分组 -func (db *DB) GetGroupByConversation(conversationID string) (string, error) { - var groupID string - err := db.QueryRow( - "SELECT group_id FROM conversation_group_mappings WHERE conversation_id = ? LIMIT 1", - conversationID, - ).Scan(&groupID) - if err != nil { - if err == sql.ErrNoRows { - return "", nil // 没有分组 - } - return "", fmt.Errorf("查询对话分组失败: %w", err) - } - return groupID, nil -} - -// UpdateConversationPinned 更新对话置顶状态 -func (db *DB) UpdateConversationPinned(id string, pinned bool) error { - pinnedValue := 0 - if pinned { - pinnedValue = 1 - } - // 注意:不更新 updated_at,因为置顶操作不应该改变对话的更新时间 - _, err := db.Exec( - "UPDATE conversations SET pinned = ? WHERE id = ?", - pinnedValue, id, - ) - if err != nil { - return fmt.Errorf("更新对话置顶状态失败: %w", err) - } - return nil -} - -// UpdateGroupPinned 更新分组置顶状态 -func (db *DB) UpdateGroupPinned(id string, pinned bool) error { - pinnedValue := 0 - if pinned { - pinnedValue = 1 - } - _, err := db.Exec( - "UPDATE conversation_groups SET pinned = ?, updated_at = ? WHERE id = ?", - pinnedValue, time.Now(), id, - ) - if err != nil { - return fmt.Errorf("更新分组置顶状态失败: %w", err) - } - return nil -} - -// GroupMapping 分组映射关系 -type GroupMapping struct { - ConversationID string `json:"conversationId"` - GroupID string `json:"groupId"` -} - -// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询) -func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) { - rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings") - if err != nil { - return nil, fmt.Errorf("查询分组映射失败: %w", err) - } - defer rows.Close() - - var mappings []GroupMapping - for rows.Next() { - var m GroupMapping - if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil { - return nil, fmt.Errorf("扫描分组映射失败: %w", err) - } - mappings = append(mappings, m) - } - - if mappings == nil { - mappings = []GroupMapping{} - } - return mappings, nil -} - -// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态 -func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error { - pinnedValue := 0 - if pinned { - pinnedValue = 1 - } - _, err := db.Exec( - "UPDATE conversation_group_mappings SET pinned = ? WHERE conversation_id = ? AND group_id = ?", - pinnedValue, conversationID, groupID, - ) - if err != nil { - return fmt.Errorf("更新分组对话置顶状态失败: %w", err) - } - return nil -} diff --git a/database/monitor.go b/database/monitor.go deleted file mode 100644 index bdfffb61..00000000 --- a/database/monitor.go +++ /dev/null @@ -1,537 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "strings" - "time" - - "cyberstrike-ai/internal/mcp" - - "go.uber.org/zap" -) - -// SaveToolExecution 保存工具执行记录 -func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error { - argsJSON, err := json.Marshal(exec.Arguments) - if err != nil { - db.logger.Warn("序列化执行参数失败", zap.Error(err)) - argsJSON = []byte("{}") - } - - var resultJSON sql.NullString - if exec.Result != nil { - resultBytes, err := json.Marshal(exec.Result) - if err != nil { - db.logger.Warn("序列化执行结果失败", zap.Error(err)) - } else { - resultJSON = sql.NullString{String: string(resultBytes), Valid: true} - } - } - - var errorText sql.NullString - if exec.Error != "" { - errorText = sql.NullString{String: exec.Error, Valid: true} - } - - var endTime sql.NullTime - if exec.EndTime != nil { - endTime = sql.NullTime{Time: *exec.EndTime, Valid: true} - } - - var durationMs sql.NullInt64 - if exec.Duration > 0 { - durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true} - } - - query := ` - INSERT OR REPLACE INTO tool_executions - (id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - - _, err = db.Exec(query, - exec.ID, - exec.ToolName, - string(argsJSON), - exec.Status, - resultJSON, - errorText, - exec.StartTime, - endTime, - durationMs, - time.Now(), - ) - - if err != nil { - db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID)) - return err - } - - return nil -} - -// CountToolExecutions 统计工具执行记录总数 -func (db *DB) CountToolExecutions(status, toolName string) (int, error) { - query := `SELECT COUNT(*) FROM tool_executions` - args := []interface{}{} - conditions := []string{} - if status != "" { - conditions = append(conditions, "status = ?") - args = append(args, status) - } - if toolName != "" { - // 支持部分匹配(模糊搜索),不区分大小写 - conditions = append(conditions, "LOWER(tool_name) LIKE ?") - args = append(args, "%"+strings.ToLower(toolName)+"%") - } - if len(conditions) > 0 { - query += ` WHERE ` + conditions[0] - for i := 1; i < len(conditions); i++ { - query += ` AND ` + conditions[i] - } - } - var count int - err := db.QueryRow(query, args...).Scan(&count) - if err != nil { - return 0, err - } - return count, nil -} - -// LoadToolExecutions 加载所有工具执行记录(支持分页) -func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) { - return db.LoadToolExecutionsWithPagination(0, 1000, "", "") -} - -// LoadToolExecutionsWithPagination 分页加载工具执行记录 -// limit: 最大返回记录数,0 表示使用默认值 1000 -// offset: 跳过的记录数,用于分页 -// status: 状态筛选,空字符串表示不过滤 -// toolName: 工具名称筛选,空字符串表示不过滤 -func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) { - if limit <= 0 { - limit = 1000 // 默认限制 - } - if limit > 10000 { - limit = 10000 // 最大限制,防止一次性加载过多数据 - } - - query := ` - SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms - FROM tool_executions - ` - args := []interface{}{} - conditions := []string{} - if status != "" { - conditions = append(conditions, "status = ?") - args = append(args, status) - } - if toolName != "" { - // 支持部分匹配(模糊搜索),不区分大小写 - conditions = append(conditions, "LOWER(tool_name) LIKE ?") - args = append(args, "%"+strings.ToLower(toolName)+"%") - } - if len(conditions) > 0 { - query += ` WHERE ` + conditions[0] - for i := 1; i < len(conditions); i++ { - query += ` AND ` + conditions[i] - } - } - query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?` - args = append(args, limit, offset) - - rows, err := db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - var executions []*mcp.ToolExecution - for rows.Next() { - var exec mcp.ToolExecution - var argsJSON string - var resultJSON sql.NullString - var errorText sql.NullString - var endTime sql.NullTime - var durationMs sql.NullInt64 - - err := rows.Scan( - &exec.ID, - &exec.ToolName, - &argsJSON, - &exec.Status, - &resultJSON, - &errorText, - &exec.StartTime, - &endTime, - &durationMs, - ) - if err != nil { - db.logger.Warn("加载执行记录失败", zap.Error(err)) - continue - } - - // 解析参数 - if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { - db.logger.Warn("解析执行参数失败", zap.Error(err)) - exec.Arguments = make(map[string]interface{}) - } - - // 解析结果 - if resultJSON.Valid && resultJSON.String != "" { - var result mcp.ToolResult - if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { - db.logger.Warn("解析执行结果失败", zap.Error(err)) - } else { - exec.Result = &result - } - } - - // 设置错误 - if errorText.Valid { - exec.Error = errorText.String - } - - // 设置结束时间 - if endTime.Valid { - exec.EndTime = &endTime.Time - } - - // 设置持续时间 - if durationMs.Valid { - exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond - } - - executions = append(executions, &exec) - } - - return executions, nil -} - -// GetToolExecution 根据ID获取单条工具执行记录 -func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) { - query := ` - SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms - FROM tool_executions - WHERE id = ? - ` - - row := db.QueryRow(query, id) - - var exec mcp.ToolExecution - var argsJSON string - var resultJSON sql.NullString - var errorText sql.NullString - var endTime sql.NullTime - var durationMs sql.NullInt64 - - err := row.Scan( - &exec.ID, - &exec.ToolName, - &argsJSON, - &exec.Status, - &resultJSON, - &errorText, - &exec.StartTime, - &endTime, - &durationMs, - ) - if err != nil { - return nil, err - } - - if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { - db.logger.Warn("解析执行参数失败", zap.Error(err)) - exec.Arguments = make(map[string]interface{}) - } - - if resultJSON.Valid && resultJSON.String != "" { - var result mcp.ToolResult - if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { - db.logger.Warn("解析执行结果失败", zap.Error(err)) - } else { - exec.Result = &result - } - } - - if errorText.Valid { - exec.Error = errorText.String - } - - if endTime.Valid { - exec.EndTime = &endTime.Time - } - - if durationMs.Valid { - exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond - } - - return &exec, nil -} - -// DeleteToolExecution 删除工具执行记录 -func (db *DB) DeleteToolExecution(id string) error { - query := `DELETE FROM tool_executions WHERE id = ?` - _, err := db.Exec(query, id) - if err != nil { - db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id)) - return err - } - return nil -} - -// DeleteToolExecutions 批量删除工具执行记录 -func (db *DB) DeleteToolExecutions(ids []string) error { - if len(ids) == 0 { - return nil - } - - // 构建 IN 查询的占位符 - placeholders := make([]string, len(ids)) - args := make([]interface{}, len(ids)) - for i, id := range ids { - placeholders[i] = "?" - args[i] = id - } - - query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)` - _, err := db.Exec(query, args...) - if err != nil { - db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids))) - return err - } - return nil -} - -// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息) -func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) { - if len(ids) == 0 { - return []*mcp.ToolExecution{}, nil - } - - // 构建 IN 查询的占位符 - placeholders := make([]string, len(ids)) - args := make([]interface{}, len(ids)) - for i, id := range ids { - placeholders[i] = "?" - args[i] = id - } - - query := ` - SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms - FROM tool_executions - WHERE id IN (` + strings.Join(placeholders, ",") + `) - ` - - rows, err := db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - var executions []*mcp.ToolExecution - for rows.Next() { - var exec mcp.ToolExecution - var argsJSON string - var resultJSON sql.NullString - var errorText sql.NullString - var endTime sql.NullTime - var durationMs sql.NullInt64 - - err := rows.Scan( - &exec.ID, - &exec.ToolName, - &argsJSON, - &exec.Status, - &resultJSON, - &errorText, - &exec.StartTime, - &endTime, - &durationMs, - ) - if err != nil { - db.logger.Warn("加载执行记录失败", zap.Error(err)) - continue - } - - // 解析参数 - if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { - db.logger.Warn("解析执行参数失败", zap.Error(err)) - exec.Arguments = make(map[string]interface{}) - } - - // 解析结果 - if resultJSON.Valid && resultJSON.String != "" { - var result mcp.ToolResult - if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { - db.logger.Warn("解析执行结果失败", zap.Error(err)) - } else { - exec.Result = &result - } - } - - // 设置错误 - if errorText.Valid { - exec.Error = errorText.String - } - - // 设置结束时间 - if endTime.Valid { - exec.EndTime = &endTime.Time - } - - // 设置持续时间 - if durationMs.Valid { - exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond - } - - executions = append(executions, &exec) - } - - return executions, nil -} - -// SaveToolStats 保存工具统计信息 -func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error { - var lastCallTime sql.NullTime - if stats.LastCallTime != nil { - lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true} - } - - query := ` - INSERT OR REPLACE INTO tool_stats - (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ` - - _, err := db.Exec(query, - toolName, - stats.TotalCalls, - stats.SuccessCalls, - stats.FailedCalls, - lastCallTime, - time.Now(), - ) - - if err != nil { - db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) - return err - } - - return nil -} - -// LoadToolStats 加载所有工具统计信息 -func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) { - query := ` - SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time - FROM tool_stats - ` - - rows, err := db.Query(query) - if err != nil { - return nil, err - } - defer rows.Close() - - stats := make(map[string]*mcp.ToolStats) - for rows.Next() { - var stat mcp.ToolStats - var lastCallTime sql.NullTime - - err := rows.Scan( - &stat.ToolName, - &stat.TotalCalls, - &stat.SuccessCalls, - &stat.FailedCalls, - &lastCallTime, - ) - if err != nil { - db.logger.Warn("加载统计信息失败", zap.Error(err)) - continue - } - - if lastCallTime.Valid { - stat.LastCallTime = &lastCallTime.Time - } - - stats[stat.ToolName] = &stat - } - - return stats, nil -} - -// UpdateToolStats 更新工具统计信息(累加模式) -func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error { - var lastCallTimeSQL sql.NullTime - if lastCallTime != nil { - lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true} - } - - query := ` - INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(tool_name) DO UPDATE SET - total_calls = total_calls + ?, - success_calls = success_calls + ?, - failed_calls = failed_calls + ?, - last_call_time = COALESCE(?, last_call_time), - updated_at = ? - ` - - _, err := db.Exec(query, - toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), - totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), - ) - - if err != nil { - db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) - return err - } - - return nil -} - -// DecreaseToolStats 减少工具统计信息(用于删除执行记录时) -// 如果统计信息变为0,则删除该统计记录 -func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error { - // 先更新统计信息 - query := ` - UPDATE tool_stats SET - total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END, - success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END, - failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END, - updated_at = ? - WHERE tool_name = ? - ` - - _, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName) - if err != nil { - db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) - return err - } - - // 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录 - checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?` - var newTotalCalls int - err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls) - if err != nil { - // 如果查询失败(记录不存在),直接返回 - return nil - } - - // 如果 total_calls 为 0,删除该统计记录 - if newTotalCalls == 0 { - deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?` - _, err = db.Exec(deleteQuery, toolName) - if err != nil { - db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName)) - // 不返回错误,因为主要操作(更新统计)已成功 - } else { - db.logger.Info("已删除零统计记录", zap.String("toolName", toolName)) - } - } - - return nil -} diff --git a/database/robot_session.go b/database/robot_session.go deleted file mode 100644 index b7631260..00000000 --- a/database/robot_session.go +++ /dev/null @@ -1,84 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "strings" - "time" -) - -// RobotSessionBinding 机器人会话绑定信息。 -type RobotSessionBinding struct { - SessionKey string - ConversationID string - RoleName string - UpdatedAt time.Time -} - -// GetRobotSessionBinding 按 session_key 获取机器人会话绑定。 -func (db *DB) GetRobotSessionBinding(sessionKey string) (*RobotSessionBinding, error) { - sessionKey = strings.TrimSpace(sessionKey) - if sessionKey == "" { - return nil, nil - } - var b RobotSessionBinding - var updatedAt string - err := db.QueryRow( - "SELECT session_key, conversation_id, role_name, updated_at FROM robot_user_sessions WHERE session_key = ?", - sessionKey, - ).Scan(&b.SessionKey, &b.ConversationID, &b.RoleName, &updatedAt) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, fmt.Errorf("查询机器人会话绑定失败: %w", err) - } - if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { - b.UpdatedAt = t - } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { - b.UpdatedAt = t - } else { - b.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) - } - if strings.TrimSpace(b.RoleName) == "" { - b.RoleName = "默认" - } - return &b, nil -} - -// UpsertRobotSessionBinding 写入或更新机器人会话绑定(包含角色)。 -func (db *DB) UpsertRobotSessionBinding(sessionKey, conversationID, roleName string) error { - sessionKey = strings.TrimSpace(sessionKey) - conversationID = strings.TrimSpace(conversationID) - roleName = strings.TrimSpace(roleName) - if sessionKey == "" || conversationID == "" { - return nil - } - if roleName == "" { - roleName = "默认" - } - _, err := db.Exec(` - INSERT INTO robot_user_sessions (session_key, conversation_id, role_name, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(session_key) DO UPDATE SET - conversation_id = excluded.conversation_id, - role_name = excluded.role_name, - updated_at = excluded.updated_at - `, sessionKey, conversationID, roleName, time.Now()) - if err != nil { - return fmt.Errorf("写入机器人会话绑定失败: %w", err) - } - return nil -} - -// DeleteRobotSessionBinding 删除机器人会话绑定。 -func (db *DB) DeleteRobotSessionBinding(sessionKey string) error { - sessionKey = strings.TrimSpace(sessionKey) - if sessionKey == "" { - return nil - } - if _, err := db.Exec("DELETE FROM robot_user_sessions WHERE session_key = ?", sessionKey); err != nil { - return fmt.Errorf("删除机器人会话绑定失败: %w", err) - } - return nil -} diff --git a/database/skill_stats.go b/database/skill_stats.go deleted file mode 100644 index 24e15585..00000000 --- a/database/skill_stats.go +++ /dev/null @@ -1,142 +0,0 @@ -package database - -import ( - "database/sql" - "time" - - "go.uber.org/zap" -) - -// SkillStats Skills统计信息 -type SkillStats struct { - SkillName string - TotalCalls int - SuccessCalls int - FailedCalls int - LastCallTime *time.Time -} - -// SaveSkillStats 保存Skills统计信息 -func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error { - var lastCallTime sql.NullTime - if stats.LastCallTime != nil { - lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true} - } - - query := ` - INSERT OR REPLACE INTO skill_stats - (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ` - - _, err := db.Exec(query, - skillName, - stats.TotalCalls, - stats.SuccessCalls, - stats.FailedCalls, - lastCallTime, - time.Now(), - ) - - if err != nil { - db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName)) - return err - } - - return nil -} - -// LoadSkillStats 加载所有Skills统计信息 -func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) { - query := ` - SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time - FROM skill_stats - ` - - rows, err := db.Query(query) - if err != nil { - return nil, err - } - defer rows.Close() - - stats := make(map[string]*SkillStats) - for rows.Next() { - var stat SkillStats - var lastCallTime sql.NullTime - - err := rows.Scan( - &stat.SkillName, - &stat.TotalCalls, - &stat.SuccessCalls, - &stat.FailedCalls, - &lastCallTime, - ) - if err != nil { - db.logger.Warn("加载Skills统计信息失败", zap.Error(err)) - continue - } - - if lastCallTime.Valid { - stat.LastCallTime = &lastCallTime.Time - } - - stats[stat.SkillName] = &stat - } - - return stats, nil -} - -// UpdateSkillStats 更新Skills统计信息(累加模式) -func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error { - var lastCallTimeSQL sql.NullTime - if lastCallTime != nil { - lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true} - } - - query := ` - INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(skill_name) DO UPDATE SET - total_calls = total_calls + ?, - success_calls = success_calls + ?, - failed_calls = failed_calls + ?, - last_call_time = COALESCE(?, last_call_time), - updated_at = ? - ` - - _, err := db.Exec(query, - skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), - totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), - ) - - if err != nil { - db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName)) - return err - } - - return nil -} - -// ClearSkillStats 清空所有Skills统计信息 -func (db *DB) ClearSkillStats() error { - query := `DELETE FROM skill_stats` - _, err := db.Exec(query) - if err != nil { - db.logger.Error("清空Skills统计信息失败", zap.Error(err)) - return err - } - db.logger.Info("已清空所有Skills统计信息") - return nil -} - -// ClearSkillStatsByName 清空指定skill的统计信息 -func (db *DB) ClearSkillStatsByName(skillName string) error { - query := `DELETE FROM skill_stats WHERE skill_name = ?` - _, err := db.Exec(query, skillName) - if err != nil { - db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName)) - return err - } - db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName)) - return nil -} diff --git a/database/vulnerability.go b/database/vulnerability.go deleted file mode 100644 index 1a584bf6..00000000 --- a/database/vulnerability.go +++ /dev/null @@ -1,369 +0,0 @@ -package database - -import ( - "database/sql" - "fmt" - "time" - - "github.com/google/uuid" - "go.uber.org/zap" -) - -// Vulnerability 漏洞 -type Vulnerability struct { - ID string `json:"id"` - ConversationID string `json:"conversation_id"` - ConversationTag string `json:"conversation_tag,omitempty"` - TaskTag string `json:"task_tag,omitempty"` - TaskID string `json:"task_id,omitempty"` - TaskQueueID string `json:"task_queue_id,omitempty"` - Title string `json:"title"` - Description string `json:"description"` - Severity string `json:"severity"` // critical, high, medium, low, info - Status string `json:"status"` // open, confirmed, fixed, false_positive - Type string `json:"type"` - Target string `json:"target"` - Proof string `json:"proof"` - Impact string `json:"impact"` - Recommendation string `json:"recommendation"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// CreateVulnerability 创建漏洞 -func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { - if vuln.ID == "" { - vuln.ID = uuid.New().String() - } - if vuln.Status == "" { - vuln.Status = "open" - } - now := time.Now() - if vuln.CreatedAt.IsZero() { - vuln.CreatedAt = now - } - vuln.UpdatedAt = now - - query := ` - INSERT INTO vulnerabilities ( - id, conversation_id, conversation_tag, task_tag, title, description, severity, status, - vulnerability_type, target, proof, impact, recommendation, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - - _, err := db.Exec( - query, - vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, - vuln.Severity, vuln.Status, vuln.Type, vuln.Target, - vuln.Proof, vuln.Impact, vuln.Recommendation, - vuln.CreatedAt, vuln.UpdatedAt, - ) - if err != nil { - return nil, fmt.Errorf("创建漏洞失败: %w", err) - } - - return vuln, nil -} - -// GetVulnerability 获取漏洞 -func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { - var vuln Vulnerability - query := ` - SELECT id, conversation_id, title, description, severity, status, - conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, - COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, - COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, - created_at, updated_at - FROM vulnerabilities - WHERE id = ? - ` - - err := db.QueryRow(query, id).Scan( - &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, - &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, - &vuln.Proof, &vuln.Impact, &vuln.Recommendation, - &vuln.TaskID, &vuln.TaskQueueID, - &vuln.CreatedAt, &vuln.UpdatedAt, - ) - if err != nil { - if err == sql.ErrNoRows { - return nil, fmt.Errorf("漏洞不存在") - } - return nil, fmt.Errorf("获取漏洞失败: %w", err) - } - - return &vuln, nil -} - -// ListVulnerabilities 列出漏洞 -func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) { - query := ` - SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, - vulnerability_type, target, proof, impact, recommendation, - COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, - COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, - created_at, updated_at - FROM vulnerabilities - WHERE 1=1 - ` - args := []interface{}{} - - if id != "" { - query += " AND id = ?" - args = append(args, id) - } - if conversationID != "" { - query += " AND conversation_id = ?" - args = append(args, conversationID) - } - if taskID != "" { - query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" - args = append(args, taskID, taskID) - } - if conversationTag != "" { - query += " AND conversation_tag = ?" - args = append(args, conversationTag) - } - if taskTag != "" { - query += " AND task_tag = ?" - args = append(args, taskTag) - } - if severity != "" { - query += " AND severity = ?" - args = append(args, severity) - } - if status != "" { - query += " AND status = ?" - args = append(args, status) - } - - query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" - args = append(args, limit, offset) - - rows, err := db.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("查询漏洞列表失败: %w", err) - } - defer rows.Close() - - var vulnerabilities []*Vulnerability - for rows.Next() { - var vuln Vulnerability - err := rows.Scan( - &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, - &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, - &vuln.Proof, &vuln.Impact, &vuln.Recommendation, - &vuln.TaskID, &vuln.TaskQueueID, - &vuln.CreatedAt, &vuln.UpdatedAt, - ) - if err != nil { - db.logger.Warn("扫描漏洞记录失败", zap.Error(err)) - continue - } - vulnerabilities = append(vulnerabilities, &vuln) - } - - return vulnerabilities, nil -} - -// CountVulnerabilities 统计漏洞总数(支持筛选条件) -func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) { - query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" - args := []interface{}{} - - if id != "" { - query += " AND id = ?" - args = append(args, id) - } - if conversationID != "" { - query += " AND conversation_id = ?" - args = append(args, conversationID) - } - if taskID != "" { - query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" - args = append(args, taskID, taskID) - } - if conversationTag != "" { - query += " AND conversation_tag = ?" - args = append(args, conversationTag) - } - if taskTag != "" { - query += " AND task_tag = ?" - args = append(args, taskTag) - } - if severity != "" { - query += " AND severity = ?" - args = append(args, severity) - } - if status != "" { - query += " AND status = ?" - args = append(args, status) - } - - var count int - err := db.QueryRow(query, args...).Scan(&count) - if err != nil { - return 0, fmt.Errorf("统计漏洞总数失败: %w", err) - } - - return count, nil -} - -// UpdateVulnerability 更新漏洞 -func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { - vuln.UpdatedAt = time.Now() - - query := ` - UPDATE vulnerabilities - SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?, - vulnerability_type = ?, target = ?, proof = ?, impact = ?, - recommendation = ?, updated_at = ? - WHERE id = ? - ` - - _, err := db.Exec( - query, - vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, - vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, - vuln.Recommendation, vuln.UpdatedAt, id, - ) - if err != nil { - return fmt.Errorf("更新漏洞失败: %w", err) - } - - return nil -} - -// DeleteVulnerability 删除漏洞 -func (db *DB) DeleteVulnerability(id string) error { - _, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id) - if err != nil { - return fmt.Errorf("删除漏洞失败: %w", err) - } - return nil -} - -// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致) -func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) { - stats := make(map[string]interface{}) - - where := "WHERE 1=1" - args := []interface{}{} - if conversationID != "" { - where += " AND conversation_id = ?" - args = append(args, conversationID) - } - if taskID != "" { - where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" - args = append(args, taskID, taskID) - } - - // 总漏洞数 - var totalCount int - query := "SELECT COUNT(*) FROM vulnerabilities " + where - err := db.QueryRow(query, args...).Scan(&totalCount) - if err != nil { - return nil, fmt.Errorf("获取总漏洞数失败: %w", err) - } - stats["total"] = totalCount - - // 按严重程度统计 - severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity" - - rows, err := db.Query(severityQuery, args...) - if err != nil { - return nil, fmt.Errorf("获取严重程度统计失败: %w", err) - } - defer rows.Close() - - severityStats := make(map[string]int) - for rows.Next() { - var severity string - var count int - if err := rows.Scan(&severity, &count); err != nil { - continue - } - severityStats[severity] = count - } - stats["by_severity"] = severityStats - - // 按状态统计 - statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status" - - rows, err = db.Query(statusQuery, args...) - if err != nil { - return nil, fmt.Errorf("获取状态统计失败: %w", err) - } - defer rows.Close() - - statusStats := make(map[string]int) - for rows.Next() { - var status string - var count int - if err := rows.Scan(&status, &count); err != nil { - continue - } - statusStats[status] = count - } - stats["by_status"] = statusStats - - return stats, nil -} - -// GetVulnerabilityFilterOptions 获取漏洞筛选建议项 -func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) { - collect := func(query string, args ...interface{}) ([]string, error) { - rows, err := db.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - items := make([]string, 0) - for rows.Next() { - var val string - if err := rows.Scan(&val); err != nil { - continue - } - if val == "" { - continue - } - items = append(items, val) - } - return items, nil - } - - vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`) - if err != nil { - return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err) - } - conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`) - if err != nil { - return nil, fmt.Errorf("查询会话ID建议失败: %w", err) - } - taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`) - if err != nil { - return nil, fmt.Errorf("查询任务ID建议失败: %w", err) - } - queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`) - if err != nil { - return nil, fmt.Errorf("查询队列ID建议失败: %w", err) - } - conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`) - if err != nil { - return nil, fmt.Errorf("查询对话标签建议失败: %w", err) - } - taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`) - if err != nil { - return nil, fmt.Errorf("查询任务标签建议失败: %w", err) - } - - return map[string][]string{ - "vulnerability_ids": vulnIDs, - "conversation_ids": conversationIDs, - "task_ids": taskIDs, - "queue_ids": queueIDs, - "conversation_tags": conversationTags, - "task_tags": taskTags, - }, nil -} diff --git a/database/webshell.go b/database/webshell.go deleted file mode 100644 index db4e912f..00000000 --- a/database/webshell.go +++ /dev/null @@ -1,152 +0,0 @@ -package database - -import ( - "database/sql" - "time" - - "go.uber.org/zap" -) - -// WebShellConnection WebShell 连接配置 -type WebShellConnection struct { - ID string `json:"id"` - URL string `json:"url"` - Password string `json:"password"` - Type string `json:"type"` - Method string `json:"method"` - CmdParam string `json:"cmdParam"` - Remark string `json:"remark"` - Encoding string `json:"encoding"` // 目标响应编码:auto / utf-8 / gbk / gb18030,空值视为 auto - OS string `json:"os"` // 目标操作系统:auto / linux / windows,空值/未知视为 auto - CreatedAt time.Time `json:"createdAt"` -} - -// GetWebshellConnectionState 获取连接关联的持久化状态 JSON,不存在时返回 "{}" -func (db *DB) GetWebshellConnectionState(connectionID string) (string, error) { - var stateJSON string - err := db.QueryRow(`SELECT state_json FROM webshell_connection_states WHERE connection_id = ?`, connectionID).Scan(&stateJSON) - if err == sql.ErrNoRows { - return "{}", nil - } - if err != nil { - db.logger.Error("查询 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID)) - return "", err - } - if stateJSON == "" { - stateJSON = "{}" - } - return stateJSON, nil -} - -// UpsertWebshellConnectionState 保存连接关联的持久化状态 JSON -func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) error { - if stateJSON == "" { - stateJSON = "{}" - } - query := ` - INSERT INTO webshell_connection_states (connection_id, state_json, updated_at) - VALUES (?, ?, ?) - ON CONFLICT(connection_id) DO UPDATE SET - state_json = excluded.state_json, - updated_at = excluded.updated_at - ` - if _, err := db.Exec(query, connectionID, stateJSON, time.Now()); err != nil { - db.logger.Error("保存 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID)) - return err - } - return nil -} - -// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序 -func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) { - query := ` - SELECT id, url, password, type, method, cmd_param, remark, - COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at - FROM webshell_connections - ORDER BY created_at DESC - ` - rows, err := db.Query(query) - if err != nil { - db.logger.Error("查询 WebShell 连接列表失败", zap.Error(err)) - return nil, err - } - defer rows.Close() - - var list []WebShellConnection - for rows.Next() { - var c WebShellConnection - err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt) - if err != nil { - db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err)) - continue - } - list = append(list, c) - } - return list, rows.Err() -} - -// GetWebshellConnection 根据 ID 获取一条连接 -func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) { - query := ` - SELECT id, url, password, type, method, cmd_param, remark, - COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at - FROM webshell_connections WHERE id = ? - ` - var c WebShellConnection - err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - db.logger.Error("查询 WebShell 连接失败", zap.Error(err), zap.String("id", id)) - return nil, err - } - return &c, nil -} - -// CreateWebshellConnection 创建 WebShell 连接 -func (db *DB) CreateWebshellConnection(c *WebShellConnection) error { - query := ` - INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, encoding, os, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - _, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.CreatedAt) - if err != nil { - db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) - return err - } - return nil -} - -// UpdateWebshellConnection 更新 WebShell 连接 -func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error { - query := ` - UPDATE webshell_connections - SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?, encoding = ?, os = ? - WHERE id = ? - ` - result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.ID) - if err != nil { - db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) - return err - } - affected, _ := result.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -} - -// DeleteWebshellConnection 删除 WebShell 连接 -func (db *DB) DeleteWebshellConnection(id string) error { - result, err := db.Exec(`DELETE FROM webshell_connections WHERE id = ?`, id) - if err != nil { - db.logger.Error("删除 WebShell 连接失败", zap.Error(err), zap.String("id", id)) - return err - } - affected, _ := result.RowsAffected() - if affected == 0 { - return sql.ErrNoRows - } - return nil -}