mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-01 03:51:46 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,256 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Conversation 对话
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||||
id, title, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
Title: title,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
// 加载消息
|
||||
messages, err := db.GetMessages(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
conv.Messages = messages
|
||||
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// ListConversations 列出所有对话
|
||||
func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询对话列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// UpdateConversationTitle 更新对话标题
|
||||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
|
||||
title, time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话标题失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConversationTime 更新对话时间
|
||||
func (db *DB) UpdateConversationTime(id string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||||
time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话
|
||||
func (db *DB) DeleteConversation(id string) error {
|
||||
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMessage 添加消息
|
||||
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
|
||||
id := uuid.New().String()
|
||||
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
|
||||
} else {
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新对话时间
|
||||
if err := db.UpdateConversationTime(conversationID); err != nil {
|
||||
db.logger.Warn("更新对话时间失败", zap.Error(err))
|
||||
}
|
||||
|
||||
message := &Message{
|
||||
ID: id,
|
||||
ConversationID: conversationID,
|
||||
Role: role,
|
||||
Content: content,
|
||||
MCPExecutionIDs: mcpExecutionIDs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
// 解析MCP执行ID
|
||||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: db,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 初始化表
|
||||
if err := database.initTables(); err != nil {
|
||||
return nil, fmt.Errorf("初始化表失败: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
func (db *DB) initTables() error {
|
||||
// 创建对话表
|
||||
createConversationsTable := `
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建消息表
|
||||
createMessagesTable := `
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
mcp_execution_ids TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
return fmt.Errorf("创建conversations表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createMessagesTable); err != nil {
|
||||
return fmt.Errorf("创建messages表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("数据库表初始化完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user