package database import ( "database/sql" "encoding/json" "fmt" "time" "github.com/google/uuid" "go.uber.org/zap" ) // Conversation 对话 type Conversation struct { ID string `json:"id"` Title string `json:"title"` Pinned bool `json:"pinned"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` Messages []Message `json:"messages,omitempty"` } // Message 消息 type Message struct { ID string `json:"id"` ConversationID string `json:"conversationId"` Role string `json:"role"` Content string `json:"content"` MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"` CreatedAt time.Time `json:"createdAt"` } // CreateConversation 创建新对话 func (db *DB) CreateConversation(title string) (*Conversation, error) { return db.CreateConversationWithWebshell("", title) } // CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话) func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) { id := uuid.New().String() now := time.Now() var err error if webshellConnectionID != "" { _, err = db.Exec( "INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)", id, title, now, now, webshellConnectionID, ) } else { _, err = db.Exec( "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", id, title, now, now, ) } if err != nil { return nil, fmt.Errorf("创建对话失败: %w", err) } return &Conversation{ ID: id, Title: title, CreatedAt: now, UpdatedAt: now, }, nil } // GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化) func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conversation, error) { if connectionID == "" { return nil, fmt.Errorf("connectionID is empty") } var conv Conversation var createdAt, updatedAt string var pinned int err := db.QueryRow( "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC LIMIT 1", connectionID, ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, fmt.Errorf("查询对话失败: %w", err) } conv.Pinned = pinned != 0 if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt); e == nil { conv.CreatedAt = t } else if t, e := time.Parse("2006-01-02 15:04:05", createdAt); e == nil { conv.CreatedAt = t } else { conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { conv.UpdatedAt = t } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { conv.UpdatedAt = t } else { conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) } messages, err := db.GetMessages(conv.ID) if err != nil { return nil, fmt.Errorf("加载消息失败: %w", err) } conv.Messages = messages // 加载过程详情并附加到对应消息(与 GetConversation 一致,便于刷新后仍可查看执行过程) processDetailsMap, err := db.GetProcessDetailsByConversation(conv.ID) if err != nil { db.logger.Warn("加载过程详情失败", zap.Error(err)) processDetailsMap = make(map[string][]ProcessDetail) } for i := range conv.Messages { if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { detailsJSON := make([]map[string]interface{}, len(details)) for j, detail := range details { var data interface{} if detail.Data != "" { if err := json.Unmarshal([]byte(detail.Data), &data); err != nil { db.logger.Warn("解析过程详情数据失败", zap.Error(err)) } } detailsJSON[j] = map[string]interface{}{ "id": detail.ID, "messageId": detail.MessageID, "conversationId": detail.ConversationID, "eventType": detail.EventType, "message": detail.Message, "data": data, "createdAt": detail.CreatedAt, } } conv.Messages[i].ProcessDetails = detailsJSON } } return &conv, nil } // WebShellConversationItem 用于侧边栏列表,不含消息 type WebShellConversationItem struct { ID string `json:"id"` Title string `json:"title"` UpdatedAt time.Time `json:"updatedAt"` } // ListConversationsByWebshellConnectionID 列出该 WebShell 连接下的所有对话(按更新时间倒序),供侧边栏展示 func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]WebShellConversationItem, error) { if connectionID == "" { return nil, nil } rows, err := db.Query( "SELECT id, title, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC", connectionID, ) if err != nil { return nil, fmt.Errorf("查询对话列表失败: %w", err) } defer rows.Close() var list []WebShellConversationItem for rows.Next() { var item WebShellConversationItem var updatedAt string if err := rows.Scan(&item.ID, &item.Title, &updatedAt); err != nil { continue } if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil { item.UpdatedAt = t } else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil { item.UpdatedAt = t } else { item.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) } list = append(list, item) } return list, rows.Err() } // GetConversation 获取对话 func (db *DB) GetConversation(id string) (*Conversation, error) { var conv Conversation var createdAt, updatedAt string var pinned int err := db.QueryRow( "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", id, ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("对话不存在") } return nil, fmt.Errorf("查询对话失败: %w", err) } // 尝试多种时间格式解析 var err1, err2 error conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) if err1 != nil { conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) } if err1 != nil { conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) if err2 != nil { conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) } if err2 != nil { conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) } conv.Pinned = pinned != 0 // 加载消息 messages, err := db.GetMessages(id) if err != nil { return nil, fmt.Errorf("加载消息失败: %w", err) } conv.Messages = messages // 加载过程详情(按消息ID分组) processDetailsMap, err := db.GetProcessDetailsByConversation(id) if err != nil { db.logger.Warn("加载过程详情失败", zap.Error(err)) processDetailsMap = make(map[string][]ProcessDetail) } // 将过程详情附加到对应的消息上 for i := range conv.Messages { if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { // 将ProcessDetail转换为JSON格式,以便前端使用 detailsJSON := make([]map[string]interface{}, len(details)) for j, detail := range details { var data interface{} if detail.Data != "" { if err := json.Unmarshal([]byte(detail.Data), &data); err != nil { db.logger.Warn("解析过程详情数据失败", zap.Error(err)) } } detailsJSON[j] = map[string]interface{}{ "id": detail.ID, "messageId": detail.MessageID, "conversationId": detail.ConversationID, "eventType": detail.EventType, "message": detail.Message, "data": data, "createdAt": detail.CreatedAt, } } conv.Messages[i].ProcessDetails = detailsJSON } } return &conv, nil } // ListConversations 列出所有对话 func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) { var rows *sql.Rows var err error if search != "" { // 使用LIKE进行模糊搜索,搜索标题和消息内容 searchPattern := "%" + search + "%" // 使用DISTINCT避免重复,因为一个对话可能有多条消息匹配 rows, err = db.Query( `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at FROM conversations c LEFT JOIN messages m ON c.id = m.conversation_id WHERE c.title LIKE ? OR m.content LIKE ? ORDER BY c.updated_at DESC LIMIT ? OFFSET ?`, searchPattern, searchPattern, limit, offset, ) } else { rows, err = db.Query( "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", limit, offset, ) } if err != nil { return nil, fmt.Errorf("查询对话列表失败: %w", err) } defer rows.Close() var conversations []*Conversation for rows.Next() { var conv Conversation var createdAt, updatedAt string var pinned int if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil { return nil, fmt.Errorf("扫描对话失败: %w", err) } // 尝试多种时间格式解析 var err1, err2 error conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) if err1 != nil { conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) } if err1 != nil { conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) if err2 != nil { conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) } if err2 != nil { conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt) } conv.Pinned = pinned != 0 conversations = append(conversations, &conv) } return conversations, nil } // UpdateConversationTitle 更新对话标题 func (db *DB) UpdateConversationTitle(id, title string) error { // 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间 _, err := db.Exec( "UPDATE conversations SET title = ? WHERE id = ?", title, id, ) if err != nil { return fmt.Errorf("更新对话标题失败: %w", err) } return nil } // UpdateConversationTime 更新对话时间 func (db *DB) UpdateConversationTime(id string) error { _, err := db.Exec( "UPDATE conversations SET updated_at = ? WHERE id = ?", time.Now(), id, ) if err != nil { return fmt.Errorf("更新对话时间失败: %w", err) } return nil } // DeleteConversation 删除对话及其所有相关数据 // 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除: // - messages(消息) // - process_details(过程详情) // - attack_chain_nodes(攻击链节点) // - attack_chain_edges(攻击链边) // - vulnerabilities(漏洞) // - conversation_group_mappings(分组映射) // 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL func (db *DB) DeleteConversation(id string) error { // 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除) _, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id) if err != nil { db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err)) // 不返回错误,继续删除对话 } // 删除对话(外键CASCADE会自动删除其他相关数据) _, err = db.Exec("DELETE FROM conversations WHERE id = ?", id) if err != nil { return fmt.Errorf("删除对话失败: %w", err) } db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) return nil } // SaveReActData 保存最后一轮ReAct的输入和输出 func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error { _, err := db.Exec( "UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?", reactInput, reactOutput, time.Now(), conversationID, ) if err != nil { return fmt.Errorf("保存ReAct数据失败: %w", err) } return nil } // GetReActData 获取最后一轮ReAct的输入和输出 func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) { var input, output sql.NullString err = db.QueryRow( "SELECT last_react_input, last_react_output FROM conversations WHERE id = ?", conversationID, ).Scan(&input, &output) if err != nil { if err == sql.ErrNoRows { return "", "", fmt.Errorf("对话不存在") } return "", "", fmt.Errorf("获取ReAct数据失败: %w", err) } if input.Valid { reactInput = input.String } if output.Valid { reactOutput = output.String } return reactInput, reactOutput, nil } // AddMessage 添加消息 func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { id := uuid.New().String() var mcpIDsJSON string if len(mcpExecutionIDs) > 0 { jsonData, err := json.Marshal(mcpExecutionIDs) if err != nil { db.logger.Warn("序列化MCP执行ID失败", zap.Error(err)) } else { mcpIDsJSON = string(jsonData) } } _, err := db.Exec( "INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)", id, conversationID, role, content, mcpIDsJSON, time.Now(), ) if err != nil { return nil, fmt.Errorf("添加消息失败: %w", err) } // 更新对话时间 if err := db.UpdateConversationTime(conversationID); err != nil { db.logger.Warn("更新对话时间失败", zap.Error(err)) } message := &Message{ ID: id, ConversationID: conversationID, Role: role, Content: content, MCPExecutionIDs: mcpExecutionIDs, CreatedAt: time.Now(), } return message, nil } // GetMessages 获取对话的所有消息 func (db *DB) GetMessages(conversationID string) ([]Message, error) { rows, err := db.Query( "SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", conversationID, ) if err != nil { return nil, fmt.Errorf("查询消息失败: %w", err) } defer rows.Close() var messages []Message for rows.Next() { var msg Message var mcpIDsJSON sql.NullString var createdAt string if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil { return nil, fmt.Errorf("扫描消息失败: %w", err) } // 尝试多种时间格式解析 var err error msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) if err != nil { msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) } if err != nil { msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } // 解析MCP执行ID if mcpIDsJSON.Valid && mcpIDsJSON.String != "" { if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil { db.logger.Warn("解析MCP执行ID失败", zap.Error(err)) } } messages = append(messages, msg) } return messages, nil } // ProcessDetail 过程详情事件 type ProcessDetail struct { ID string `json:"id"` MessageID string `json:"messageId"` ConversationID string `json:"conversationId"` EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error Message string `json:"message"` Data string `json:"data"` // JSON格式的数据 CreatedAt time.Time `json:"createdAt"` } // AddProcessDetail 添加过程详情事件 func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error { id := uuid.New().String() var dataJSON string if data != nil { jsonData, err := json.Marshal(data) if err != nil { db.logger.Warn("序列化过程详情数据失败", zap.Error(err)) } else { dataJSON = string(jsonData) } } _, err := db.Exec( "INSERT INTO process_details (id, message_id, conversation_id, event_type, message, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", id, messageID, conversationID, eventType, message, dataJSON, time.Now(), ) if err != nil { return fmt.Errorf("添加过程详情失败: %w", err) } return nil } // GetProcessDetails 获取消息的过程详情 func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) { rows, err := db.Query( "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC", messageID, ) if err != nil { return nil, fmt.Errorf("查询过程详情失败: %w", err) } defer rows.Close() var details []ProcessDetail for rows.Next() { var detail ProcessDetail var createdAt string if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil { return nil, fmt.Errorf("扫描过程详情失败: %w", err) } // 尝试多种时间格式解析 var err error detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) if err != nil { detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) } if err != nil { detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } details = append(details, detail) } return details, nil } // GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组) func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) { rows, err := db.Query( "SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC", conversationID, ) if err != nil { return nil, fmt.Errorf("查询过程详情失败: %w", err) } defer rows.Close() detailsMap := make(map[string][]ProcessDetail) for rows.Next() { var detail ProcessDetail var createdAt string if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil { return nil, fmt.Errorf("扫描过程详情失败: %w", err) } // 尝试多种时间格式解析 var err error detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) if err != nil { detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) } if err != nil { detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) } detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail) } return detailsMap, nil }