package knowledge import ( "database/sql" "encoding/json" "fmt" "io/fs" "os" "path/filepath" "strings" "time" "github.com/google/uuid" "go.uber.org/zap" ) // Manager 知识库管理器 type Manager struct { db *sql.DB basePath string logger *zap.Logger } // NewManager 创建新的知识库管理器 func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager { return &Manager{ db: db, basePath: basePath, logger: logger, } } // ScanKnowledgeBase 扫描知识库目录,更新数据库 // 返回需要索引的知识项ID列表(新添加的或更新的) func (m *Manager) ScanKnowledgeBase() ([]string, error) { if m.basePath == "" { return nil, fmt.Errorf("知识库路径未配置") } // 确保目录存在 if err := os.MkdirAll(m.basePath, 0755); err != nil { return nil, fmt.Errorf("创建知识库目录失败: %w", err) } var itemsToIndex []string // 遍历知识库目录 err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } // 跳过目录和非markdown文件 if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") { return nil } // 计算相对路径和分类 relPath, err := filepath.Rel(m.basePath, path) if err != nil { return err } // 第一个目录名作为分类(风险类型) parts := strings.Split(relPath, string(filepath.Separator)) category := "未分类" if len(parts) > 1 { category = parts[0] } // 文件名为标题 title := strings.TrimSuffix(filepath.Base(path), ".md") // 读取文件内容 content, err := os.ReadFile(path) if err != nil { m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err)) return nil // 继续处理其他文件 } // 检查是否已存在 var existingID string var existingContent string var existingUpdatedAt time.Time err = m.db.QueryRow( "SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?", path, ).Scan(&existingID, &existingContent, &existingUpdatedAt) if err == sql.ErrNoRows { // 创建新项 id := uuid.New().String() now := time.Now() _, err = m.db.Exec( "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", id, category, title, path, string(content), now, now, ) if err != nil { return fmt.Errorf("插入知识项失败: %w", err) } m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category)) // 新添加的项需要索引 itemsToIndex = append(itemsToIndex, id) } else if err == nil { // 检查内容是否有变化 contentChanged := existingContent != string(content) if contentChanged { // 更新现有项 _, err = m.db.Exec( "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", category, title, string(content), time.Now(), existingID, ) if err != nil { return fmt.Errorf("更新知识项失败: %w", err) } m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title)) // 内容已更新的项需要重新索引 itemsToIndex = append(itemsToIndex, existingID) } else { m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title)) } } else { return fmt.Errorf("查询知识项失败: %w", err) } return nil }) if err != nil { return nil, err } return itemsToIndex, nil } // GetCategories 获取所有分类(风险类型) func (m *Manager) GetCategories() ([]string, error) { rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category") if err != nil { return nil, fmt.Errorf("查询分类失败: %w", err) } defer rows.Close() var categories []string for rows.Next() { var category string if err := rows.Scan(&category); err != nil { return nil, fmt.Errorf("扫描分类失败: %w", err) } categories = append(categories, category) } return categories, nil } // GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项) // limit: 每页分类数量(0表示不限制) // offset: 偏移量(按分类偏移) func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) { // 首先获取所有分类(带数量统计) rows, err := m.db.Query(` SELECT category, COUNT(*) as item_count FROM knowledge_base_items GROUP BY category ORDER BY category `) if err != nil { return nil, 0, fmt.Errorf("查询分类失败: %w", err) } defer rows.Close() // 收集所有分类信息 type categoryInfo struct { name string itemCount int } var allCategories []categoryInfo for rows.Next() { var info categoryInfo if err := rows.Scan(&info.name, &info.itemCount); err != nil { return nil, 0, fmt.Errorf("扫描分类失败: %w", err) } allCategories = append(allCategories, info) } totalCategories := len(allCategories) // 应用分页(按分类分页) var paginatedCategories []categoryInfo if limit > 0 { start := offset end := offset + limit if start >= totalCategories { paginatedCategories = []categoryInfo{} } else { if end > totalCategories { end = totalCategories } paginatedCategories = allCategories[start:end] } } else { paginatedCategories = allCategories } // 为每个分类获取其下的知识项(只返回摘要,不包含完整内容) result := make([]*CategoryWithItems, 0, len(paginatedCategories)) for _, catInfo := range paginatedCategories { // 获取该分类下的所有知识项 items, _, err := m.GetItemsSummary(catInfo.name, 0, 0) if err != nil { return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err) } result = append(result, &CategoryWithItems{ Category: catInfo.name, ItemCount: catInfo.itemCount, Items: items, }) } return result, totalCategories, nil } // GetItems 获取知识项列表(完整内容,用于向后兼容) func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { return m.GetItemsWithOptions(category, 0, 0, true) } // GetItemsWithOptions 获取知识项列表(支持分页和可选内容) // category: 分类筛选(空字符串表示所有分类) // limit: 每页数量(0表示不限制) // offset: 偏移量 // includeContent: 是否包含完整内容(false时只返回摘要) func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) { var rows *sql.Rows var err error // 构建SQL查询 var query string var args []interface{} if includeContent { query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items" } else { query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" } if category != "" { query += " WHERE category = ?" args = append(args, category) } query += " ORDER BY category, title" if limit > 0 { query += " LIMIT ?" args = append(args, limit) if offset > 0 { query += " OFFSET ?" args = append(args, offset) } } rows, err = m.db.Query(query, args...) if err != nil { return nil, fmt.Errorf("查询知识项失败: %w", err) } defer rows.Close() var items []*KnowledgeItem for rows.Next() { item := &KnowledgeItem{} var createdAt, updatedAt string if includeContent { if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil { return nil, fmt.Errorf("扫描知识项失败: %w", err) } } else { if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { return nil, fmt.Errorf("扫描知识项失败: %w", err) } // 不包含内容时,Content为空字符串 item.Content = "" } // 解析时间 - 支持多种格式 timeFormats := []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999Z07:00", "2006-01-02T15:04:05Z", "2006-01-02 15:04:05", time.RFC3339, time.RFC3339Nano, } // 解析创建时间 if createdAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, createdAt) if err == nil && !parsed.IsZero() { item.CreatedAt = parsed break } } } // 解析更新时间 if updatedAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, updatedAt) if err == nil && !parsed.IsZero() { item.UpdatedAt = parsed break } } } // 如果更新时间为空,使用创建时间 if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { item.UpdatedAt = item.CreatedAt } items = append(items, item) } return items, nil } // GetItemsCount 获取知识项总数 func (m *Manager) GetItemsCount(category string) (int, error) { var count int var err error if category != "" { err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count) } else { err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count) } if err != nil { return 0, fmt.Errorf("查询知识项总数失败: %w", err) } return count, nil } // SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配) func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) { if keyword == "" { return nil, fmt.Errorf("搜索关键字不能为空") } // 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写) var query string var args []interface{} // SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数 // 使用%keyword%进行模糊匹配 searchPattern := "%" + keyword + "%" query = ` SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?)) ` args = append(args, searchPattern, searchPattern, searchPattern, searchPattern) // 如果指定了分类,添加分类过滤 if category != "" { query += " AND category = ?" args = append(args, category) } query += " ORDER BY category, title" rows, err := m.db.Query(query, args...) if err != nil { return nil, fmt.Errorf("搜索知识项失败: %w", err) } defer rows.Close() var items []*KnowledgeItemSummary for rows.Next() { item := &KnowledgeItemSummary{} var createdAt, updatedAt string if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { return nil, fmt.Errorf("扫描知识项失败: %w", err) } // 解析时间 timeFormats := []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999Z07:00", "2006-01-02T15:04:05Z", "2006-01-02 15:04:05", time.RFC3339, time.RFC3339Nano, } if createdAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, createdAt) if err == nil && !parsed.IsZero() { item.CreatedAt = parsed break } } } if updatedAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, updatedAt) if err == nil && !parsed.IsZero() { item.UpdatedAt = parsed break } } } if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { item.UpdatedAt = item.CreatedAt } items = append(items, item) } return items, nil } // GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页) func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) { // 获取总数 total, err := m.GetItemsCount(category) if err != nil { return nil, 0, err } // 获取列表数据(不包含内容) var rows *sql.Rows var query string var args []interface{} query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" if category != "" { query += " WHERE category = ?" args = append(args, category) } query += " ORDER BY category, title" if limit > 0 { query += " LIMIT ?" args = append(args, limit) if offset > 0 { query += " OFFSET ?" args = append(args, offset) } } rows, err = m.db.Query(query, args...) if err != nil { return nil, 0, fmt.Errorf("查询知识项失败: %w", err) } defer rows.Close() var items []*KnowledgeItemSummary for rows.Next() { item := &KnowledgeItemSummary{} var createdAt, updatedAt string if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { return nil, 0, fmt.Errorf("扫描知识项失败: %w", err) } // 解析时间 timeFormats := []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999Z07:00", "2006-01-02T15:04:05Z", "2006-01-02 15:04:05", time.RFC3339, time.RFC3339Nano, } if createdAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, createdAt) if err == nil && !parsed.IsZero() { item.CreatedAt = parsed break } } } if updatedAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, updatedAt) if err == nil && !parsed.IsZero() { item.UpdatedAt = parsed break } } } if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { item.UpdatedAt = item.CreatedAt } items = append(items, item) } return items, total, nil } // GetItem 获取单个知识项 func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { item := &KnowledgeItem{} var createdAt, updatedAt string err := m.db.QueryRow( "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?", id, ).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("知识项不存在") } if err != nil { return nil, fmt.Errorf("查询知识项失败: %w", err) } // 解析时间 - 支持多种格式 timeFormats := []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999Z07:00", "2006-01-02T15:04:05Z", "2006-01-02 15:04:05", time.RFC3339, time.RFC3339Nano, } // 解析创建时间 if createdAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, createdAt) if err == nil && !parsed.IsZero() { item.CreatedAt = parsed break } } } // 解析更新时间 if updatedAt != "" { for _, format := range timeFormats { parsed, err := time.Parse(format, updatedAt) if err == nil && !parsed.IsZero() { item.UpdatedAt = parsed break } } } // 如果更新时间为空,使用创建时间 if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { item.UpdatedAt = item.CreatedAt } return item, nil } // CreateItem 创建知识项 func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) { id := uuid.New().String() now := time.Now() // 构建文件路径 filePath := filepath.Join(m.basePath, category, title+".md") // 确保目录存在 if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { return nil, fmt.Errorf("创建目录失败: %w", err) } // 写入文件 if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { return nil, fmt.Errorf("写入文件失败: %w", err) } // 插入数据库 _, err := m.db.Exec( "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", id, category, title, filePath, content, now, now, ) if err != nil { return nil, fmt.Errorf("插入知识项失败: %w", err) } return &KnowledgeItem{ ID: id, Category: category, Title: title, FilePath: filePath, Content: content, CreatedAt: now, UpdatedAt: now, }, nil } // UpdateItem 更新知识项 func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) { // 获取现有项 item, err := m.GetItem(id) if err != nil { return nil, err } // 构建新文件路径 newFilePath := filepath.Join(m.basePath, category, title+".md") // 如果路径改变,需要移动文件 if item.FilePath != newFilePath { // 确保新目录存在 if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil { return nil, fmt.Errorf("创建目录失败: %w", err) } // 移动文件 if err := os.Rename(item.FilePath, newFilePath); err != nil { return nil, fmt.Errorf("移动文件失败: %w", err) } // 删除旧目录(如果为空) oldDir := filepath.Dir(item.FilePath) if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 { // 只有当目录不是知识库根目录时才删除(避免删除根目录) if oldDir != m.basePath { if err := os.Remove(oldDir); err != nil { m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err)) } } } } // 写入文件 if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil { return nil, fmt.Errorf("写入文件失败: %w", err) } // 更新数据库 _, err = m.db.Exec( "UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?", category, title, newFilePath, content, time.Now(), id, ) if err != nil { return nil, fmt.Errorf("更新知识项失败: %w", err) } // 删除旧的向量嵌入(需要重新索引) _, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id) if err != nil { m.logger.Warn("删除旧向量嵌入失败", zap.Error(err)) } return m.GetItem(id) } // DeleteItem 删除知识项 func (m *Manager) DeleteItem(id string) error { // 获取文件路径 var filePath string err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath) if err != nil { return fmt.Errorf("查询知识项失败: %w", err) } // 删除文件 if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err)) } // 删除数据库记录(级联删除向量) _, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id) if err != nil { return fmt.Errorf("删除知识项失败: %w", err) } // 删除空目录(如果为空) dir := filepath.Dir(filePath) if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 { // 只有当目录不是知识库根目录时才删除(避免删除根目录) if dir != m.basePath { if err := os.Remove(dir); err != nil { m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err)) } } } return nil } // LogRetrieval 记录检索日志 func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error { id := uuid.New().String() itemsJSON, _ := json.Marshal(retrievedItems) _, err := m.db.Exec( "INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(), ) return err } // GetIndexStatus 获取索引状态 func (m *Manager) GetIndexStatus() (map[string]interface{}, error) { // 获取总知识项数 var totalItems int err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) if err != nil { return nil, fmt.Errorf("查询总知识项数失败: %w", err) } // 获取已索引的知识项数(有向量嵌入的) var indexedItems int err = m.db.QueryRow(` SELECT COUNT(DISTINCT item_id) FROM knowledge_embeddings `).Scan(&indexedItems) if err != nil { return nil, fmt.Errorf("查询已索引项数失败: %w", err) } // 计算进度百分比 var progressPercent float64 if totalItems > 0 { progressPercent = float64(indexedItems) / float64(totalItems) * 100 } else { progressPercent = 100.0 } // 判断是否完成 isComplete := indexedItems >= totalItems && totalItems > 0 return map[string]interface{}{ "total_items": totalItems, "indexed_items": indexedItems, "progress_percent": progressPercent, "is_complete": isComplete, }, nil } // GetRetrievalLogs 获取检索日志 func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) { var rows *sql.Rows var err error if messageID != "" { rows, err = m.db.Query( "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?", messageID, limit, ) } else if conversationID != "" { rows, err = m.db.Query( "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?", conversationID, limit, ) } else { rows, err = m.db.Query( "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?", limit, ) } if err != nil { return nil, fmt.Errorf("查询检索日志失败: %w", err) } defer rows.Close() var logs []*RetrievalLog for rows.Next() { log := &RetrievalLog{} var createdAt string var itemsJSON sql.NullString if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil { return nil, fmt.Errorf("扫描检索日志失败: %w", err) } // 解析时间 - 支持多种格式 var err error timeFormats := []string{ "2006-01-02 15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999Z07:00", "2006-01-02T15:04:05Z", "2006-01-02 15:04:05", time.RFC3339, time.RFC3339Nano, } for _, format := range timeFormats { log.CreatedAt, err = time.Parse(format, createdAt) if err == nil && !log.CreatedAt.IsZero() { break } } // 如果所有格式都失败,记录警告但继续处理 if log.CreatedAt.IsZero() { m.logger.Warn("解析检索日志时间失败", zap.String("timeStr", createdAt), zap.Error(err), ) // 使用当前时间作为fallback log.CreatedAt = time.Now() } // 解析检索项 if itemsJSON.Valid { json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems) } logs = append(logs, log) } return logs, nil } // DeleteRetrievalLog 删除检索日志 func (m *Manager) DeleteRetrievalLog(id string) error { result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id) if err != nil { return fmt.Errorf("删除检索日志失败: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("获取删除行数失败: %w", err) } if rowsAffected == 0 { return fmt.Errorf("检索日志不存在") } return nil }