Add files via upload

This commit is contained in:
公明
2026-06-18 12:40:54 +08:00
committed by GitHub
parent 56faefaaf9
commit d5a0f93c6c
94 changed files with 24645 additions and 0 deletions
+167
View File
@@ -0,0 +1,167 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"go.uber.org/zap"
)
// AttackChainNode 攻击链节点
type AttackChainNode struct {
ID string `json:"id"`
Type string `json:"type"` // tool, vulnerability, target, exploit
Label string `json:"label"`
ToolExecutionID string `json:"tool_execution_id,omitempty"`
Metadata map[string]interface{} `json:"metadata"`
RiskScore int `json:"risk_score"`
}
// AttackChainEdge 攻击链边
type AttackChainEdge struct {
ID string `json:"id"`
Source string `json:"source"`
Target string `json:"target"`
Type string `json:"type"` // leads_to, exploits, enables, depends_on
Weight int `json:"weight"`
}
// SaveAttackChainNode 保存攻击链节点
func (db *DB) SaveAttackChainNode(conversationID, nodeID, nodeType, nodeName, toolExecutionID, metadata string, riskScore int) error {
var toolExecID sql.NullString
if toolExecutionID != "" {
toolExecID = sql.NullString{String: toolExecutionID, Valid: true}
}
var metadataJSON sql.NullString
if metadata != "" {
metadataJSON = sql.NullString{String: metadata, Valid: true}
}
query := `
INSERT OR REPLACE INTO attack_chain_nodes
(id, conversation_id, node_type, node_name, tool_execution_id, metadata, risk_score, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
`
_, err := db.Exec(query, nodeID, conversationID, nodeType, nodeName, toolExecID, metadataJSON, riskScore)
if err != nil {
db.logger.Error("保存攻击链节点失败", zap.Error(err), zap.String("nodeId", nodeID))
return err
}
return nil
}
// SaveAttackChainEdge 保存攻击链边
func (db *DB) SaveAttackChainEdge(conversationID, edgeID, sourceNodeID, targetNodeID, edgeType string, weight int) error {
query := `
INSERT OR REPLACE INTO attack_chain_edges
(id, conversation_id, source_node_id, target_node_id, edge_type, weight, created_at)
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
`
_, err := db.Exec(query, edgeID, conversationID, sourceNodeID, targetNodeID, edgeType, weight)
if err != nil {
db.logger.Error("保存攻击链边失败", zap.Error(err), zap.String("edgeId", edgeID))
return err
}
return nil
}
// LoadAttackChainNodes 加载攻击链节点
func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, error) {
query := `
SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score
FROM attack_chain_nodes
WHERE conversation_id = ?
ORDER BY created_at ASC, rowid ASC
`
rows, err := db.Query(query, conversationID)
if err != nil {
return nil, fmt.Errorf("查询攻击链节点失败: %w", err)
}
defer rows.Close()
var nodes []AttackChainNode
for rows.Next() {
var node AttackChainNode
var toolExecID sql.NullString
var metadataJSON sql.NullString
err := rows.Scan(&node.ID, &node.Type, &node.Label, &toolExecID, &metadataJSON, &node.RiskScore)
if err != nil {
db.logger.Warn("扫描攻击链节点失败", zap.Error(err))
continue
}
if toolExecID.Valid {
node.ToolExecutionID = toolExecID.String
}
if metadataJSON.Valid && metadataJSON.String != "" {
if err := json.Unmarshal([]byte(metadataJSON.String), &node.Metadata); err != nil {
db.logger.Warn("解析节点元数据失败", zap.Error(err))
node.Metadata = make(map[string]interface{})
}
} else {
node.Metadata = make(map[string]interface{})
}
nodes = append(nodes, node)
}
return nodes, nil
}
// LoadAttackChainEdges 加载攻击链边
func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, error) {
query := `
SELECT id, source_node_id, target_node_id, edge_type, weight
FROM attack_chain_edges
WHERE conversation_id = ?
ORDER BY created_at ASC, rowid ASC
`
rows, err := db.Query(query, conversationID)
if err != nil {
return nil, fmt.Errorf("查询攻击链边失败: %w", err)
}
defer rows.Close()
var edges []AttackChainEdge
for rows.Next() {
var edge AttackChainEdge
err := rows.Scan(&edge.ID, &edge.Source, &edge.Target, &edge.Type, &edge.Weight)
if err != nil {
db.logger.Warn("扫描攻击链边失败", zap.Error(err))
continue
}
edges = append(edges, edge)
}
return edges, nil
}
// DeleteAttackChain 删除对话的攻击链数据
func (db *DB) DeleteAttackChain(conversationID string) error {
// 先删除边(因为有外键约束)
_, err := db.Exec("DELETE FROM attack_chain_edges WHERE conversation_id = ?", conversationID)
if err != nil {
db.logger.Warn("删除攻击链边失败", zap.Error(err))
}
// 再删除节点
_, err = db.Exec("DELETE FROM attack_chain_nodes WHERE conversation_id = ?", conversationID)
if err != nil {
db.logger.Error("删除攻击链节点失败", zap.Error(err), zap.String("conversationId", conversationID))
return err
}
return nil
}
+212
View File
@@ -0,0 +1,212 @@
package database
import (
"encoding/json"
"errors"
"strings"
"time"
)
// AuditLog platform operation audit record.
type AuditLog struct {
ID string `json:"id"`
CreatedAt time.Time `json:"createdAt"`
Level string `json:"level"`
Category string `json:"category"`
Action string `json:"action"`
Result string `json:"result"`
Actor string `json:"actor"`
SessionHint string `json:"sessionHint,omitempty"`
ClientIP string `json:"clientIp,omitempty"`
UserAgent string `json:"userAgent,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
ResourceID string `json:"resourceId,omitempty"`
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
Message string `json:"message"`
Detail map[string]interface{} `json:"detail,omitempty"`
}
// ListAuditLogsFilter query parameters.
type ListAuditLogsFilter struct {
Level string
Category string
Action string
Result string
Query string
ResourceType string
ResourceID string
Since *time.Time
Until *time.Time
Limit int
Offset int
}
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
conditions := []string{"1=1"}
args := []interface{}{}
if filter.Level != "" {
conditions = append(conditions, "level = ?")
args = append(args, filter.Level)
}
if filter.Category != "" {
conditions = append(conditions, "category = ?")
args = append(args, filter.Category)
}
if filter.Action != "" {
conditions = append(conditions, "action = ?")
args = append(args, filter.Action)
}
if filter.Result != "" {
conditions = append(conditions, "result = ?")
args = append(args, filter.Result)
}
if filter.ResourceType != "" {
conditions = append(conditions, "resource_type = ?")
args = append(args, filter.ResourceType)
}
if filter.ResourceID != "" {
conditions = append(conditions, "resource_id = ?")
args = append(args, filter.ResourceID)
}
if filter.Since != nil {
conditions = append(conditions, sqliteEpochGE("created_at", ">="))
args = append(args, formatSQLiteUTC(*filter.Since))
}
if filter.Until != nil {
conditions = append(conditions, sqliteEpochGE("created_at", "<="))
args = append(args, formatSQLiteUTC(*filter.Until))
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
args = append(args, like, like, like, like)
}
return strings.Join(conditions, " AND "), args
}
// AppendAuditLog inserts one audit row.
func (db *DB) AppendAuditLog(row *AuditLog) error {
if row == nil {
return errors.New("audit log is nil")
}
if strings.TrimSpace(row.ID) == "" {
return errors.New("audit id is required")
}
if row.CreatedAt.IsZero() {
row.CreatedAt = time.Now().UTC()
} else {
row.CreatedAt = row.CreatedAt.UTC()
}
if strings.TrimSpace(row.Level) == "" {
row.Level = "info"
}
detailJSON := ""
if len(row.Detail) > 0 {
if b, err := json.Marshal(row.Detail); err == nil {
detailJSON = string(b)
}
}
query := `
INSERT INTO audit_logs (
id, created_at, level, category, action, result, actor, session_hint,
client_ip, user_agent, resource_type, resource_id, message, detail_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
row.ID, formatSQLiteUTC(row.CreatedAt), row.Level, row.Category, row.Action, row.Result,
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
row.ResourceType, row.ResourceID, row.Message, detailJSON,
)
return err
}
// GetAuditLogByID returns one row.
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
id = strings.TrimSpace(id)
if id == "" {
return nil, errors.New("id is required")
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs WHERE id = ?
`
var row AuditLog
var detailJSON string
err := db.QueryRow(query, id).Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
)
if err != nil {
return nil, err
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
return &row, nil
}
// CountAuditLogs counts rows matching filter.
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
where, args := buildAuditLogsWhere(filter)
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
var n int64
err := db.QueryRow(query, args...).Scan(&n)
return n, err
}
// ListAuditLogs lists audit rows newest first.
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
where, args := buildAuditLogsWhere(filter)
limit := filter.Limit
if limit <= 0 || limit > 500 {
limit = 50
}
offset := filter.Offset
if offset < 0 {
offset = 0
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs
WHERE ` + where + `
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*AuditLog
for rows.Next() {
var row AuditLog
var detailJSON string
if err := rows.Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
); err != nil {
continue
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
list = append(list, &row)
}
return list, rows.Err()
}
// DeleteAuditLogsBefore removes rows older than cutoff.
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
res, err := db.Exec(`DELETE FROM audit_logs WHERE `+sqliteEpochGE("created_at", "<"), formatSQLiteUTC(cutoff))
if err != nil {
return 0, err
}
return res.RowsAffected()
}
+62
View File
@@ -0,0 +1,62 @@
package database
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
func TestBuildAuditLogsWhere_timeFilterSQL(t *testing.T) {
since := time.Date(2026, 6, 16, 17, 2, 0, 0, time.UTC)
until := time.Date(2026, 6, 17, 3, 3, 0, 0, time.UTC)
where, args := buildAuditLogsWhere(ListAuditLogsFilter{Since: &since, Until: &until})
if !strings.Contains(where, "strftime('%s', created_at) >=") {
t.Fatalf("expected epoch comparison for since, got %q", where)
}
if !strings.Contains(where, "strftime('%s', created_at) <=") {
t.Fatalf("expected epoch comparison for until, got %q", where)
}
if len(args) != 2 {
t.Fatalf("expected 2 time args, got %d", len(args))
}
for i, arg := range args {
s, ok := arg.(string)
if !ok || s == "" {
t.Fatalf("arg %d: want non-empty UTC RFC3339 string, got %v", i, arg)
}
}
}
func TestListAuditLogs_timeFilterMixedStorageFormats(t *testing.T) {
root, err := os.Getwd()
if err != nil {
t.Skip(err)
}
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
if _, err := os.Stat(dbPath); err != nil {
t.Skip("conversations.db not found")
}
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
since, _ := ParseRFC3339Time("2026-06-16T17:02:00Z")
until, _ := ParseRFC3339Time("2026-06-17T03:03:00Z")
filter := ListAuditLogsFilter{Since: &since, Until: &until, Limit: 50}
logs, err := db.ListAuditLogs(filter)
if err != nil {
t.Fatal(err)
}
for _, row := range logs {
at := row.CreatedAt.UTC()
if at.Before(since) || at.After(until) {
t.Fatalf("log %s at %s outside [%s, %s]", row.ID, at, since, until)
}
}
}
+543
View File
@@ -0,0 +1,543 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
"go.uber.org/zap"
)
// BatchTaskQueueRow 批量任务队列数据库行
type BatchTaskQueueRow struct {
ID string
Title sql.NullString
Role sql.NullString
AgentMode sql.NullString
ScheduleMode sql.NullString
CronExpr sql.NullString
NextRunAt sql.NullTime
ScheduleEnabled sql.NullInt64
LastScheduleTriggerAt sql.NullTime
LastScheduleError sql.NullString
LastRunError sql.NullString
ProjectID sql.NullString
Status string
CreatedAt time.Time
StartedAt sql.NullTime
CompletedAt sql.NullTime
CurrentIndex int
}
// BatchTaskRow 批量任务数据库行
type BatchTaskRow struct {
ID string
QueueID string
Message string
ConversationID sql.NullString
Status string
StartedAt sql.NullTime
CompletedAt sql.NullTime
Error sql.NullString
Result sql.NullString
}
// CreateBatchQueue 创建批量任务队列
func (db *DB) CreateBatchQueue(
queueID string,
title string,
role string,
agentMode string,
scheduleMode string,
cronExpr string,
nextRunAt *time.Time,
projectID string,
tasks []map[string]interface{},
) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
now := time.Now()
var nextRunAtValue interface{}
if nextRunAt != nil {
nextRunAtValue = *nextRunAt
}
var projectIDVal interface{}
if strings.TrimSpace(projectID) != "" {
projectIDVal = strings.TrimSpace(projectID)
}
_, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
)
if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err)
}
// 插入任务
for _, task := range tasks {
taskID, ok := task["id"].(string)
if !ok {
continue
}
message, ok := task["message"].(string)
if !ok {
continue
}
_, err = tx.Exec(
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
taskID, queueID, message, "pending",
)
if err != nil {
return fmt.Errorf("创建批量任务失败: %w", err)
}
}
return tx.Commit()
}
// GetBatchQueue 获取批量任务队列
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow
var createdAt string
err := db.QueryRow(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
queueID,
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
// 尝试其他时间格式
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
return &row, nil
}
// GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
}
defer rows.Close()
var queues []*BatchTaskQueueRow
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
queues = append(queues, &row)
}
return queues, nil
}
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
if status != "" && status != "all" {
query += " AND status = ?"
args = append(args, status)
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
}
defer rows.Close()
var queues []*BatchTaskQueueRow
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
queues = append(queues, &row)
}
return queues, nil
}
// CountBatchQueues 统计批量任务队列总数(支持筛选条件)
func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
if status != "" && status != "all" {
query += " AND status = ?"
args = append(args, status)
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err)
}
return count, nil
}
// GetBatchTasks 获取批量任务队列的所有任务
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
rows, err := db.Query(
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY rowid ASC",
queueID,
)
if err != nil {
return nil, fmt.Errorf("查询批量任务失败: %w", err)
}
defer rows.Close()
var tasks []*BatchTaskRow
for rows.Next() {
var task BatchTaskRow
if err := rows.Scan(
&task.ID, &task.QueueID, &task.Message, &task.ConversationID,
&task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result,
); err != nil {
return nil, fmt.Errorf("扫描批量任务失败: %w", err)
}
tasks = append(tasks, &task)
}
return tasks, nil
}
// UpdateBatchQueueStatus 更新批量任务队列状态
func (db *DB) UpdateBatchQueueStatus(queueID, status string) error {
var err error
now := time.Now()
if status == "running" {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?",
status, now, queueID,
)
} else if status == "completed" || status == "cancelled" {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?",
status, now, queueID,
)
} else {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ? WHERE id = ?",
status, queueID,
)
}
if err != nil {
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
}
return nil
}
// UpdateBatchTaskStatus 更新批量任务状态
func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error {
var err error
now := time.Now()
// 构建更新语句
var updates []string
var args []interface{}
updates = append(updates, "status = ?")
args = append(args, status)
if conversationID != "" {
updates = append(updates, "conversation_id = ?")
args = append(args, conversationID)
}
if result != "" {
updates = append(updates, "result = ?")
args = append(args, result)
}
if errorMsg != "" {
updates = append(updates, "error = ?")
args = append(args, errorMsg)
}
if status == "running" {
updates = append(updates, "started_at = COALESCE(started_at, ?)")
args = append(args, now)
}
if status == "completed" || status == "failed" || status == "cancelled" {
updates = append(updates, "completed_at = COALESCE(completed_at, ?)")
args = append(args, now)
}
args = append(args, queueID, taskID)
// 构建SQL语句
sql := "UPDATE batch_tasks SET "
for i, update := range updates {
if i > 0 {
sql += ", "
}
sql += update
}
sql += " WHERE queue_id = ? AND id = ?"
_, err = db.Exec(sql, args...)
if err != nil {
return fmt.Errorf("更新批量任务状态失败: %w", err)
}
return nil
}
// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引
func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET current_index = ? WHERE id = ?",
currentIndex, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务队列当前索引失败: %w", err)
}
return nil
}
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
title, role, agentMode, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
}
return nil
}
// UpdateBatchQueueSchedule 更新批量任务队列调度相关信息
func (db *DB) UpdateBatchQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) error {
var nextRunAtValue interface{}
if nextRunAt != nil {
nextRunAtValue = *nextRunAt
}
_, err := db.Exec(
"UPDATE batch_task_queues SET schedule_mode = ?, cron_expr = ?, next_run_at = ? WHERE id = ?",
scheduleMode, cronExpr, nextRunAtValue, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务调度配置失败: %w", err)
}
return nil
}
// UpdateBatchQueueScheduleEnabled 是否允许 Cron 自动触发(手工「开始执行」不受影响)
func (db *DB) UpdateBatchQueueScheduleEnabled(queueID string, enabled bool) error {
v := 0
if enabled {
v = 1
}
_, err := db.Exec(
"UPDATE batch_task_queues SET schedule_enabled = ? WHERE id = ?",
v, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务调度开关失败: %w", err)
}
return nil
}
// RecordBatchQueueScheduledTriggerStart 记录一次由调度触发的开始时间并清空调度层错误
func (db *DB) RecordBatchQueueScheduledTriggerStart(queueID string, at time.Time) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET last_schedule_trigger_at = ?, last_schedule_error = NULL WHERE id = ?",
at, queueID,
)
if err != nil {
return fmt.Errorf("记录调度触发时间失败: %w", err)
}
return nil
}
// SetBatchQueueLastScheduleError 调度启动失败等原因(如状态不允许、重置失败)
func (db *DB) SetBatchQueueLastScheduleError(queueID, msg string) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET last_schedule_error = ? WHERE id = ?",
msg, queueID,
)
if err != nil {
return fmt.Errorf("写入调度错误信息失败: %w", err)
}
return nil
}
// SetBatchQueueLastRunError 最近一轮执行中出现的子任务失败摘要(空串表示清空)
func (db *DB) SetBatchQueueLastRunError(queueID, msg string) error {
var v interface{}
if strings.TrimSpace(msg) == "" {
v = nil
} else {
v = msg
}
_, err := db.Exec(
"UPDATE batch_task_queues SET last_run_error = ? WHERE id = ?",
v, queueID,
)
if err != nil {
return fmt.Errorf("写入最近运行错误失败: %w", err)
}
return nil
}
// ResetBatchQueueForRerun 重置队列和任务状态用于下一轮调度执行
func (db *DB) ResetBatchQueueForRerun(queueID string) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
_, err = tx.Exec(
"UPDATE batch_task_queues SET status = ?, current_index = 0, started_at = NULL, completed_at = NULL, last_run_error = NULL, last_schedule_error = NULL WHERE id = ?",
"pending", queueID,
)
if err != nil {
return fmt.Errorf("重置批量任务队列状态失败: %w", err)
}
_, err = tx.Exec(
"UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ?",
"pending", queueID,
)
if err != nil {
return fmt.Errorf("重置批量任务状态失败: %w", err)
}
return tx.Commit()
}
// UpdateBatchTaskMessage 更新批量任务消息
func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error {
_, err := db.Exec(
"UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?",
message, queueID, taskID,
)
if err != nil {
return fmt.Errorf("更新批量任务消息失败: %w", err)
}
return nil
}
// AddBatchTask 添加任务到批量任务队列
func (db *DB) AddBatchTask(queueID, taskID, message string) error {
_, err := db.Exec(
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
taskID, queueID, message, "pending",
)
if err != nil {
return fmt.Errorf("添加批量任务失败: %w", err)
}
return nil
}
// CancelPendingBatchTasks 批量取消队列中所有 pending 状态的任务(单条 SQL
func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) error {
_, err := db.Exec(
"UPDATE batch_tasks SET status = ?, completed_at = ? WHERE queue_id = ? AND status = ?",
"cancelled", completedAt, queueID, "pending",
)
if err != nil {
return fmt.Errorf("批量取消 pending 任务失败: %w", err)
}
return nil
}
// DeleteBatchTask 删除批量任务
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
_, err := db.Exec(
"DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?",
queueID, taskID,
)
if err != nil {
return fmt.Errorf("删除批量任务失败: %w", err)
}
return nil
}
// DeleteBatchQueue 删除批量任务队列
func (db *DB) DeleteBatchQueue(queueID string) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
// 删除任务(外键会自动级联删除)
_, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID)
if err != nil {
return fmt.Errorf("删除批量任务失败: %w", err)
}
// 删除队列
_, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID)
if err != nil {
return fmt.Errorf("删除批量任务队列失败: %w", err)
}
return tx.Commit()
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,57 @@
package database
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "conversations.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
convID := conv.ID
seg := sanitizeConversationPathSegment(convID)
for _, base := range []struct {
root string
file string
}{
{db.conversationArtifactsDir, "transcript.txt"},
{plantaskBase, "task-1.json"},
{checkpointBase, "runner-deep.ckpt"},
} {
dir := filepath.Join(base.root, seg)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", dir, err)
}
if err := os.WriteFile(filepath.Join(dir, base.file), []byte("x"), 0o644); err != nil {
t.Fatalf("write %s: %v", base.file, err)
}
}
if err := db.DeleteConversation(convID); err != nil {
t.Fatalf("DeleteConversation: %v", err)
}
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} {
dir := filepath.Join(base, seg)
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
}
}
}
@@ -0,0 +1,30 @@
package database
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
type ConversationCreateMeta struct {
Source string
WebShellConnectionID string
ProjectID string
ClientIP string
SessionHint string
}
// ConversationCreateHook is invoked after a conversation row is inserted.
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
var conversationCreateHook ConversationCreateHook
// SetConversationCreateHook registers a global hook (e.g. platform audit).
func SetConversationCreateHook(h ConversationCreateHook) {
conversationCreateHook = h
}
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
if conversationCreateHook == nil || conv == nil {
return
}
if meta.Source == "" {
meta.Source = "unknown"
}
conversationCreateHook(conv, meta)
}
@@ -0,0 +1,39 @@
package database
import (
"testing"
)
func TestTurnSliceRange(t *testing.T) {
mk := func(id, role string) Message {
return Message{ID: id, Role: role}
}
msgs := []Message{
mk("u1", "user"),
mk("a1", "assistant"),
mk("u2", "user"),
mk("a2", "assistant"),
}
cases := []struct {
anchor string
start int
end int
}{
{"u1", 0, 2},
{"a1", 0, 2},
{"u2", 2, 4},
{"a2", 2, 4},
}
for _, tc := range cases {
s, e, err := turnSliceRange(msgs, tc.anchor)
if err != nil {
t.Fatalf("anchor %s: %v", tc.anchor, err)
}
if s != tc.start || e != tc.end {
t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end)
}
}
if _, _, err := turnSliceRange(msgs, "nope"); err == nil {
t.Fatal("expected error for missing id")
}
}
@@ -0,0 +1,69 @@
package database
import (
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestDeleteConversationPreservesVulnerabilities(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "vuln-preserve.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
conv, err := db.CreateConversation("vuln source chat", ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
vuln, err := db.CreateVulnerability(&Vulnerability{
ConversationID: conv.ID,
Title: "SQL Injection",
Severity: "high",
Status: "open",
})
if err != nil {
t.Fatalf("CreateVulnerability: %v", err)
}
if err := db.DeleteConversation(conv.ID); err != nil {
t.Fatalf("DeleteConversation: %v", err)
}
got, err := db.GetVulnerability(vuln.ID)
if err != nil {
t.Fatalf("GetVulnerability after delete: %v", err)
}
if got.Title != "SQL Injection" {
t.Fatalf("title = %q, want SQL Injection", got.Title)
}
if got.ConversationID != "" {
t.Fatalf("conversation_id = %q, want empty after conversation delete", got.ConversationID)
}
if got.ConversationTag != "vuln source chat" {
t.Fatalf("conversation_tag = %q, want vuln source chat", got.ConversationTag)
}
}
func TestMigrateVulnerabilitiesConversationFK(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "vuln-fk-migrate.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
ok, err := vulnerabilitiesConversationFKOnDeleteSetNull(db.DB)
if err != nil {
t.Fatalf("vulnerabilitiesConversationFKOnDeleteSetNull: %v", err)
}
if !ok {
t.Fatal("expected vulnerabilities.conversation_id FK to use ON DELETE SET NULL")
}
}
File diff suppressed because it is too large Load Diff
+449
View File
@@ -0,0 +1,449 @@
package database
import (
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
)
// ConversationGroup 对话分组
type ConversationGroup struct {
ID string `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// GroupExistsByName 检查分组名称是否已存在
func (db *DB) GroupExistsByName(name string, excludeID string) (bool, error) {
var count int
var err error
if excludeID != "" {
err = db.QueryRow(
"SELECT COUNT(*) FROM conversation_groups WHERE name = ? AND id != ?",
name, excludeID,
).Scan(&count)
} else {
err = db.QueryRow(
"SELECT COUNT(*) FROM conversation_groups WHERE name = ?",
name,
).Scan(&count)
}
if err != nil {
return false, fmt.Errorf("检查分组名称失败: %w", err)
}
return count > 0, nil
}
// CreateGroup 创建分组
func (db *DB) CreateGroup(name, icon string) (*ConversationGroup, error) {
// 检查名称是否已存在
exists, err := db.GroupExistsByName(name, "")
if err != nil {
return nil, err
}
if exists {
return nil, fmt.Errorf("分组名称已存在")
}
id := uuid.New().String()
now := time.Now()
if icon == "" {
icon = "📁"
}
_, err = db.Exec(
"INSERT INTO conversation_groups (id, name, icon, pinned, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
id, name, icon, 0, now, now,
)
if err != nil {
return nil, fmt.Errorf("创建分组失败: %w", err)
}
return &ConversationGroup{
ID: id,
Name: name,
Icon: icon,
Pinned: false,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// ListGroups 列出所有分组
func (db *DB) ListGroups() ([]*ConversationGroup, error) {
rows, err := db.Query(
"SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups ORDER BY COALESCE(pinned, 0) DESC, created_at ASC",
)
if err != nil {
return nil, fmt.Errorf("查询分组列表失败: %w", err)
}
defer rows.Close()
var groups []*ConversationGroup
for rows.Next() {
var group ConversationGroup
var createdAt, updatedAt string
var pinned int
if err := rows.Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描分组失败: %w", err)
}
group.Pinned = pinned != 0
// 尝试多种时间格式解析
var err1, err2 error
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
groups = append(groups, &group)
}
return groups, nil
}
// GetGroup 获取分组
func (db *DB) GetGroup(id string) (*ConversationGroup, error) {
var group ConversationGroup
var createdAt, updatedAt string
var pinned int
err := db.QueryRow(
"SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups WHERE id = ?",
id,
).Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("分组不存在")
}
return nil, fmt.Errorf("查询分组失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
group.Pinned = pinned != 0
return &group, nil
}
// UpdateGroup 更新分组
func (db *DB) UpdateGroup(id, name, icon string) error {
// 检查名称是否已存在(排除当前分组)
exists, err := db.GroupExistsByName(name, id)
if err != nil {
return err
}
if exists {
return fmt.Errorf("分组名称已存在")
}
_, err = db.Exec(
"UPDATE conversation_groups SET name = ?, icon = ?, updated_at = ? WHERE id = ?",
name, icon, time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新分组失败: %w", err)
}
return nil
}
// DeleteGroup 删除分组
func (db *DB) DeleteGroup(id string) error {
_, err := db.Exec("DELETE FROM conversation_groups WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除分组失败: %w", err)
}
return nil
}
// AddConversationToGroup 将对话添加到分组
// 注意:一个对话只能属于一个分组,所以在添加新分组之前,会先删除该对话的所有旧分组关联
func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
// 先删除该对话的所有旧分组关联,确保一个对话只属于一个分组
_, err := db.Exec(
"DELETE FROM conversation_group_mappings WHERE conversation_id = ?",
conversationID,
)
if err != nil {
return fmt.Errorf("删除对话旧分组关联失败: %w", err)
}
// 然后插入新的分组关联
id := uuid.New().String()
_, err = db.Exec(
"INSERT INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)",
id, conversationID, groupID, time.Now(),
)
if err != nil {
return fmt.Errorf("添加对话到分组失败: %w", err)
}
return nil
}
// RemoveConversationFromGroup 从分组中移除对话
func (db *DB) RemoveConversationFromGroup(conversationID, groupID string) error {
_, err := db.Exec(
"DELETE FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?",
conversationID, groupID,
)
if err != nil {
return fmt.Errorf("从分组中移除对话失败: %w", err)
}
return nil
}
// GetConversationsByGroup 获取分组中的所有对话
func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) {
rows, err := db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
FROM conversations c
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
WHERE cgm.group_id = ?
ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC`,
groupID,
)
if err != nil {
return nil, fmt.Errorf("查询分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var groupPinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配)
func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) {
// 构建SQL查询,支持按标题和消息内容搜索
// 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复
query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
FROM conversations c
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
WHERE cgm.group_id = ?`
args := []interface{}{groupID}
// 如果有搜索关键词,添加标题和消息内容搜索条件
if searchQuery != "" {
searchPattern := "%" + searchQuery + "%"
// 搜索标题或消息内容
// 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题)
query += ` AND (
LOWER(c.title) LIKE LOWER(?)
OR EXISTS (
SELECT 1 FROM messages m
WHERE m.conversation_id = c.id
AND LOWER(m.content) LIKE LOWER(?)
)
)`
args = append(args, searchPattern, searchPattern)
}
query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC"
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var groupPinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// GetGroupByConversation 获取对话所属的分组
func (db *DB) GetGroupByConversation(conversationID string) (string, error) {
var groupID string
err := db.QueryRow(
"SELECT group_id FROM conversation_group_mappings WHERE conversation_id = ? LIMIT 1",
conversationID,
).Scan(&groupID)
if err != nil {
if err == sql.ErrNoRows {
return "", nil // 没有分组
}
return "", fmt.Errorf("查询对话分组失败: %w", err)
}
return groupID, nil
}
// UpdateConversationPinned 更新对话置顶状态
func (db *DB) UpdateConversationPinned(id string, pinned bool) error {
pinnedValue := 0
if pinned {
pinnedValue = 1
}
// 注意:不更新 updated_at,因为置顶操作不应该改变对话的更新时间
_, err := db.Exec(
"UPDATE conversations SET pinned = ? WHERE id = ?",
pinnedValue, id,
)
if err != nil {
return fmt.Errorf("更新对话置顶状态失败: %w", err)
}
return nil
}
// UpdateGroupPinned 更新分组置顶状态
func (db *DB) UpdateGroupPinned(id string, pinned bool) error {
pinnedValue := 0
if pinned {
pinnedValue = 1
}
_, err := db.Exec(
"UPDATE conversation_groups SET pinned = ?, updated_at = ? WHERE id = ?",
pinnedValue, time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新分组置顶状态失败: %w", err)
}
return nil
}
// GroupMapping 分组映射关系
type GroupMapping struct {
ConversationID string `json:"conversationId"`
GroupID string `json:"groupId"`
}
// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询)
func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) {
rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings")
if err != nil {
return nil, fmt.Errorf("查询分组映射失败: %w", err)
}
defer rows.Close()
var mappings []GroupMapping
for rows.Next() {
var m GroupMapping
if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil {
return nil, fmt.Errorf("扫描分组映射失败: %w", err)
}
mappings = append(mappings, m)
}
if mappings == nil {
mappings = []GroupMapping{}
}
return mappings, nil
}
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error {
pinnedValue := 0
if pinned {
pinnedValue = 1
}
_, err := db.Exec(
"UPDATE conversation_group_mappings SET pinned = ? WHERE conversation_id = ? AND group_id = ?",
pinnedValue, conversationID, groupID,
)
if err != nil {
return fmt.Errorf("更新分组对话置顶状态失败: %w", err)
}
return nil
}
+617
View File
@@ -0,0 +1,617 @@
package database
import (
"database/sql"
"encoding/json"
"sort"
"strings"
"time"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// SaveToolExecution 保存工具执行记录
func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
argsJSON, err := json.Marshal(exec.Arguments)
if err != nil {
db.logger.Warn("序列化执行参数失败", zap.Error(err))
argsJSON = []byte("{}")
}
var resultJSON sql.NullString
if exec.Result != nil {
resultBytes, err := json.Marshal(exec.Result)
if err != nil {
db.logger.Warn("序列化执行结果失败", zap.Error(err))
} else {
resultJSON = sql.NullString{String: string(resultBytes), Valid: true}
}
}
var errorText sql.NullString
if exec.Error != "" {
errorText = sql.NullString{String: exec.Error, Valid: true}
}
var endTime sql.NullTime
if exec.EndTime != nil {
endTime = sql.NullTime{Time: *exec.EndTime, Valid: true}
}
var durationMs sql.NullInt64
if exec.Duration > 0 {
durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true}
}
query := `
INSERT OR REPLACE INTO tool_executions
(id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err = db.Exec(query,
exec.ID,
exec.ToolName,
string(argsJSON),
exec.Status,
resultJSON,
errorText,
exec.StartTime,
endTime,
durationMs,
time.Now(),
)
if err != nil {
db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID))
return err
}
return nil
}
// UpdateToolExecutionResult 仅更新结果字段(用于 reduction 后将监控展示与模型上下文对齐)。
func (db *DB) UpdateToolExecutionResult(id string, result *mcp.ToolResult) error {
id = strings.TrimSpace(id)
if id == "" || result == nil {
return nil
}
resultBytes, err := json.Marshal(result)
if err != nil {
return err
}
_, err = db.Exec(`UPDATE tool_executions SET result = ? WHERE id = ?`, string(resultBytes), id)
if err != nil {
db.logger.Warn("更新工具执行结果失败", zap.Error(err), zap.String("executionId", id))
}
return err
}
// CountToolExecutions 统计工具执行记录总数
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
query := `SELECT COUNT(*) FROM tool_executions`
args := []interface{}{}
conditions := []string{}
if status != "" {
conditions = append(conditions, "status = ?")
args = append(args, status)
}
if toolName != "" {
// 支持部分匹配(模糊搜索),不区分大小写
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) > 0 {
query += ` WHERE ` + conditions[0]
for i := 1; i < len(conditions); i++ {
query += ` AND ` + conditions[i]
}
}
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// LoadToolExecutions 加载所有工具执行记录(支持分页)
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
}
// LoadToolExecutionsWithPagination 分页加载工具执行记录
// limit: 最大返回记录数,0 表示使用默认值 1000
// offset: 跳过的记录数,用于分页
// status: 状态筛选,空字符串表示不过滤
// toolName: 工具名称筛选,空字符串表示不过滤
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
if limit <= 0 {
limit = 1000 // 默认限制
}
if limit > 10000 {
limit = 10000 // 最大限制,防止一次性加载过多数据
}
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
`
args := []interface{}{}
conditions := []string{}
if status != "" {
conditions = append(conditions, "status = ?")
args = append(args, status)
}
if toolName != "" {
// 支持部分匹配(模糊搜索),不区分大小写
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) > 0 {
query += ` WHERE ` + conditions[0]
for i := 1; i < len(conditions); i++ {
query += ` AND ` + conditions[i]
}
}
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var executions []*mcp.ToolExecution
for rows.Next() {
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := rows.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
db.logger.Warn("加载执行记录失败", zap.Error(err))
continue
}
// 解析参数
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
// 解析结果
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
// 设置错误
if errorText.Valid {
exec.Error = errorText.String
}
// 设置结束时间
if endTime.Valid {
exec.EndTime = &endTime.Time
}
// 设置持续时间
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
executions = append(executions, &exec)
}
return executions, nil
}
// GetToolExecution 根据ID获取单条工具执行记录
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
WHERE id = ?
`
row := db.QueryRow(query, id)
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := row.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
if errorText.Valid {
exec.Error = errorText.String
}
if endTime.Valid {
exec.EndTime = &endTime.Time
}
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
return &exec, nil
}
// DeleteToolExecution 删除工具执行记录
func (db *DB) DeleteToolExecution(id string) error {
query := `DELETE FROM tool_executions WHERE id = ?`
_, err := db.Exec(query, id)
if err != nil {
db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id))
return err
}
return nil
}
// DeleteToolExecutions 批量删除工具执行记录
func (db *DB) DeleteToolExecutions(ids []string) error {
if len(ids) == 0 {
return nil
}
// 构建 IN 查询的占位符
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)`
_, err := db.Exec(query, args...)
if err != nil {
db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids)))
return err
}
return nil
}
// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息)
func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) {
if len(ids) == 0 {
return []*mcp.ToolExecution{}, nil
}
// 构建 IN 查询的占位符
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
WHERE id IN (` + strings.Join(placeholders, ",") + `)
`
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var executions []*mcp.ToolExecution
for rows.Next() {
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := rows.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
db.logger.Warn("加载执行记录失败", zap.Error(err))
continue
}
// 解析参数
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
// 解析结果
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
// 设置错误
if errorText.Valid {
exec.Error = errorText.String
}
// 设置结束时间
if endTime.Valid {
exec.EndTime = &endTime.Time
}
// 设置持续时间
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
executions = append(executions, &exec)
}
return executions, nil
}
// SaveToolStats 保存工具统计信息
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
var lastCallTime sql.NullTime
if stats.LastCallTime != nil {
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
}
query := `
INSERT OR REPLACE INTO tool_stats
(tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
toolName,
stats.TotalCalls,
stats.SuccessCalls,
stats.FailedCalls,
lastCallTime,
time.Now(),
)
if err != nil {
db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
return err
}
return nil
}
// LoadToolStats 加载所有工具统计信息
func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) {
query := `
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
FROM tool_stats
`
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
stats := make(map[string]*mcp.ToolStats)
for rows.Next() {
var stat mcp.ToolStats
var lastCallTime sql.NullTime
err := rows.Scan(
&stat.ToolName,
&stat.TotalCalls,
&stat.SuccessCalls,
&stat.FailedCalls,
&lastCallTime,
)
if err != nil {
db.logger.Warn("加载统计信息失败", zap.Error(err))
continue
}
if lastCallTime.Valid {
stat.LastCallTime = &lastCallTime.Time
}
stats[stat.ToolName] = &stat
}
return stats, nil
}
// UpdateToolStats 更新工具统计信息(累加模式)
func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
var lastCallTimeSQL sql.NullTime
if lastCallTime != nil {
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
}
query := `
INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(tool_name) DO UPDATE SET
total_calls = total_calls + ?,
success_calls = success_calls + ?,
failed_calls = failed_calls + ?,
last_call_time = COALESCE(?, last_call_time),
updated_at = ?
`
_, err := db.Exec(query,
toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
)
if err != nil {
db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
return err
}
return nil
}
// CallsTimelineBucket 调用趋势时间桶
type CallsTimelineBucket struct {
BucketTime time.Time
Total int
Failed int
}
// truncateCallsTimelineBucket 将时间截断到趋势图桶边界(本地时区,与 handler 侧 truncateToBucket 一致)
func truncateCallsTimelineBucket(t time.Time, dailyBuckets bool) time.Time {
t = t.In(time.Local)
if dailyBuckets {
y, m, d := t.Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.Local)
}
return t.Truncate(time.Hour)
}
// LoadCallsTimeline 按时间范围加载调用趋势(since 起至今,含边界)
func (db *DB) LoadCallsTimeline(since time.Time, dailyBuckets bool) ([]CallsTimelineBucket, error) {
// 在 Go 侧按本地时区分桶,避免 SQLite strftime 对 UTC 存储时间分桶后再误当本地时间解析(差 8h 等问题)
query := `
SELECT start_time,
CASE WHEN status IN ('failed', 'cancelled') THEN 1 ELSE 0 END AS failed
FROM tool_executions
WHERE start_time >= ?
`
rows, err := db.Query(query, since)
if err != nil {
return nil, err
}
defer rows.Close()
bucketMap := make(map[time.Time]struct{ total, failed int })
for rows.Next() {
var startTime time.Time
var failed int
if err := rows.Scan(&startTime, &failed); err != nil {
db.logger.Warn("加载调用趋势失败", zap.Error(err))
continue
}
key := truncateCallsTimelineBucket(startTime, dailyBuckets)
entry := bucketMap[key]
entry.total++
entry.failed += failed
bucketMap[key] = entry
}
buckets := make([]CallsTimelineBucket, 0, len(bucketMap))
for bucketTime, counts := range bucketMap {
buckets = append(buckets, CallsTimelineBucket{
BucketTime: bucketTime,
Total: counts.total,
Failed: counts.failed,
})
}
sort.Slice(buckets, func(i, j int) bool {
return buckets[i].BucketTime.Before(buckets[j].BucketTime)
})
return buckets, nil
}
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
// 如果统计信息变为0,则删除该统计记录
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
// 先更新统计信息
query := `
UPDATE tool_stats SET
total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END,
success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END,
failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END,
updated_at = ?
WHERE tool_name = ?
`
_, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName)
if err != nil {
db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
return err
}
// 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录
checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?`
var newTotalCalls int
err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls)
if err != nil {
// 如果查询失败(记录不存在),直接返回
return nil
}
// 如果 total_calls 为 0,删除该统计记录
if newTotalCalls == 0 {
deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?`
_, err = db.Exec(deleteQuery, toolName)
if err != nil {
db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName))
// 不返回错误,因为主要操作(更新统计)已成功
} else {
db.logger.Info("已删除零统计记录", zap.String("toolName", toolName))
}
}
return nil
}
@@ -0,0 +1,28 @@
package database
import (
"fmt"
"strings"
)
// DedupeConsecutiveProcessDetails 去掉相邻且语义相同的过程详情(使用 DB 中 data 列原始 JSON 作指纹,避免 map 序列化键序不稳定)。
func DedupeConsecutiveProcessDetails(rows []ProcessDetail) []ProcessDetail {
if len(rows) < 2 {
return rows
}
out := make([]ProcessDetail, 0, len(rows))
var lastKey string
for _, d := range rows {
key := processDetailRowKey(d)
if len(out) > 0 && key != "" && key == lastKey {
continue
}
out = append(out, d)
lastKey = key
}
return out
}
func processDetailRowKey(d ProcessDetail) string {
return fmt.Sprintf("%s\x00%s\x00%s", d.EventType, strings.TrimSpace(d.Message), d.Data)
}
+528
View File
@@ -0,0 +1,528 @@
package database
import (
"database/sql"
"fmt"
"regexp"
"strings"
"time"
"github.com/google/uuid"
)
var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`)
// ValidateFactKey 校验事实 key(项目内唯一标识)。
func ValidateFactKey(key string) error {
key = strings.TrimSpace(key)
if key == "" {
return fmt.Errorf("fact_key 不能为空")
}
if len(key) > 128 {
return fmt.Errorf("fact_key 过长(最多 128 字符)")
}
if !factKeyPattern.MatchString(key) {
return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头")
}
return nil
}
// Project 渗透测试项目(跨对话共享黑板)。
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
ScopeJSON string `json:"scope_json,omitempty"`
Status string `json:"status"` // active | archived
Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectFact 项目事实(黑板条目)。
type ProjectFact struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
FactKey string `json:"fact_key"`
Category string `json:"category"`
Summary string `json:"summary"`
Body string `json:"body"`
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
SourceConversationID string `json:"source_conversation_id,omitempty"`
SourceMessageID string `json:"source_message_id,omitempty"`
Pinned bool `json:"pinned"`
RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectFactListFilter 事实列表筛选。
type ProjectFactListFilter struct {
Category string
Confidence string
Search string
RelatedVulnerabilityID string
ExcludeDeprecated bool // 为 true 时排除 confidence=deprecated
}
// CreateProject 创建项目。
func (db *DB) CreateProject(p *Project) (*Project, error) {
if p.ID == "" {
p.ID = uuid.New().String()
}
if strings.TrimSpace(p.Status) == "" {
p.Status = "active"
}
now := time.Now()
if p.CreatedAt.IsZero() {
p.CreatedAt = now
}
p.UpdatedAt = now
_, err := db.Exec(
`INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建项目失败: %w", err)
}
return p, nil
}
// GetProject 获取项目。
func (db *DB) GetProject(id string) (*Project, error) {
var p Project
var pinned int
var createdAt, updatedAt string
err := db.QueryRow(
`SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
FROM projects WHERE id = ?`, id,
).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("项目不存在")
}
return nil, fmt.Errorf("获取项目失败: %w", err)
}
p.Pinned = pinned != 0
p.CreatedAt = parseDBTime(createdAt)
p.UpdatedAt = parseDBTime(updatedAt)
return &p, nil
}
// CountProjects 统计项目数量。
func (db *DB) CountProjects(status, search string) (int, error) {
query := `SELECT COUNT(*) FROM projects WHERE 1=1`
args := []interface{}{}
if s := strings.TrimSpace(status); s != "" {
query += " AND status = ?"
args = append(args, s)
}
if q := strings.TrimSpace(search); q != "" {
pattern := "%" + q + "%"
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
args = append(args, pattern, pattern)
}
var count int
if err := db.QueryRow(query, args...).Scan(&count); err != nil {
return 0, fmt.Errorf("统计项目失败: %w", err)
}
return count, nil
}
// ListProjects 列出项目。
func (db *DB) ListProjects(status, search string, limit, offset int) ([]*Project, error) {
if limit <= 0 {
limit = 50
}
query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at
FROM projects WHERE 1=1`
args := []interface{}{}
if s := strings.TrimSpace(status); s != "" {
query += " AND status = ?"
args = append(args, s)
}
if q := strings.TrimSpace(search); q != "" {
pattern := "%" + q + "%"
query += " AND (name LIKE ? OR COALESCE(description,'') LIKE ?)"
args = append(args, pattern, pattern)
}
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("列出项目失败: %w", err)
}
defer rows.Close()
var out []*Project
for rows.Next() {
var p Project
var pinned int
var createdAt, updatedAt string
if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil {
return nil, err
}
p.Pinned = pinned != 0
p.CreatedAt = parseDBTime(createdAt)
p.UpdatedAt = parseDBTime(updatedAt)
out = append(out, &p)
}
return out, rows.Err()
}
// UpdateProject 更新项目。
func (db *DB) UpdateProject(p *Project) error {
p.UpdatedAt = time.Now()
_, err := db.Exec(
`UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`,
p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID,
)
if err != nil {
return fmt.Errorf("更新项目失败: %w", err)
}
return nil
}
// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理;漏洞 project_id 置空)。
func (db *DB) DeleteProject(id string) error {
if _, err := db.Exec(`UPDATE vulnerabilities SET project_id = NULL WHERE project_id = ?`, id); err != nil {
return fmt.Errorf("解除漏洞项目关联失败: %w", err)
}
_, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("删除项目失败: %w", err)
}
return nil
}
// GetConversationProjectID 返回对话绑定的项目 ID。
func (db *DB) GetConversationProjectID(conversationID string) (string, error) {
var pid sql.NullString
err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid)
if err != nil {
if err == sql.ErrNoRows {
return "", fmt.Errorf("对话不存在")
}
return "", err
}
if pid.Valid {
return strings.TrimSpace(pid.String), nil
}
return "", nil
}
// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。
func (db *DB) SetConversationProjectID(conversationID, projectID string) error {
projectID = strings.TrimSpace(projectID)
if projectID != "" {
if _, err := db.GetProject(projectID); err != nil {
return err
}
}
var val interface{}
if projectID == "" {
val = nil
} else {
val = projectID
}
_, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID)
if err != nil {
return fmt.Errorf("设置对话项目失败: %w", err)
}
return nil
}
// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。
func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) {
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ?`
args := []interface{}{projectID}
if !includeDeprecated {
query += " AND confidence != 'deprecated'"
}
query += " ORDER BY pinned DESC, updated_at DESC"
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFacts(rows)
}
// ListProjectFacts 分页列出项目事实。
func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) {
if limit <= 0 {
limit = 100
}
query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ?`
args := []interface{}{projectID}
if c := strings.TrimSpace(filter.Category); c != "" {
query += " AND category = ?"
args = append(args, c)
}
if c := strings.TrimSpace(filter.Confidence); c != "" {
query += " AND confidence = ?"
args = append(args, c)
}
if filter.ExcludeDeprecated {
query += " AND confidence != 'deprecated'"
}
if rid := strings.TrimSpace(filter.RelatedVulnerabilityID); rid != "" {
query += " AND related_vulnerability_id = ?"
args = append(args, rid)
}
if s := strings.TrimSpace(filter.Search); s != "" {
pat := "%" + s + "%"
query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)"
args = append(args, pat, pat, pat)
}
query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanProjectFacts(rows)
}
// GetProjectFactByKey 按 key 获取事实。
func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) {
row := db.QueryRow(
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE project_id = ? AND fact_key = ?`,
projectID, factKey,
)
return scanProjectFactRow(row)
}
// GetProjectFact 按 ID 获取事实。
func (db *DB) GetProjectFact(id string) (*ProjectFact, error) {
row := db.QueryRow(
`SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence,
COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned,
COALESCE(related_vulnerability_id,''), created_at, updated_at
FROM project_facts WHERE id = ?`, id,
)
return scanProjectFactRow(row)
}
// mergeFactBodyOnUpdate 更新时若 incoming body 为空则保留已有内容,避免仅改 summary 时丢失攻击链。
func mergeFactBodyOnUpdate(incoming, existing string) string {
if strings.TrimSpace(incoming) == "" {
return existing
}
return incoming
}
// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。
func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
if err := ValidateFactKey(f.FactKey); err != nil {
return nil, err
}
if strings.TrimSpace(f.Category) == "" {
f.Category = "note"
}
if strings.TrimSpace(f.Confidence) == "" {
f.Confidence = "tentative"
}
now := time.Now()
existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey)
if err == nil && existing != nil {
f.ID = existing.ID
f.CreatedAt = existing.CreatedAt
f.UpdatedAt = now
f.Body = mergeFactBodyOnUpdate(f.Body, existing.Body)
if strings.TrimSpace(f.Category) == "" {
f.Category = existing.Category
}
if strings.TrimSpace(f.Confidence) == "" {
f.Confidence = existing.Confidence
}
_, err = db.Exec(
`UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?,
source_conversation_id = COALESCE(?, source_conversation_id),
source_message_id = COALESCE(?, source_message_id),
pinned = ?, related_vulnerability_id = ?, updated_at = ?
WHERE id = ?`,
f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID,
)
if err != nil {
return nil, fmt.Errorf("更新事实失败: %w", err)
}
return f, nil
}
if f.ID == "" {
f.ID = uuid.New().String()
}
f.CreatedAt = now
f.UpdatedAt = now
_, err = db.Exec(
`INSERT INTO project_facts (
id, project_id, fact_key, category, summary, body, confidence,
source_conversation_id, source_message_id, pinned, related_vulnerability_id,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence,
nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned),
nullIfEmpty(f.RelatedVulnerabilityID),
f.CreatedAt, f.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建事实失败: %w", err)
}
return f, nil
}
// DeprecateProjectFact 将事实标记为 deprecated。
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
res, err := db.Exec(
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
time.Now(), projectID, factKey,
)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return fmt.Errorf("事实不存在")
}
return nil
}
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
confidence = strings.TrimSpace(strings.ToLower(confidence))
if confidence == "" {
confidence = "tentative"
}
if confidence != "confirmed" && confidence != "tentative" {
return fmt.Errorf("confidence 须为 confirmed 或 tentative")
}
existing, err := db.GetProjectFactByKey(projectID, factKey)
if err != nil {
return fmt.Errorf("事实不存在")
}
if strings.ToLower(strings.TrimSpace(existing.Confidence)) != "deprecated" {
return fmt.Errorf("事实未处于废弃状态")
}
_, err = db.Exec(
`UPDATE project_facts SET confidence = ?, updated_at = ? WHERE project_id = ? AND fact_key = ?`,
confidence, time.Now(), projectID, factKey,
)
return err
}
// DeleteProjectFact 删除事实。
func (db *DB) DeleteProjectFact(id string) error {
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
return err
}
func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) {
var out []*ProjectFact
for rows.Next() {
f, err := scanProjectFactFromRows(rows)
if err != nil {
return nil, err
}
out = append(out, f)
}
return out, rows.Err()
}
func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) {
var f ProjectFact
var pinned int
var createdAt, updatedAt string
err := row.Scan(
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
&f.SourceConversationID, &f.SourceMessageID, &pinned,
&f.RelatedVulnerabilityID, &createdAt, &updatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("事实不存在")
}
return nil, err
}
f.Pinned = pinned != 0
f.CreatedAt = parseDBTime(createdAt)
f.UpdatedAt = parseDBTime(updatedAt)
return &f, nil
}
func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) {
var f ProjectFact
var pinned int
var createdAt, updatedAt string
err := rows.Scan(
&f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence,
&f.SourceConversationID, &f.SourceMessageID, &pinned,
&f.RelatedVulnerabilityID, &createdAt, &updatedAt,
)
if err != nil {
return nil, err
}
f.Pinned = pinned != 0
f.CreatedAt = parseDBTime(createdAt)
f.UpdatedAt = parseDBTime(updatedAt)
return &f, nil
}
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func nullIfEmpty(s string) interface{} {
if strings.TrimSpace(s) == "" {
return nil
}
return s
}
func parseDBTime(s string) time.Time {
s = strings.TrimSpace(s)
if s == "" {
return time.Time{}
}
// go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态
layouts := []string{
time.RFC3339Nano,
time.RFC3339,
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02T15:04:05-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05.999999999",
"2006-01-02T15:04:05",
}
for _, layout := range layouts {
if t, e := time.Parse(layout, s); e == nil {
return t
}
}
return time.Time{}
}
+91
View File
@@ -0,0 +1,91 @@
package database
import (
"fmt"
"strings"
"time"
)
// ProjectDashboardFact 仪表盘跨项目近期事实条目。
type ProjectDashboardFact struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
ProjectName string `json:"project_name"`
FactKey string `json:"fact_key"`
Category string `json:"category"`
Summary string `json:"summary"`
Confidence string `json:"confidence"`
Pinned bool `json:"pinned"`
UpdatedAt time.Time `json:"updated_at"`
}
// ProjectDashboardTotals 仪表盘项目事实汇总计数。
type ProjectDashboardTotals struct {
ActiveProjects int `json:"active_projects"`
TotalFacts int `json:"total_facts"`
}
// ProjectDashboardSummary 仪表盘项目情报摘要。
type ProjectDashboardSummary struct {
RecentFacts []ProjectDashboardFact `json:"recent_facts"`
Totals ProjectDashboardTotals `json:"totals"`
}
// GetProjectDashboardSummary 聚合跨项目近期事实(仅活跃项目、排除 deprecated)。
func (db *DB) GetProjectDashboardSummary(factLimit int) (*ProjectDashboardSummary, error) {
if factLimit <= 0 {
factLimit = 5
}
if factLimit > 50 {
factLimit = 50
}
out := &ProjectDashboardSummary{
RecentFacts: []ProjectDashboardFact{},
}
if err := db.QueryRow(`SELECT COUNT(*) FROM projects WHERE status = 'active'`).Scan(&out.Totals.ActiveProjects); err != nil {
return nil, fmt.Errorf("统计活跃项目失败: %w", err)
}
if err := db.QueryRow(
`SELECT COUNT(*) FROM project_facts f
INNER JOIN projects p ON p.id = f.project_id
WHERE f.confidence != 'deprecated' AND p.status = 'active'`,
).Scan(&out.Totals.TotalFacts); err != nil {
return nil, fmt.Errorf("统计事实失败: %w", err)
}
rows, err := db.Query(
`SELECT f.id, f.project_id, p.name, f.fact_key, f.category, f.summary, f.confidence, f.pinned, f.updated_at
FROM project_facts f
INNER JOIN projects p ON p.id = f.project_id
WHERE f.confidence != 'deprecated' AND p.status = 'active'
ORDER BY f.pinned DESC, f.updated_at DESC
LIMIT ?`,
factLimit,
)
if err != nil {
return nil, fmt.Errorf("查询近期事实失败: %w", err)
}
defer rows.Close()
for rows.Next() {
var item ProjectDashboardFact
var pinned int
var updatedAt string
if err := rows.Scan(
&item.ID, &item.ProjectID, &item.ProjectName, &item.FactKey,
&item.Category, &item.Summary, &item.Confidence, &pinned, &updatedAt,
); err != nil {
return nil, err
}
item.Pinned = pinned != 0
item.ProjectName = strings.TrimSpace(item.ProjectName)
item.UpdatedAt = parseDBTime(updatedAt)
out.RecentFacts = append(out.RecentFacts, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
@@ -0,0 +1,148 @@
package database
import (
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestUpsertProjectFact_preservesBodyOnEmptyUpdate(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "facts.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
proj, err := db.CreateProject(&Project{Name: "test-facts"})
if err != nil {
t.Fatal(err)
}
const body = "## 攻击链\n1. step\n```http\nGET / HTTP/1.1\n```\n"
_, err = db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: "finding/sqli-login",
Category: "finding",
Summary: "SQLi on /login",
Body: body,
})
if err != nil {
t.Fatal(err)
}
updated, err := db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: "finding/sqli-login",
Summary: "SQLi on /login (confirmed)",
Body: "",
})
if err != nil {
t.Fatal(err)
}
if updated.Summary != "SQLi on /login (confirmed)" {
t.Fatalf("summary=%q", updated.Summary)
}
if updated.Body != body {
t.Fatalf("returned body=%q want preserved attack chain", updated.Body)
}
fromDB, err := db.GetProjectFactByKey(proj.ID, "finding/sqli-login")
if err != nil {
t.Fatal(err)
}
if fromDB.Body != body {
t.Fatalf("stored body=%q want preserved", fromDB.Body)
}
}
func TestUpsertProjectFact_replacesBodyWhenProvided(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "facts.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
proj, err := db.CreateProject(&Project{Name: "test-facts"})
if err != nil {
t.Fatal(err)
}
_, err = db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: "target/primary",
Summary: "v1",
Body: "old body",
})
if err != nil {
t.Fatal(err)
}
const newBody = "new body with evidence"
updated, err := db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: "target/primary",
Summary: "v2",
Body: newBody,
})
if err != nil {
t.Fatal(err)
}
if updated.Body != newBody {
t.Fatalf("body=%q want %q", updated.Body, newBody)
}
}
func TestRestoreProjectFact(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "facts.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
defer db.Close()
proj, err := db.CreateProject(&Project{Name: "restore-test"})
if err != nil {
t.Fatal(err)
}
key := "target/restore-me"
_, err = db.UpsertProjectFact(&ProjectFact{
ProjectID: proj.ID,
FactKey: key,
Summary: "s",
Confidence: "confirmed",
})
if err != nil {
t.Fatal(err)
}
if err := db.DeprecateProjectFact(proj.ID, key); err != nil {
t.Fatal(err)
}
if err := db.RestoreProjectFact(proj.ID, key, "confirmed"); err != nil {
t.Fatal(err)
}
f, err := db.GetProjectFactByKey(proj.ID, key)
if err != nil {
t.Fatal(err)
}
if f.Confidence != "confirmed" {
t.Fatalf("confidence=%q want confirmed", f.Confidence)
}
if err := db.RestoreProjectFact(proj.ID, key, ""); err == nil {
t.Fatal("expected error when not deprecated")
}
}
func TestMergeFactBodyOnUpdate(t *testing.T) {
if got := mergeFactBodyOnUpdate("", "keep"); got != "keep" {
t.Fatalf("empty incoming: got %q", got)
}
if got := mergeFactBodyOnUpdate(" ", "keep"); got != "keep" {
t.Fatalf("whitespace incoming: got %q", got)
}
if got := mergeFactBodyOnUpdate("new", "old"); got != "new" {
t.Fatalf("non-empty incoming: got %q", got)
}
}
+121
View File
@@ -0,0 +1,121 @@
package database
import (
"database/sql"
"fmt"
"strings"
)
// ProjectStats 项目聚合统计。
type ProjectStats struct {
FactCount int `json:"fact_count"`
VulnCount int `json:"vuln_count"`
ConversationCount int `json:"conversation_count"`
SparseFactCount int `json:"sparse_fact_count"`
}
// GetProjectStatsCounts 统计项目下事实、漏洞、对话数量(不含 sparse,由 project 包补全)。
func (db *DB) GetProjectStatsCounts(projectID string) (*ProjectStats, error) {
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return nil, fmt.Errorf("project_id 不能为空")
}
if _, err := db.GetProject(projectID); err != nil {
return nil, err
}
stats := &ProjectStats{}
if err := db.QueryRow(
`SELECT COUNT(*) FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
projectID,
).Scan(&stats.FactCount); err != nil {
return nil, fmt.Errorf("统计事实失败: %w", err)
}
if err := db.QueryRow(
`SELECT COUNT(*) FROM vulnerabilities WHERE project_id = ?`,
projectID,
).Scan(&stats.VulnCount); err != nil {
return nil, fmt.Errorf("统计漏洞失败: %w", err)
}
if err := db.QueryRow(
`SELECT COUNT(*) FROM conversations WHERE project_id = ?`,
projectID,
).Scan(&stats.ConversationCount); err != nil {
return nil, fmt.Errorf("统计对话失败: %w", err)
}
return stats, nil
}
// ListProjectFactsForSparseCheck 返回用于待补全检测的事实字段(非 deprecated)。
func (db *DB) ListProjectFactsForSparseCheck(projectID string) ([]struct {
Category string
FactKey string
Body string
}, error) {
rows, err := db.Query(
`SELECT category, fact_key, COALESCE(body,'') FROM project_facts WHERE project_id = ? AND confidence != 'deprecated'`,
projectID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []struct {
Category string
FactKey string
Body string
}
for rows.Next() {
var row struct {
Category string
FactKey string
Body string
}
if err := rows.Scan(&row.Category, &row.FactKey, &row.Body); err != nil {
return nil, err
}
out = append(out, row)
}
return out, rows.Err()
}
// ListConversationsByProjectID 列出绑定到项目的对话。
func (db *DB) ListConversationsByProjectID(projectID string, limit, offset int) ([]*Conversation, error) {
if limit <= 0 {
limit = 100
}
rows, err := db.Query(
`SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id
FROM conversations WHERE project_id = ? ORDER BY updated_at DESC LIMIT ? OFFSET ?`,
projectID, limit, offset,
)
if err != nil {
return nil, fmt.Errorf("查询项目对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var pid sql.NullString
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &pid); err != nil {
return nil, err
}
if pid.Valid {
conv.ProjectID = strings.TrimSpace(pid.String)
}
conv.CreatedAt = parseDBTime(createdAt)
conv.UpdatedAt = parseDBTime(updatedAt)
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, rows.Err()
}
// CountConversationsByProjectID 统计项目绑定对话数。
func (db *DB) CountConversationsByProjectID(projectID string) (int, error) {
var n int
err := db.QueryRow(`SELECT COUNT(*) FROM conversations WHERE project_id = ?`, projectID).Scan(&n)
return n, err
}
+93
View File
@@ -0,0 +1,93 @@
package database
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"go.uber.org/zap"
)
func TestParseDBTime_projectFactFormats(t *testing.T) {
cases := []string{
"2026-05-26 11:13:07.442143+08:00",
"2026-05-26 11:13:07",
"2026-05-26T11:13:07.442143+08:00",
}
for _, s := range cases {
got := parseDBTime(s)
if got.IsZero() {
t.Fatalf("parseDBTime(%q) returned zero", s)
}
}
}
func TestListProjectFacts_updatedAtJSON(t *testing.T) {
root, err := os.Getwd()
if err != nil {
t.Skip(err)
}
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
if _, err := os.Stat(dbPath); err != nil {
t.Skip("conversations.db not found")
}
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
projects, err := db.ListProjects("", "", 1, 0)
if err != nil || len(projects) == 0 {
t.Skip("no projects")
}
pid := projects[0].ID
list, err := db.ListProjectFacts(pid, ProjectFactListFilter{}, 5, 0)
if err != nil {
t.Fatal(err)
}
if len(list) == 0 {
t.Skip("no facts")
}
for _, f := range list {
if f.UpdatedAt.IsZero() {
t.Fatalf("fact %s UpdatedAt is zero after ListProjectFacts", f.FactKey)
}
b, err := json.Marshal(f)
if err != nil {
t.Fatal(err)
}
var m map[string]interface{}
if err := json.Unmarshal(b, &m); err != nil {
t.Fatal(err)
}
raw, ok := m["updated_at"].(string)
if !ok || raw == "" || raw[:4] == "0001" {
t.Fatalf("bad updated_at in JSON: %v", m["updated_at"])
}
}
}
func TestParseDBTime_zeroOnGarbage(t *testing.T) {
if !parseDBTime("").IsZero() {
t.Fatal("expected zero for empty")
}
}
// Ensure RFC3339 round-trip used by API is after year 2000.
func TestParseDBTime_marshalRoundTrip(t *testing.T) {
s := "2026-05-26 11:13:07.442143+08:00"
tm := parseDBTime(s)
b, err := json.Marshal(tm)
if err != nil {
t.Fatal(err)
}
var back time.Time
if err := json.Unmarshal(b, &back); err != nil {
t.Fatal(err)
}
if back.IsZero() {
t.Fatalf("unmarshal zero from %s", string(b))
}
}
+84
View File
@@ -0,0 +1,84 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
)
// RobotSessionBinding 机器人会话绑定信息。
type RobotSessionBinding struct {
SessionKey string
ConversationID string
RoleName string
UpdatedAt time.Time
}
// GetRobotSessionBinding 按 session_key 获取机器人会话绑定。
func (db *DB) GetRobotSessionBinding(sessionKey string) (*RobotSessionBinding, error) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return nil, nil
}
var b RobotSessionBinding
var updatedAt string
err := db.QueryRow(
"SELECT session_key, conversation_id, role_name, updated_at FROM robot_user_sessions WHERE session_key = ?",
sessionKey,
).Scan(&b.SessionKey, &b.ConversationID, &b.RoleName, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询机器人会话绑定失败: %w", err)
}
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
b.UpdatedAt = t
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
b.UpdatedAt = t
} else {
b.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
if strings.TrimSpace(b.RoleName) == "" {
b.RoleName = "默认"
}
return &b, nil
}
// UpsertRobotSessionBinding 写入或更新机器人会话绑定(包含角色)。
func (db *DB) UpsertRobotSessionBinding(sessionKey, conversationID, roleName string) error {
sessionKey = strings.TrimSpace(sessionKey)
conversationID = strings.TrimSpace(conversationID)
roleName = strings.TrimSpace(roleName)
if sessionKey == "" || conversationID == "" {
return nil
}
if roleName == "" {
roleName = "默认"
}
_, err := db.Exec(`
INSERT INTO robot_user_sessions (session_key, conversation_id, role_name, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(session_key) DO UPDATE SET
conversation_id = excluded.conversation_id,
role_name = excluded.role_name,
updated_at = excluded.updated_at
`, sessionKey, conversationID, roleName, time.Now())
if err != nil {
return fmt.Errorf("写入机器人会话绑定失败: %w", err)
}
return nil
}
// DeleteRobotSessionBinding 删除机器人会话绑定。
func (db *DB) DeleteRobotSessionBinding(sessionKey string) error {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return nil
}
if _, err := db.Exec("DELETE FROM robot_user_sessions WHERE session_key = ?", sessionKey); err != nil {
return fmt.Errorf("删除机器人会话绑定失败: %w", err)
}
return nil
}
+142
View File
@@ -0,0 +1,142 @@
package database
import (
"database/sql"
"time"
"go.uber.org/zap"
)
// SkillStats Skills统计信息
type SkillStats struct {
SkillName string
TotalCalls int
SuccessCalls int
FailedCalls int
LastCallTime *time.Time
}
// SaveSkillStats 保存Skills统计信息
func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error {
var lastCallTime sql.NullTime
if stats.LastCallTime != nil {
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
}
query := `
INSERT OR REPLACE INTO skill_stats
(skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
skillName,
stats.TotalCalls,
stats.SuccessCalls,
stats.FailedCalls,
lastCallTime,
time.Now(),
)
if err != nil {
db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
return nil
}
// LoadSkillStats 加载所有Skills统计信息
func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) {
query := `
SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time
FROM skill_stats
`
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
stats := make(map[string]*SkillStats)
for rows.Next() {
var stat SkillStats
var lastCallTime sql.NullTime
err := rows.Scan(
&stat.SkillName,
&stat.TotalCalls,
&stat.SuccessCalls,
&stat.FailedCalls,
&lastCallTime,
)
if err != nil {
db.logger.Warn("加载Skills统计信息失败", zap.Error(err))
continue
}
if lastCallTime.Valid {
stat.LastCallTime = &lastCallTime.Time
}
stats[stat.SkillName] = &stat
}
return stats, nil
}
// UpdateSkillStats 更新Skills统计信息(累加模式)
func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
var lastCallTimeSQL sql.NullTime
if lastCallTime != nil {
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
}
query := `
INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(skill_name) DO UPDATE SET
total_calls = total_calls + ?,
success_calls = success_calls + ?,
failed_calls = failed_calls + ?,
last_call_time = COALESCE(?, last_call_time),
updated_at = ?
`
_, err := db.Exec(query,
skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
)
if err != nil {
db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
return nil
}
// ClearSkillStats 清空所有Skills统计信息
func (db *DB) ClearSkillStats() error {
query := `DELETE FROM skill_stats`
_, err := db.Exec(query)
if err != nil {
db.logger.Error("清空Skills统计信息失败", zap.Error(err))
return err
}
db.logger.Info("已清空所有Skills统计信息")
return nil
}
// ClearSkillStatsByName 清空指定skill的统计信息
func (db *DB) ClearSkillStatsByName(skillName string) error {
query := `DELETE FROM skill_stats WHERE skill_name = ?`
_, err := db.Exec(query, skillName)
if err != nil {
db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName))
return nil
}
+33
View File
@@ -0,0 +1,33 @@
package database
import (
"errors"
"strings"
"time"
)
// formatSQLiteUTC stores instants as UTC RFC3339 for consistent SQLite reads/writes.
func formatSQLiteUTC(t time.Time) string {
return t.UTC().Format(time.RFC3339Nano)
}
// sqliteEpochGE returns SQL comparing column to param as Unix seconds (timezone-safe).
func sqliteEpochGE(column, op string) string {
return "strftime('%s', " + column + ") " + op + " strftime('%s', ?)"
}
// ParseRFC3339Time parses API/query timestamps (RFC3339 or RFC3339Nano).
func ParseRFC3339Time(value string) (time.Time, error) {
value = strings.TrimSpace(value)
if value == "" {
return time.Time{}, errors.New("empty time value")
}
if t, err := time.Parse(time.RFC3339Nano, value); err == nil {
return t.UTC(), nil
}
t, err := time.Parse(time.RFC3339, value)
if err != nil {
return time.Time{}, err
}
return t.UTC(), nil
}
+440
View File
@@ -0,0 +1,440 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
type VulnerabilityListFilter struct {
ID string
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
ConversationID string
ProjectID string
Severity string
Status string
TaskID string
ConversationTag string
TaskTag string
}
func escapeVulnerabilityLikePattern(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return "%" + s + "%"
}
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
if f.ID != "" {
query += " AND id = ?"
args = append(args, f.ID)
}
if f.ConversationID != "" {
query += " AND conversation_id = ?"
args = append(args, f.ConversationID)
}
if f.ProjectID != "" {
query += " AND project_id = ?"
args = append(args, f.ProjectID)
}
if f.TaskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, f.TaskID, f.TaskID)
}
if f.ConversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, f.ConversationTag)
}
if f.TaskTag != "" {
query += " AND task_tag = ?"
args = append(args, f.TaskTag)
}
if f.Severity != "" {
query += " AND severity = ?"
args = append(args, f.Severity)
}
if f.Status != "" {
query += " AND status = ?"
args = append(args, f.Status)
}
search := strings.TrimSpace(f.Search)
if search != "" {
pattern := escapeVulnerabilityLikePattern(search)
query += ` AND (
LOWER(id) LIKE LOWER(?) OR
LOWER(title) LIKE LOWER(?) OR
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
)`
for i := 0; i < 11; i++ {
args = append(args, pattern)
}
}
return query, args
}
// Vulnerability 漏洞
type Vulnerability struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
ProjectID string `json:"project_id,omitempty"`
ConversationTag string `json:"conversation_tag,omitempty"`
TaskTag string `json:"task_tag,omitempty"`
TaskID string `json:"task_id,omitempty"`
TaskQueueID string `json:"task_queue_id,omitempty"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info
Status string `json:"status"` // open, confirmed, fixed, false_positive, ignored
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// CreateVulnerability 创建漏洞
func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
if vuln.ID == "" {
vuln.ID = uuid.New().String()
}
if vuln.Status == "" {
vuln.Status = "open"
}
now := time.Now()
if vuln.CreatedAt.IsZero() {
vuln.CreatedAt = now
}
vuln.UpdatedAt = now
if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID != "" {
if pid, err := db.GetConversationProjectID(vuln.ConversationID); err == nil {
vuln.ProjectID = pid
}
}
query := `
INSERT INTO vulnerabilities (
id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(
query,
vuln.ID, nullIfEmpty(vuln.ConversationID), nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建漏洞失败: %w", err)
}
return vuln, nil
}
// GetVulnerability 获取漏洞
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability
query := `
SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status,
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at
FROM vulnerabilities
WHERE id = ?
`
err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("漏洞不存在")
}
return nil, fmt.Errorf("获取漏洞失败: %w", err)
}
return &vuln, nil
}
// ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
query := `
SELECT id, COALESCE(conversation_id,''), COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at
FROM vulnerabilities
WHERE 1=1
`
args := []interface{}{}
query, args = filter.appendWhere(query, args)
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询漏洞列表失败: %w", err)
}
defer rows.Close()
var vulnerabilities []*Vulnerability
for rows.Next() {
var vuln Vulnerability
err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
db.logger.Warn("扫描漏洞记录失败", zap.Error(err))
continue
}
vulnerabilities = append(vulnerabilities, &vuln)
}
return vulnerabilities, nil
}
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{}
query, args = filter.appendWhere(query, args)
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计漏洞总数失败: %w", err)
}
return count, nil
}
// UpdateVulnerability 更新漏洞
func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
vuln.UpdatedAt = time.Now()
query := `
UPDATE vulnerabilities
SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ?
WHERE id = ?
`
_, err := db.Exec(
query,
nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id,
)
if err != nil {
return fmt.Errorf("更新漏洞失败: %w", err)
}
return nil
}
// DeleteVulnerabilitiesByFilter 按筛选条件批量删除漏洞,返回实际删除条数
func (db *DB) DeleteVulnerabilitiesByFilter(filter VulnerabilityListFilter) (int64, error) {
tx, err := db.Begin()
if err != nil {
return 0, fmt.Errorf("开启事务失败: %w", err)
}
defer func() { _ = tx.Rollback() }()
where := "WHERE 1=1"
args := []interface{}{}
where, args = filter.appendWhere(where, args)
clearQuery := `UPDATE project_facts SET related_vulnerability_id = NULL
WHERE related_vulnerability_id IN (SELECT id FROM vulnerabilities ` + where + `)`
if _, err := tx.Exec(clearQuery, args...); err != nil {
return 0, fmt.Errorf("清理事实漏洞关联失败: %w", err)
}
deleteQuery := `DELETE FROM vulnerabilities ` + where
result, err := tx.Exec(deleteQuery, args...)
if err != nil {
return 0, fmt.Errorf("批量删除漏洞失败: %w", err)
}
deleted, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("获取删除条数失败: %w", err)
}
if err := tx.Commit(); err != nil {
return 0, fmt.Errorf("提交事务失败: %w", err)
}
return deleted, nil
}
// DeleteVulnerability 删除漏洞
func (db *DB) DeleteVulnerability(id string) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开启事务失败: %w", err)
}
defer func() { _ = tx.Rollback() }()
// 删除漏洞前先解除项目事实中的关联,避免前端继续显示已删除漏洞的短 ID。
if _, err := tx.Exec("UPDATE project_facts SET related_vulnerability_id = NULL WHERE related_vulnerability_id = ?", id); err != nil {
return fmt.Errorf("清理事实漏洞关联失败: %w", err)
}
if _, err := tx.Exec("DELETE FROM vulnerabilities WHERE id = ?", id); err != nil {
return fmt.Errorf("删除漏洞失败: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
stats := make(map[string]interface{})
where := "WHERE 1=1"
args := []interface{}{}
where, args = filter.appendWhere(where, args)
// 总漏洞数
var totalCount int
query := "SELECT COUNT(*) FROM vulnerabilities " + where
err := db.QueryRow(query, args...).Scan(&totalCount)
if err != nil {
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
}
stats["total"] = totalCount
// 按严重程度统计
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity"
rows, err := db.Query(severityQuery, args...)
if err != nil {
return nil, fmt.Errorf("获取严重程度统计失败: %w", err)
}
defer rows.Close()
severityStats := make(map[string]int)
for rows.Next() {
var severity string
var count int
if err := rows.Scan(&severity, &count); err != nil {
continue
}
severityStats[severity] = count
}
stats["by_severity"] = severityStats
// 按状态统计
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status"
rows, err = db.Query(statusQuery, args...)
if err != nil {
return nil, fmt.Errorf("获取状态统计失败: %w", err)
}
defer rows.Close()
statusStats := make(map[string]int)
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
continue
}
statusStats[status] = count
}
stats["by_status"] = statusStats
return stats, nil
}
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
collect := func(query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]string, 0)
for rows.Next() {
var val string
if err := rows.Scan(&val); err != nil {
continue
}
if val == "" {
continue
}
items = append(items, val)
}
return items, nil
}
vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
}
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id IS NOT NULL AND conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
}
taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务ID建议失败: %w", err)
}
queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询队列ID建议失败: %w", err)
}
conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询对话标签建议失败: %w", err)
}
taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
}
projectIDs, err := collect(`SELECT DISTINCT project_id FROM vulnerabilities WHERE project_id IS NOT NULL AND project_id <> '' ORDER BY created_at DESC LIMIT 200`)
if err != nil {
return nil, fmt.Errorf("查询项目ID建议失败: %w", err)
}
return map[string][]string{
"vulnerability_ids": vulnIDs,
"conversation_ids": conversationIDs,
"project_ids": projectIDs,
"task_ids": taskIDs,
"queue_ids": queueIDs,
"conversation_tags": conversationTags,
"task_tags": taskTags,
}, nil
}
+152
View File
@@ -0,0 +1,152 @@
package database
import (
"database/sql"
"time"
"go.uber.org/zap"
)
// WebShellConnection WebShell 连接配置
type WebShellConnection struct {
ID string `json:"id"`
URL string `json:"url"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"`
CmdParam string `json:"cmdParam"`
Remark string `json:"remark"`
Encoding string `json:"encoding"` // 目标响应编码:auto / utf-8 / gbk / gb18030,空值视为 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空值/未知视为 auto
CreatedAt time.Time `json:"createdAt"`
}
// GetWebshellConnectionState 获取连接关联的持久化状态 JSON,不存在时返回 "{}"
func (db *DB) GetWebshellConnectionState(connectionID string) (string, error) {
var stateJSON string
err := db.QueryRow(`SELECT state_json FROM webshell_connection_states WHERE connection_id = ?`, connectionID).Scan(&stateJSON)
if err == sql.ErrNoRows {
return "{}", nil
}
if err != nil {
db.logger.Error("查询 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID))
return "", err
}
if stateJSON == "" {
stateJSON = "{}"
}
return stateJSON, nil
}
// UpsertWebshellConnectionState 保存连接关联的持久化状态 JSON
func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) error {
if stateJSON == "" {
stateJSON = "{}"
}
query := `
INSERT INTO webshell_connection_states (connection_id, state_json, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(connection_id) DO UPDATE SET
state_json = excluded.state_json,
updated_at = excluded.updated_at
`
if _, err := db.Exec(query, connectionID, stateJSON, time.Now()); err != nil {
db.logger.Error("保存 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID))
return err
}
return nil
}
// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序
func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
query := `
SELECT id, url, password, type, method, cmd_param, remark,
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
FROM webshell_connections
ORDER BY created_at DESC
`
rows, err := db.Query(query)
if err != nil {
db.logger.Error("查询 WebShell 连接列表失败", zap.Error(err))
return nil, err
}
defer rows.Close()
var list []WebShellConnection
for rows.Next() {
var c WebShellConnection
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
if err != nil {
db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err))
continue
}
list = append(list, c)
}
return list, rows.Err()
}
// GetWebshellConnection 根据 ID 获取一条连接
func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
query := `
SELECT id, url, password, type, method, cmd_param, remark,
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
FROM webshell_connections WHERE id = ?
`
var c WebShellConnection
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
db.logger.Error("查询 WebShell 连接失败", zap.Error(err), zap.String("id", id))
return nil, err
}
return &c, nil
}
// CreateWebshellConnection 创建 WebShell 连接
func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
query := `
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, encoding, os, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.CreatedAt)
if err != nil {
db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err
}
return nil
}
// UpdateWebshellConnection 更新 WebShell 连接
func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error {
query := `
UPDATE webshell_connections
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?, encoding = ?, os = ?
WHERE id = ?
`
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.ID)
if err != nil {
db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
// DeleteWebshellConnection 删除 WebShell 连接
func (db *DB) DeleteWebshellConnection(id string) error {
result, err := db.Exec(`DELETE FROM webshell_connections WHERE id = ?`, id)
if err != nil {
db.logger.Error("删除 WebShell 连接失败", zap.Error(err), zap.String("id", id))
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return sql.ErrNoRows
}
return nil
}