Add files via upload

This commit is contained in:
公明
2025-11-08 18:56:23 +08:00
committed by GitHub
commit add33e1cf7
24 changed files with 5228 additions and 0 deletions
+256
View File
@@ -0,0 +1,256 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Conversation 对话
type Conversation struct {
ID string `json:"id"`
Title string `json:"title"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Messages []Message `json:"messages,omitempty"`
}
// Message 消息
type Message struct {
ID string `json:"id"`
ConversationID string `json:"conversationId"`
Role string `json:"role"`
Content string `json:"content"`
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
// CreateConversation 创建新对话
func (db *DB) CreateConversation(title string) (*Conversation, error) {
id := uuid.New().String()
now := time.Now()
_, err := db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
id, title, now, now,
)
if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err)
}
return &Conversation{
ID: id,
Title: title,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// GetConversation 获取对话
func (db *DB) GetConversation(id string) (*Conversation, error) {
var conv Conversation
var createdAt, updatedAt string
err := db.QueryRow(
"SELECT id, title, created_at, updated_at FROM conversations WHERE id = ?",
id,
).Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("对话不存在")
}
return nil, fmt.Errorf("查询对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
}
// 加载消息
messages, err := db.GetMessages(id)
if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err)
}
conv.Messages = messages
return &conv, nil
}
// ListConversations 列出所有对话
func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) {
rows, err := db.Query(
"SELECT id, title, created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
limit, offset,
)
if err != nil {
return nil, fmt.Errorf("查询对话列表失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
if err := rows.Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
}
conversations = append(conversations, &conv)
}
return conversations, nil
}
// UpdateConversationTitle 更新对话标题
func (db *DB) UpdateConversationTitle(id, title string) error {
_, err := db.Exec(
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
title, time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新对话标题失败: %w", err)
}
return nil
}
// UpdateConversationTime 更新对话时间
func (db *DB) UpdateConversationTime(id string) error {
_, err := db.Exec(
"UPDATE conversations SET updated_at = ? WHERE id = ?",
time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新对话时间失败: %w", err)
}
return nil
}
// DeleteConversation 删除对话
func (db *DB) DeleteConversation(id string) error {
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除对话失败: %w", err)
}
return nil
}
// AddMessage 添加消息
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
id := uuid.New().String()
var mcpIDsJSON string
if len(mcpExecutionIDs) > 0 {
jsonData, err := json.Marshal(mcpExecutionIDs)
if err != nil {
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
} else {
mcpIDsJSON = string(jsonData)
}
}
_, err := db.Exec(
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
id, conversationID, role, content, mcpIDsJSON, time.Now(),
)
if err != nil {
return nil, fmt.Errorf("添加消息失败: %w", err)
}
// 更新对话时间
if err := db.UpdateConversationTime(conversationID); err != nil {
db.logger.Warn("更新对话时间失败", zap.Error(err))
}
message := &Message{
ID: id,
ConversationID: conversationID,
Role: role,
Content: content,
MCPExecutionIDs: mcpExecutionIDs,
CreatedAt: time.Now(),
}
return message, nil
}
// GetMessages 获取对话的所有消息
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
rows, err := db.Query(
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
conversationID,
)
if err != nil {
return nil, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
var mcpIDsJSON sql.NullString
var createdAt string
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描消息失败: %w", err)
}
// 尝试多种时间格式解析
var err error
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt)
}
// 解析MCP执行ID
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
}
}
messages = append(messages, msg)
}
return messages, nil
}
+90
View File
@@ -0,0 +1,90 @@
package database
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
)
// DB 数据库连接
type DB struct {
*sql.DB
logger *zap.Logger
}
// NewDB 创建数据库连接
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
if err != nil {
return nil, fmt.Errorf("打开数据库失败: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
database := &DB{
DB: db,
logger: logger,
}
// 初始化表
if err := database.initTables(); err != nil {
return nil, fmt.Errorf("初始化表失败: %w", err)
}
return database, nil
}
// initTables 初始化数据库表
func (db *DB) initTables() error {
// 创建对话表
createConversationsTable := `
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
);`
// 创建消息表
createMessagesTable := `
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
mcp_execution_ids TEXT,
created_at DATETIME NOT NULL,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建索引
createIndexes := `
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
`
if _, err := db.Exec(createConversationsTable); err != nil {
return fmt.Errorf("创建conversations表失败: %w", err)
}
if _, err := db.Exec(createMessagesTable); err != nil {
return fmt.Errorf("创建messages表失败: %w", err)
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
db.logger.Info("数据库表初始化完成")
return nil
}
// Close 关闭数据库连接
func (db *DB) Close() error {
return db.DB.Close()
}