mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
461 lines
13 KiB
Go
461 lines
13 KiB
Go
package database
|
||
|
||
import (
|
||
"database/sql"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// Conversation 对话
|
||
type Conversation struct {
|
||
ID string `json:"id"`
|
||
Title string `json:"title"`
|
||
Pinned bool `json:"pinned"`
|
||
CreatedAt time.Time `json:"createdAt"`
|
||
UpdatedAt time.Time `json:"updatedAt"`
|
||
Messages []Message `json:"messages,omitempty"`
|
||
}
|
||
|
||
// Message 消息
|
||
type Message struct {
|
||
ID string `json:"id"`
|
||
ConversationID string `json:"conversationId"`
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||
CreatedAt time.Time `json:"createdAt"`
|
||
}
|
||
|
||
// CreateConversation 创建新对话
|
||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||
id := uuid.New().String()
|
||
now := time.Now()
|
||
|
||
_, err := db.Exec(
|
||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||
id, title, now, now,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||
}
|
||
|
||
return &Conversation{
|
||
ID: id,
|
||
Title: title,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}, nil
|
||
}
|
||
|
||
// GetConversation 获取对话
|
||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||
var conv Conversation
|
||
var createdAt, updatedAt string
|
||
var pinned int
|
||
|
||
err := db.QueryRow(
|
||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||
id,
|
||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||
if err != nil {
|
||
if err == sql.ErrNoRows {
|
||
return nil, fmt.Errorf("对话不存在")
|
||
}
|
||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||
}
|
||
|
||
// 尝试多种时间格式解析
|
||
var err1, err2 error
|
||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||
if err1 != nil {
|
||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||
}
|
||
if err1 != nil {
|
||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
}
|
||
|
||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||
if err2 != nil {
|
||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||
}
|
||
if err2 != nil {
|
||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||
}
|
||
|
||
conv.Pinned = pinned != 0
|
||
|
||
// 加载消息
|
||
messages, err := db.GetMessages(id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||
}
|
||
conv.Messages = messages
|
||
|
||
// 加载过程详情(按消息ID分组)
|
||
processDetailsMap, err := db.GetProcessDetailsByConversation(id)
|
||
if err != nil {
|
||
db.logger.Warn("加载过程详情失败", zap.Error(err))
|
||
processDetailsMap = make(map[string][]ProcessDetail)
|
||
}
|
||
|
||
// 将过程详情附加到对应的消息上
|
||
for i := range conv.Messages {
|
||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||
// 将ProcessDetail转换为JSON格式,以便前端使用
|
||
detailsJSON := make([]map[string]interface{}, len(details))
|
||
for j, detail := range details {
|
||
var data interface{}
|
||
if detail.Data != "" {
|
||
if err := json.Unmarshal([]byte(detail.Data), &data); err != nil {
|
||
db.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||
}
|
||
}
|
||
detailsJSON[j] = map[string]interface{}{
|
||
"id": detail.ID,
|
||
"messageId": detail.MessageID,
|
||
"conversationId": detail.ConversationID,
|
||
"eventType": detail.EventType,
|
||
"message": detail.Message,
|
||
"data": data,
|
||
"createdAt": detail.CreatedAt,
|
||
}
|
||
}
|
||
conv.Messages[i].ProcessDetails = detailsJSON
|
||
}
|
||
}
|
||
|
||
return &conv, nil
|
||
}
|
||
|
||
// ListConversations 列出所有对话
|
||
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
|
||
var rows *sql.Rows
|
||
var err error
|
||
|
||
if search != "" {
|
||
// 使用LIKE进行模糊搜索,搜索标题和消息内容
|
||
searchPattern := "%" + search + "%"
|
||
// 使用DISTINCT避免重复,因为一个对话可能有多条消息匹配
|
||
rows, err = db.Query(
|
||
`SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
||
FROM conversations c
|
||
LEFT JOIN messages m ON c.id = m.conversation_id
|
||
WHERE c.title LIKE ? OR m.content LIKE ?
|
||
ORDER BY c.updated_at DESC
|
||
LIMIT ? OFFSET ?`,
|
||
searchPattern, searchPattern, limit, offset,
|
||
)
|
||
} else {
|
||
rows, err = db.Query(
|
||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||
limit, offset,
|
||
)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询对话列表失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var conversations []*Conversation
|
||
for rows.Next() {
|
||
var conv Conversation
|
||
var createdAt, updatedAt string
|
||
var pinned int
|
||
|
||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil {
|
||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||
}
|
||
|
||
// 尝试多种时间格式解析
|
||
var err1, err2 error
|
||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||
if err1 != nil {
|
||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||
}
|
||
if err1 != nil {
|
||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
}
|
||
|
||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||
if err2 != nil {
|
||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||
}
|
||
if err2 != nil {
|
||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||
}
|
||
|
||
conv.Pinned = pinned != 0
|
||
|
||
conversations = append(conversations, &conv)
|
||
}
|
||
|
||
return conversations, nil
|
||
}
|
||
|
||
// UpdateConversationTitle 更新对话标题
|
||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
||
_, err := db.Exec(
|
||
"UPDATE conversations SET title = ? WHERE id = ?",
|
||
title, id,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("更新对话标题失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// UpdateConversationTime 更新对话时间
|
||
func (db *DB) UpdateConversationTime(id string) error {
|
||
_, err := db.Exec(
|
||
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||
time.Now(), id,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("更新对话时间失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DeleteConversation 删除对话
|
||
func (db *DB) DeleteConversation(id string) error {
|
||
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||
if err != nil {
|
||
return fmt.Errorf("删除对话失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// SaveReActData 保存最后一轮ReAct的输入和输出
|
||
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
|
||
_, err := db.Exec(
|
||
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
|
||
reactInput, reactOutput, time.Now(), conversationID,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("保存ReAct数据失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetReActData 获取最后一轮ReAct的输入和输出
|
||
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
|
||
var input, output sql.NullString
|
||
err = db.QueryRow(
|
||
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
|
||
conversationID,
|
||
).Scan(&input, &output)
|
||
if err != nil {
|
||
if err == sql.ErrNoRows {
|
||
return "", "", fmt.Errorf("对话不存在")
|
||
}
|
||
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
|
||
}
|
||
|
||
if input.Valid {
|
||
reactInput = input.String
|
||
}
|
||
if output.Valid {
|
||
reactOutput = output.String
|
||
}
|
||
|
||
return reactInput, reactOutput, nil
|
||
}
|
||
|
||
// AddMessage 添加消息
|
||
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
|
||
id := uuid.New().String()
|
||
|
||
var mcpIDsJSON string
|
||
if len(mcpExecutionIDs) > 0 {
|
||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||
if err != nil {
|
||
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
|
||
} else {
|
||
mcpIDsJSON = string(jsonData)
|
||
}
|
||
}
|
||
|
||
_, err := db.Exec(
|
||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||
id, conversationID, role, content, mcpIDsJSON, time.Now(),
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||
}
|
||
|
||
// 更新对话时间
|
||
if err := db.UpdateConversationTime(conversationID); err != nil {
|
||
db.logger.Warn("更新对话时间失败", zap.Error(err))
|
||
}
|
||
|
||
message := &Message{
|
||
ID: id,
|
||
ConversationID: conversationID,
|
||
Role: role,
|
||
Content: content,
|
||
MCPExecutionIDs: mcpExecutionIDs,
|
||
CreatedAt: time.Now(),
|
||
}
|
||
|
||
return message, nil
|
||
}
|
||
|
||
// GetMessages 获取对话的所有消息
|
||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||
rows, err := db.Query(
|
||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||
conversationID,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var messages []Message
|
||
for rows.Next() {
|
||
var msg Message
|
||
var mcpIDsJSON sql.NullString
|
||
var createdAt string
|
||
|
||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
|
||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||
}
|
||
|
||
// 尝试多种时间格式解析
|
||
var err error
|
||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||
if err != nil {
|
||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||
}
|
||
if err != nil {
|
||
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
}
|
||
|
||
// 解析MCP执行ID
|
||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
messages = append(messages, msg)
|
||
}
|
||
|
||
return messages, nil
|
||
}
|
||
|
||
// ProcessDetail 过程详情事件
|
||
type ProcessDetail struct {
|
||
ID string `json:"id"`
|
||
MessageID string `json:"messageId"`
|
||
ConversationID string `json:"conversationId"`
|
||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
||
Message string `json:"message"`
|
||
Data string `json:"data"` // JSON格式的数据
|
||
CreatedAt time.Time `json:"createdAt"`
|
||
}
|
||
|
||
// AddProcessDetail 添加过程详情事件
|
||
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
|
||
id := uuid.New().String()
|
||
|
||
var dataJSON string
|
||
if data != nil {
|
||
jsonData, err := json.Marshal(data)
|
||
if err != nil {
|
||
db.logger.Warn("序列化过程详情数据失败", zap.Error(err))
|
||
} else {
|
||
dataJSON = string(jsonData)
|
||
}
|
||
}
|
||
|
||
_, err := db.Exec(
|
||
"INSERT INTO process_details (id, message_id, conversation_id, event_type, message, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||
id, messageID, conversationID, eventType, message, dataJSON, time.Now(),
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("添加过程详情失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetProcessDetails 获取消息的过程详情
|
||
func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
||
rows, err := db.Query(
|
||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC",
|
||
messageID,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询过程详情失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var details []ProcessDetail
|
||
for rows.Next() {
|
||
var detail ProcessDetail
|
||
var createdAt string
|
||
|
||
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||
return nil, fmt.Errorf("扫描过程详情失败: %w", err)
|
||
}
|
||
|
||
// 尝试多种时间格式解析
|
||
var err error
|
||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||
if err != nil {
|
||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||
}
|
||
if err != nil {
|
||
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
}
|
||
|
||
details = append(details, detail)
|
||
}
|
||
|
||
return details, nil
|
||
}
|
||
|
||
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
|
||
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
|
||
rows, err := db.Query(
|
||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC",
|
||
conversationID,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询过程详情失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
detailsMap := make(map[string][]ProcessDetail)
|
||
for rows.Next() {
|
||
var detail ProcessDetail
|
||
var createdAt string
|
||
|
||
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||
return nil, fmt.Errorf("扫描过程详情失败: %w", err)
|
||
}
|
||
|
||
// 尝试多种时间格式解析
|
||
var err error
|
||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||
if err != nil {
|
||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||
}
|
||
if err != nil {
|
||
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||
}
|
||
|
||
detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail)
|
||
}
|
||
|
||
return detailsMap, nil
|
||
}
|