mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 02:06:46 +02:00
852 lines
23 KiB
Go
852 lines
23 KiB
Go
package knowledge
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/fs"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Manager 知识库管理器
|
|
type Manager struct {
|
|
db *sql.DB
|
|
basePath string
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewManager 创建新的知识库管理器
|
|
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
|
|
return &Manager{
|
|
db: db,
|
|
basePath: basePath,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// ScanKnowledgeBase 扫描知识库目录,更新数据库
|
|
// 返回需要索引的知识项ID列表(新添加的或更新的)
|
|
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
|
|
if m.basePath == "" {
|
|
return nil, fmt.Errorf("知识库路径未配置")
|
|
}
|
|
|
|
// 确保目录存在
|
|
if err := os.MkdirAll(m.basePath, 0755); err != nil {
|
|
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
|
|
}
|
|
|
|
var itemsToIndex []string
|
|
|
|
// 遍历知识库目录
|
|
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 跳过目录和非markdown文件
|
|
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
|
|
return nil
|
|
}
|
|
|
|
// 计算相对路径和分类
|
|
relPath, err := filepath.Rel(m.basePath, path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 第一个目录名作为分类(风险类型)
|
|
parts := strings.Split(relPath, string(filepath.Separator))
|
|
category := "未分类"
|
|
if len(parts) > 1 {
|
|
category = parts[0]
|
|
}
|
|
|
|
// 文件名为标题
|
|
title := strings.TrimSuffix(filepath.Base(path), ".md")
|
|
|
|
// 读取文件内容
|
|
content, err := os.ReadFile(path)
|
|
if err != nil {
|
|
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
|
|
return nil // 继续处理其他文件
|
|
}
|
|
|
|
// 检查是否已存在
|
|
var existingID string
|
|
var existingContent string
|
|
var existingUpdatedAt time.Time
|
|
err = m.db.QueryRow(
|
|
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
|
|
path,
|
|
).Scan(&existingID, &existingContent, &existingUpdatedAt)
|
|
|
|
if err == sql.ErrNoRows {
|
|
// 创建新项
|
|
id := uuid.New().String()
|
|
now := time.Now()
|
|
_, err = m.db.Exec(
|
|
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
id, category, title, path, string(content), now, now,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("插入知识项失败: %w", err)
|
|
}
|
|
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
|
|
// 新添加的项需要索引
|
|
itemsToIndex = append(itemsToIndex, id)
|
|
} else if err == nil {
|
|
// 检查内容是否有变化
|
|
contentChanged := existingContent != string(content)
|
|
if contentChanged {
|
|
// 更新现有项
|
|
_, err = m.db.Exec(
|
|
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
|
|
category, title, string(content), time.Now(), existingID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("更新知识项失败: %w", err)
|
|
}
|
|
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
|
|
// 内容已更新的项需要重新索引
|
|
itemsToIndex = append(itemsToIndex, existingID)
|
|
} else {
|
|
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
|
|
}
|
|
} else {
|
|
return fmt.Errorf("查询知识项失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return itemsToIndex, nil
|
|
}
|
|
|
|
// GetCategories 获取所有分类(风险类型)
|
|
func (m *Manager) GetCategories() ([]string, error) {
|
|
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询分类失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var categories []string
|
|
for rows.Next() {
|
|
var category string
|
|
if err := rows.Scan(&category); err != nil {
|
|
return nil, fmt.Errorf("扫描分类失败: %w", err)
|
|
}
|
|
categories = append(categories, category)
|
|
}
|
|
|
|
return categories, nil
|
|
}
|
|
|
|
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
|
|
// limit: 每页分类数量(0表示不限制)
|
|
// offset: 偏移量(按分类偏移)
|
|
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
|
|
// 首先获取所有分类(带数量统计)
|
|
rows, err := m.db.Query(`
|
|
SELECT category, COUNT(*) as item_count
|
|
FROM knowledge_base_items
|
|
GROUP BY category
|
|
ORDER BY category
|
|
`)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
// 收集所有分类信息
|
|
type categoryInfo struct {
|
|
name string
|
|
itemCount int
|
|
}
|
|
var allCategories []categoryInfo
|
|
for rows.Next() {
|
|
var info categoryInfo
|
|
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
|
|
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
|
|
}
|
|
allCategories = append(allCategories, info)
|
|
}
|
|
|
|
totalCategories := len(allCategories)
|
|
|
|
// 应用分页(按分类分页)
|
|
var paginatedCategories []categoryInfo
|
|
if limit > 0 {
|
|
start := offset
|
|
end := offset + limit
|
|
if start >= totalCategories {
|
|
paginatedCategories = []categoryInfo{}
|
|
} else {
|
|
if end > totalCategories {
|
|
end = totalCategories
|
|
}
|
|
paginatedCategories = allCategories[start:end]
|
|
}
|
|
} else {
|
|
paginatedCategories = allCategories
|
|
}
|
|
|
|
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
|
|
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
|
|
for _, catInfo := range paginatedCategories {
|
|
// 获取该分类下的所有知识项
|
|
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
|
|
}
|
|
|
|
result = append(result, &CategoryWithItems{
|
|
Category: catInfo.name,
|
|
ItemCount: catInfo.itemCount,
|
|
Items: items,
|
|
})
|
|
}
|
|
|
|
return result, totalCategories, nil
|
|
}
|
|
|
|
// GetItems 获取知识项列表(完整内容,用于向后兼容)
|
|
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
|
return m.GetItemsWithOptions(category, 0, 0, true)
|
|
}
|
|
|
|
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
|
|
// category: 分类筛选(空字符串表示所有分类)
|
|
// limit: 每页数量(0表示不限制)
|
|
// offset: 偏移量
|
|
// includeContent: 是否包含完整内容(false时只返回摘要)
|
|
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
|
|
// 构建SQL查询
|
|
var query string
|
|
var args []interface{}
|
|
|
|
if includeContent {
|
|
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
|
|
} else {
|
|
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
|
}
|
|
|
|
if category != "" {
|
|
query += " WHERE category = ?"
|
|
args = append(args, category)
|
|
}
|
|
|
|
query += " ORDER BY category, title"
|
|
|
|
if limit > 0 {
|
|
query += " LIMIT ?"
|
|
args = append(args, limit)
|
|
if offset > 0 {
|
|
query += " OFFSET ?"
|
|
args = append(args, offset)
|
|
}
|
|
}
|
|
|
|
rows, err = m.db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []*KnowledgeItem
|
|
for rows.Next() {
|
|
item := &KnowledgeItem{}
|
|
var createdAt, updatedAt string
|
|
|
|
if includeContent {
|
|
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
|
|
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
|
}
|
|
} else {
|
|
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
|
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
|
}
|
|
// 不包含内容时,Content为空字符串
|
|
item.Content = ""
|
|
}
|
|
|
|
// 解析时间 - 支持多种格式
|
|
timeFormats := []string{
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02T15:04:05.999999999Z07:00",
|
|
"2006-01-02T15:04:05Z",
|
|
"2006-01-02 15:04:05",
|
|
time.RFC3339,
|
|
time.RFC3339Nano,
|
|
}
|
|
|
|
// 解析创建时间
|
|
if createdAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, createdAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.CreatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// 解析更新时间
|
|
if updatedAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, updatedAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.UpdatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// 如果更新时间为空,使用创建时间
|
|
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
|
item.UpdatedAt = item.CreatedAt
|
|
}
|
|
|
|
items = append(items, item)
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
// GetItemsCount 获取知识项总数
|
|
func (m *Manager) GetItemsCount(category string) (int, error) {
|
|
var count int
|
|
var err error
|
|
|
|
if category != "" {
|
|
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
|
|
} else {
|
|
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
|
|
}
|
|
|
|
if err != nil {
|
|
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
|
|
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
|
|
if keyword == "" {
|
|
return nil, fmt.Errorf("搜索关键字不能为空")
|
|
}
|
|
|
|
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
|
|
var query string
|
|
var args []interface{}
|
|
|
|
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
|
|
// 使用%keyword%进行模糊匹配
|
|
searchPattern := "%" + keyword + "%"
|
|
|
|
query = `
|
|
SELECT id, category, title, file_path, created_at, updated_at
|
|
FROM knowledge_base_items
|
|
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
|
|
`
|
|
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
|
|
|
|
// 如果指定了分类,添加分类过滤
|
|
if category != "" {
|
|
query += " AND category = ?"
|
|
args = append(args, category)
|
|
}
|
|
|
|
query += " ORDER BY category, title"
|
|
|
|
rows, err := m.db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("搜索知识项失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []*KnowledgeItemSummary
|
|
for rows.Next() {
|
|
item := &KnowledgeItemSummary{}
|
|
var createdAt, updatedAt string
|
|
|
|
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
|
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
|
}
|
|
|
|
// 解析时间
|
|
timeFormats := []string{
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02T15:04:05.999999999Z07:00",
|
|
"2006-01-02T15:04:05Z",
|
|
"2006-01-02 15:04:05",
|
|
time.RFC3339,
|
|
time.RFC3339Nano,
|
|
}
|
|
|
|
if createdAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, createdAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.CreatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if updatedAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, updatedAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.UpdatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
|
item.UpdatedAt = item.CreatedAt
|
|
}
|
|
|
|
items = append(items, item)
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
|
|
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
|
|
// 获取总数
|
|
total, err := m.GetItemsCount(category)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// 获取列表数据(不包含内容)
|
|
var rows *sql.Rows
|
|
var query string
|
|
var args []interface{}
|
|
|
|
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
|
|
|
if category != "" {
|
|
query += " WHERE category = ?"
|
|
args = append(args, category)
|
|
}
|
|
|
|
query += " ORDER BY category, title"
|
|
|
|
if limit > 0 {
|
|
query += " LIMIT ?"
|
|
args = append(args, limit)
|
|
if offset > 0 {
|
|
query += " OFFSET ?"
|
|
args = append(args, offset)
|
|
}
|
|
}
|
|
|
|
rows, err = m.db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var items []*KnowledgeItemSummary
|
|
for rows.Next() {
|
|
item := &KnowledgeItemSummary{}
|
|
var createdAt, updatedAt string
|
|
|
|
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
|
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
|
|
}
|
|
|
|
// 解析时间
|
|
timeFormats := []string{
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02T15:04:05.999999999Z07:00",
|
|
"2006-01-02T15:04:05Z",
|
|
"2006-01-02 15:04:05",
|
|
time.RFC3339,
|
|
time.RFC3339Nano,
|
|
}
|
|
|
|
if createdAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, createdAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.CreatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if updatedAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, updatedAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.UpdatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
|
item.UpdatedAt = item.CreatedAt
|
|
}
|
|
|
|
items = append(items, item)
|
|
}
|
|
|
|
return items, total, nil
|
|
}
|
|
|
|
// GetItem 获取单个知识项
|
|
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
|
item := &KnowledgeItem{}
|
|
var createdAt, updatedAt string
|
|
err := m.db.QueryRow(
|
|
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
|
|
id,
|
|
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, fmt.Errorf("知识项不存在")
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
|
}
|
|
|
|
// 解析时间 - 支持多种格式
|
|
timeFormats := []string{
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02T15:04:05.999999999Z07:00",
|
|
"2006-01-02T15:04:05Z",
|
|
"2006-01-02 15:04:05",
|
|
time.RFC3339,
|
|
time.RFC3339Nano,
|
|
}
|
|
|
|
// 解析创建时间
|
|
if createdAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, createdAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.CreatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// 解析更新时间
|
|
if updatedAt != "" {
|
|
for _, format := range timeFormats {
|
|
parsed, err := time.Parse(format, updatedAt)
|
|
if err == nil && !parsed.IsZero() {
|
|
item.UpdatedAt = parsed
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// 如果更新时间为空,使用创建时间
|
|
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
|
item.UpdatedAt = item.CreatedAt
|
|
}
|
|
|
|
return item, nil
|
|
}
|
|
|
|
// CreateItem 创建知识项
|
|
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
|
|
id := uuid.New().String()
|
|
now := time.Now()
|
|
|
|
// 构建文件路径
|
|
filePath := filepath.Join(m.basePath, category, title+".md")
|
|
|
|
// 确保目录存在
|
|
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
|
return nil, fmt.Errorf("创建目录失败: %w", err)
|
|
}
|
|
|
|
// 写入文件
|
|
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
|
return nil, fmt.Errorf("写入文件失败: %w", err)
|
|
}
|
|
|
|
// 插入数据库
|
|
_, err := m.db.Exec(
|
|
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
id, category, title, filePath, content, now, now,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("插入知识项失败: %w", err)
|
|
}
|
|
|
|
return &KnowledgeItem{
|
|
ID: id,
|
|
Category: category,
|
|
Title: title,
|
|
FilePath: filePath,
|
|
Content: content,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}, nil
|
|
}
|
|
|
|
// UpdateItem 更新知识项
|
|
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
|
|
// 获取现有项
|
|
item, err := m.GetItem(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 构建新文件路径
|
|
newFilePath := filepath.Join(m.basePath, category, title+".md")
|
|
|
|
// 如果路径改变,需要移动文件
|
|
if item.FilePath != newFilePath {
|
|
// 确保新目录存在
|
|
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
|
|
return nil, fmt.Errorf("创建目录失败: %w", err)
|
|
}
|
|
|
|
// 移动文件
|
|
if err := os.Rename(item.FilePath, newFilePath); err != nil {
|
|
return nil, fmt.Errorf("移动文件失败: %w", err)
|
|
}
|
|
|
|
// 删除旧目录(如果为空)
|
|
oldDir := filepath.Dir(item.FilePath)
|
|
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
|
|
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
|
if oldDir != m.basePath {
|
|
if err := os.Remove(oldDir); err != nil {
|
|
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 写入文件
|
|
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
|
|
return nil, fmt.Errorf("写入文件失败: %w", err)
|
|
}
|
|
|
|
// 更新数据库
|
|
_, err = m.db.Exec(
|
|
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
|
|
category, title, newFilePath, content, time.Now(), id,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("更新知识项失败: %w", err)
|
|
}
|
|
|
|
// 删除旧的向量嵌入(需要重新索引)
|
|
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
|
|
if err != nil {
|
|
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
|
|
}
|
|
|
|
return m.GetItem(id)
|
|
}
|
|
|
|
// DeleteItem 删除知识项
|
|
func (m *Manager) DeleteItem(id string) error {
|
|
// 获取文件路径
|
|
var filePath string
|
|
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("查询知识项失败: %w", err)
|
|
}
|
|
|
|
// 删除文件
|
|
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
|
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
|
|
}
|
|
|
|
// 删除数据库记录(级联删除向量)
|
|
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
|
|
if err != nil {
|
|
return fmt.Errorf("删除知识项失败: %w", err)
|
|
}
|
|
|
|
// 删除空目录(如果为空)
|
|
dir := filepath.Dir(filePath)
|
|
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
|
|
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
|
if dir != m.basePath {
|
|
if err := os.Remove(dir); err != nil {
|
|
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LogRetrieval 记录检索日志
|
|
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
|
|
id := uuid.New().String()
|
|
itemsJSON, _ := json.Marshal(retrievedItems)
|
|
|
|
_, err := m.db.Exec(
|
|
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// GetIndexStatus 获取索引状态
|
|
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
|
|
// 获取总知识项数
|
|
var totalItems int
|
|
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
|
|
}
|
|
|
|
// 获取已索引的知识项数(有向量嵌入的)
|
|
var indexedItems int
|
|
err = m.db.QueryRow(`
|
|
SELECT COUNT(DISTINCT item_id)
|
|
FROM knowledge_embeddings
|
|
`).Scan(&indexedItems)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
|
|
}
|
|
|
|
// 计算进度百分比
|
|
var progressPercent float64
|
|
if totalItems > 0 {
|
|
progressPercent = float64(indexedItems) / float64(totalItems) * 100
|
|
} else {
|
|
progressPercent = 100.0
|
|
}
|
|
|
|
// 判断是否完成
|
|
isComplete := indexedItems >= totalItems && totalItems > 0
|
|
|
|
return map[string]interface{}{
|
|
"total_items": totalItems,
|
|
"indexed_items": indexedItems,
|
|
"progress_percent": progressPercent,
|
|
"is_complete": isComplete,
|
|
}, nil
|
|
}
|
|
|
|
// GetRetrievalLogs 获取检索日志
|
|
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
|
|
if messageID != "" {
|
|
rows, err = m.db.Query(
|
|
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
|
|
messageID, limit,
|
|
)
|
|
} else if conversationID != "" {
|
|
rows, err = m.db.Query(
|
|
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
|
|
conversationID, limit,
|
|
)
|
|
} else {
|
|
rows, err = m.db.Query(
|
|
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
|
|
limit,
|
|
)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询检索日志失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var logs []*RetrievalLog
|
|
for rows.Next() {
|
|
log := &RetrievalLog{}
|
|
var createdAt string
|
|
var itemsJSON sql.NullString
|
|
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
|
|
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
|
|
}
|
|
|
|
// 解析时间 - 支持多种格式
|
|
var err error
|
|
timeFormats := []string{
|
|
"2006-01-02 15:04:05.999999999-07:00",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02T15:04:05.999999999Z07:00",
|
|
"2006-01-02T15:04:05Z",
|
|
"2006-01-02 15:04:05",
|
|
time.RFC3339,
|
|
time.RFC3339Nano,
|
|
}
|
|
|
|
for _, format := range timeFormats {
|
|
log.CreatedAt, err = time.Parse(format, createdAt)
|
|
if err == nil && !log.CreatedAt.IsZero() {
|
|
break
|
|
}
|
|
}
|
|
|
|
// 如果所有格式都失败,记录警告但继续处理
|
|
if log.CreatedAt.IsZero() {
|
|
m.logger.Warn("解析检索日志时间失败",
|
|
zap.String("timeStr", createdAt),
|
|
zap.Error(err),
|
|
)
|
|
// 使用当前时间作为fallback
|
|
log.CreatedAt = time.Now()
|
|
}
|
|
|
|
// 解析检索项
|
|
if itemsJSON.Valid {
|
|
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
|
|
}
|
|
|
|
logs = append(logs, log)
|
|
}
|
|
|
|
return logs, nil
|
|
}
|
|
|
|
// DeleteRetrievalLog 删除检索日志
|
|
func (m *Manager) DeleteRetrievalLog(id string) error {
|
|
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
|
|
if err != nil {
|
|
return fmt.Errorf("删除检索日志失败: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("获取删除行数失败: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("检索日志不存在")
|
|
}
|
|
|
|
return nil
|
|
}
|