mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
836 lines
22 KiB
Go
836 lines
22 KiB
Go
package knowledge
|
||
|
||
import (
|
||
"database/sql"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io/fs"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// Manager 知识库管理器
|
||
type Manager struct {
|
||
db *sql.DB
|
||
basePath string
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewManager 创建新的知识库管理器
|
||
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
|
||
return &Manager{
|
||
db: db,
|
||
basePath: basePath,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// ScanKnowledgeBase 扫描知识库目录,更新数据库
|
||
// 返回需要索引的知识项ID列表(新添加的或更新的)
|
||
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
|
||
if m.basePath == "" {
|
||
return nil, fmt.Errorf("知识库路径未配置")
|
||
}
|
||
|
||
// 确保目录存在
|
||
if err := os.MkdirAll(m.basePath, 0755); err != nil {
|
||
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
|
||
}
|
||
|
||
var itemsToIndex []string
|
||
|
||
// 遍历知识库目录
|
||
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 跳过目录和非markdown文件
|
||
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
|
||
return nil
|
||
}
|
||
|
||
// 计算相对路径和分类
|
||
relPath, err := filepath.Rel(m.basePath, path)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 第一个目录名作为分类(风险类型)
|
||
parts := strings.Split(relPath, string(filepath.Separator))
|
||
category := "未分类"
|
||
if len(parts) > 1 {
|
||
category = parts[0]
|
||
}
|
||
|
||
// 文件名为标题
|
||
title := strings.TrimSuffix(filepath.Base(path), ".md")
|
||
|
||
// 读取文件内容
|
||
content, err := os.ReadFile(path)
|
||
if err != nil {
|
||
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
|
||
return nil // 继续处理其他文件
|
||
}
|
||
|
||
// 检查是否已存在
|
||
var existingID string
|
||
var existingContent string
|
||
var existingUpdatedAt time.Time
|
||
err = m.db.QueryRow(
|
||
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
|
||
path,
|
||
).Scan(&existingID, &existingContent, &existingUpdatedAt)
|
||
|
||
if err == sql.ErrNoRows {
|
||
// 创建新项
|
||
id := uuid.New().String()
|
||
now := time.Now()
|
||
_, err = m.db.Exec(
|
||
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||
id, category, title, path, string(content), now, now,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("插入知识项失败: %w", err)
|
||
}
|
||
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
|
||
// 新添加的项需要索引
|
||
itemsToIndex = append(itemsToIndex, id)
|
||
} else if err == nil {
|
||
// 检查内容是否有变化
|
||
contentChanged := existingContent != string(content)
|
||
if contentChanged {
|
||
// 更新现有项
|
||
_, err = m.db.Exec(
|
||
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
|
||
category, title, string(content), time.Now(), existingID,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("更新知识项失败: %w", err)
|
||
}
|
||
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
|
||
// 内容已更新的项需要重新索引
|
||
itemsToIndex = append(itemsToIndex, existingID)
|
||
} else {
|
||
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
|
||
}
|
||
} else {
|
||
return fmt.Errorf("查询知识项失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return itemsToIndex, nil
|
||
}
|
||
|
||
// GetCategories 获取所有分类(风险类型)
|
||
func (m *Manager) GetCategories() ([]string, error) {
|
||
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询分类失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var categories []string
|
||
for rows.Next() {
|
||
var category string
|
||
if err := rows.Scan(&category); err != nil {
|
||
return nil, fmt.Errorf("扫描分类失败: %w", err)
|
||
}
|
||
categories = append(categories, category)
|
||
}
|
||
|
||
return categories, nil
|
||
}
|
||
|
||
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
|
||
// limit: 每页分类数量(0表示不限制)
|
||
// offset: 偏移量(按分类偏移)
|
||
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
|
||
// 首先获取所有分类(带数量统计)
|
||
rows, err := m.db.Query(`
|
||
SELECT category, COUNT(*) as item_count
|
||
FROM knowledge_base_items
|
||
GROUP BY category
|
||
ORDER BY category
|
||
`)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
// 收集所有分类信息
|
||
type categoryInfo struct {
|
||
name string
|
||
itemCount int
|
||
}
|
||
var allCategories []categoryInfo
|
||
for rows.Next() {
|
||
var info categoryInfo
|
||
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
|
||
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
|
||
}
|
||
allCategories = append(allCategories, info)
|
||
}
|
||
|
||
totalCategories := len(allCategories)
|
||
|
||
// 应用分页(按分类分页)
|
||
var paginatedCategories []categoryInfo
|
||
if limit > 0 {
|
||
start := offset
|
||
end := offset + limit
|
||
if start >= totalCategories {
|
||
paginatedCategories = []categoryInfo{}
|
||
} else {
|
||
if end > totalCategories {
|
||
end = totalCategories
|
||
}
|
||
paginatedCategories = allCategories[start:end]
|
||
}
|
||
} else {
|
||
paginatedCategories = allCategories
|
||
}
|
||
|
||
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
|
||
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
|
||
for _, catInfo := range paginatedCategories {
|
||
// 获取该分类下的所有知识项
|
||
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
|
||
}
|
||
|
||
result = append(result, &CategoryWithItems{
|
||
Category: catInfo.name,
|
||
ItemCount: catInfo.itemCount,
|
||
Items: items,
|
||
})
|
||
}
|
||
|
||
return result, totalCategories, nil
|
||
}
|
||
|
||
// GetItems 获取知识项列表(完整内容,用于向后兼容)
|
||
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
||
return m.GetItemsWithOptions(category, 0, 0, true)
|
||
}
|
||
|
||
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
|
||
// category: 分类筛选(空字符串表示所有分类)
|
||
// limit: 每页数量(0表示不限制)
|
||
// offset: 偏移量
|
||
// includeContent: 是否包含完整内容(false时只返回摘要)
|
||
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
|
||
var rows *sql.Rows
|
||
var err error
|
||
|
||
// 构建SQL查询
|
||
var query string
|
||
var args []interface{}
|
||
|
||
if includeContent {
|
||
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
|
||
} else {
|
||
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
||
}
|
||
|
||
if category != "" {
|
||
query += " WHERE category = ?"
|
||
args = append(args, category)
|
||
}
|
||
|
||
query += " ORDER BY category, title"
|
||
|
||
if limit > 0 {
|
||
query += " LIMIT ?"
|
||
args = append(args, limit)
|
||
if offset > 0 {
|
||
query += " OFFSET ?"
|
||
args = append(args, offset)
|
||
}
|
||
}
|
||
|
||
rows, err = m.db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var items []*KnowledgeItem
|
||
for rows.Next() {
|
||
item := &KnowledgeItem{}
|
||
var createdAt, updatedAt string
|
||
|
||
if includeContent {
|
||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
|
||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||
}
|
||
} else {
|
||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||
}
|
||
// 不包含内容时,Content为空字符串
|
||
item.Content = ""
|
||
}
|
||
|
||
// 解析时间 - 支持多种格式
|
||
timeFormats := []string{
|
||
"2006-01-02 15:04:05.999999999-07:00",
|
||
"2006-01-02 15:04:05.999999999",
|
||
"2006-01-02T15:04:05.999999999Z07:00",
|
||
"2006-01-02T15:04:05Z",
|
||
"2006-01-02 15:04:05",
|
||
time.RFC3339,
|
||
time.RFC3339Nano,
|
||
}
|
||
|
||
// 解析创建时间
|
||
if createdAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, createdAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.CreatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 解析更新时间
|
||
if updatedAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, updatedAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.UpdatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果更新时间为空,使用创建时间
|
||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||
item.UpdatedAt = item.CreatedAt
|
||
}
|
||
|
||
items = append(items, item)
|
||
}
|
||
|
||
return items, nil
|
||
}
|
||
|
||
// GetItemsCount 获取知识项总数
|
||
func (m *Manager) GetItemsCount(category string) (int, error) {
|
||
var count int
|
||
var err error
|
||
|
||
if category != "" {
|
||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
|
||
} else {
|
||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
|
||
}
|
||
|
||
if err != nil {
|
||
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
|
||
}
|
||
|
||
return count, nil
|
||
}
|
||
|
||
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
|
||
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
|
||
if keyword == "" {
|
||
return nil, fmt.Errorf("搜索关键字不能为空")
|
||
}
|
||
|
||
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
|
||
var query string
|
||
var args []interface{}
|
||
|
||
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
|
||
// 使用%keyword%进行模糊匹配
|
||
searchPattern := "%" + keyword + "%"
|
||
|
||
query = `
|
||
SELECT id, category, title, file_path, created_at, updated_at
|
||
FROM knowledge_base_items
|
||
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
|
||
`
|
||
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
|
||
|
||
// 如果指定了分类,添加分类过滤
|
||
if category != "" {
|
||
query += " AND category = ?"
|
||
args = append(args, category)
|
||
}
|
||
|
||
query += " ORDER BY category, title"
|
||
|
||
rows, err := m.db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("搜索知识项失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var items []*KnowledgeItemSummary
|
||
for rows.Next() {
|
||
item := &KnowledgeItemSummary{}
|
||
var createdAt, updatedAt string
|
||
|
||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||
}
|
||
|
||
// 解析时间
|
||
timeFormats := []string{
|
||
"2006-01-02 15:04:05.999999999-07:00",
|
||
"2006-01-02 15:04:05.999999999",
|
||
"2006-01-02T15:04:05.999999999Z07:00",
|
||
"2006-01-02T15:04:05Z",
|
||
"2006-01-02 15:04:05",
|
||
time.RFC3339,
|
||
time.RFC3339Nano,
|
||
}
|
||
|
||
if createdAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, createdAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.CreatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if updatedAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, updatedAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.UpdatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||
item.UpdatedAt = item.CreatedAt
|
||
}
|
||
|
||
items = append(items, item)
|
||
}
|
||
|
||
return items, nil
|
||
}
|
||
|
||
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
|
||
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
|
||
// 获取总数
|
||
total, err := m.GetItemsCount(category)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表数据(不包含内容)
|
||
var rows *sql.Rows
|
||
var query string
|
||
var args []interface{}
|
||
|
||
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
||
|
||
if category != "" {
|
||
query += " WHERE category = ?"
|
||
args = append(args, category)
|
||
}
|
||
|
||
query += " ORDER BY category, title"
|
||
|
||
if limit > 0 {
|
||
query += " LIMIT ?"
|
||
args = append(args, limit)
|
||
if offset > 0 {
|
||
query += " OFFSET ?"
|
||
args = append(args, offset)
|
||
}
|
||
}
|
||
|
||
rows, err = m.db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var items []*KnowledgeItemSummary
|
||
for rows.Next() {
|
||
item := &KnowledgeItemSummary{}
|
||
var createdAt, updatedAt string
|
||
|
||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
|
||
}
|
||
|
||
// 解析时间
|
||
timeFormats := []string{
|
||
"2006-01-02 15:04:05.999999999-07:00",
|
||
"2006-01-02 15:04:05.999999999",
|
||
"2006-01-02T15:04:05.999999999Z07:00",
|
||
"2006-01-02T15:04:05Z",
|
||
"2006-01-02 15:04:05",
|
||
time.RFC3339,
|
||
time.RFC3339Nano,
|
||
}
|
||
|
||
if createdAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, createdAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.CreatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if updatedAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, updatedAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.UpdatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||
item.UpdatedAt = item.CreatedAt
|
||
}
|
||
|
||
items = append(items, item)
|
||
}
|
||
|
||
return items, total, nil
|
||
}
|
||
|
||
// GetItem 获取单个知识项
|
||
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
||
item := &KnowledgeItem{}
|
||
var createdAt, updatedAt string
|
||
err := m.db.QueryRow(
|
||
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
|
||
id,
|
||
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
|
||
|
||
if err == sql.ErrNoRows {
|
||
return nil, fmt.Errorf("知识项不存在")
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
||
}
|
||
|
||
// 解析时间 - 支持多种格式
|
||
timeFormats := []string{
|
||
"2006-01-02 15:04:05.999999999-07:00",
|
||
"2006-01-02 15:04:05.999999999",
|
||
"2006-01-02T15:04:05.999999999Z07:00",
|
||
"2006-01-02T15:04:05Z",
|
||
"2006-01-02 15:04:05",
|
||
time.RFC3339,
|
||
time.RFC3339Nano,
|
||
}
|
||
|
||
// 解析创建时间
|
||
if createdAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, createdAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.CreatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 解析更新时间
|
||
if updatedAt != "" {
|
||
for _, format := range timeFormats {
|
||
parsed, err := time.Parse(format, updatedAt)
|
||
if err == nil && !parsed.IsZero() {
|
||
item.UpdatedAt = parsed
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果更新时间为空,使用创建时间
|
||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||
item.UpdatedAt = item.CreatedAt
|
||
}
|
||
|
||
return item, nil
|
||
}
|
||
|
||
// CreateItem 创建知识项
|
||
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
|
||
id := uuid.New().String()
|
||
now := time.Now()
|
||
|
||
// 构建文件路径
|
||
filePath := filepath.Join(m.basePath, category, title+".md")
|
||
|
||
// 确保目录存在
|
||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
||
return nil, fmt.Errorf("创建目录失败: %w", err)
|
||
}
|
||
|
||
// 写入文件
|
||
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||
return nil, fmt.Errorf("写入文件失败: %w", err)
|
||
}
|
||
|
||
// 插入数据库
|
||
_, err := m.db.Exec(
|
||
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||
id, category, title, filePath, content, now, now,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("插入知识项失败: %w", err)
|
||
}
|
||
|
||
return &KnowledgeItem{
|
||
ID: id,
|
||
Category: category,
|
||
Title: title,
|
||
FilePath: filePath,
|
||
Content: content,
|
||
CreatedAt: now,
|
||
UpdatedAt: now,
|
||
}, nil
|
||
}
|
||
|
||
// UpdateItem 更新知识项
|
||
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
|
||
// 获取现有项
|
||
item, err := m.GetItem(id)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 构建新文件路径
|
||
newFilePath := filepath.Join(m.basePath, category, title+".md")
|
||
|
||
// 如果路径改变,需要移动文件
|
||
if item.FilePath != newFilePath {
|
||
// 确保新目录存在
|
||
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
|
||
return nil, fmt.Errorf("创建目录失败: %w", err)
|
||
}
|
||
|
||
// 移动文件
|
||
if err := os.Rename(item.FilePath, newFilePath); err != nil {
|
||
return nil, fmt.Errorf("移动文件失败: %w", err)
|
||
}
|
||
|
||
// 删除旧目录(如果为空)
|
||
oldDir := filepath.Dir(item.FilePath)
|
||
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
|
||
os.Remove(oldDir)
|
||
}
|
||
}
|
||
|
||
// 写入文件
|
||
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
|
||
return nil, fmt.Errorf("写入文件失败: %w", err)
|
||
}
|
||
|
||
// 更新数据库
|
||
_, err = m.db.Exec(
|
||
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
|
||
category, title, newFilePath, content, time.Now(), id,
|
||
)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("更新知识项失败: %w", err)
|
||
}
|
||
|
||
// 删除旧的向量嵌入(需要重新索引)
|
||
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
|
||
if err != nil {
|
||
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
|
||
}
|
||
|
||
return m.GetItem(id)
|
||
}
|
||
|
||
// DeleteItem 删除知识项
|
||
func (m *Manager) DeleteItem(id string) error {
|
||
// 获取文件路径
|
||
var filePath string
|
||
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
|
||
if err != nil {
|
||
return fmt.Errorf("查询知识项失败: %w", err)
|
||
}
|
||
|
||
// 删除文件
|
||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
|
||
}
|
||
|
||
// 删除数据库记录(级联删除向量)
|
||
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
|
||
if err != nil {
|
||
return fmt.Errorf("删除知识项失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// LogRetrieval 记录检索日志
|
||
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
|
||
id := uuid.New().String()
|
||
itemsJSON, _ := json.Marshal(retrievedItems)
|
||
|
||
_, err := m.db.Exec(
|
||
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
|
||
)
|
||
return err
|
||
}
|
||
|
||
// GetIndexStatus 获取索引状态
|
||
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
|
||
// 获取总知识项数
|
||
var totalItems int
|
||
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
|
||
}
|
||
|
||
// 获取已索引的知识项数(有向量嵌入的)
|
||
var indexedItems int
|
||
err = m.db.QueryRow(`
|
||
SELECT COUNT(DISTINCT item_id)
|
||
FROM knowledge_embeddings
|
||
`).Scan(&indexedItems)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
|
||
}
|
||
|
||
// 计算进度百分比
|
||
var progressPercent float64
|
||
if totalItems > 0 {
|
||
progressPercent = float64(indexedItems) / float64(totalItems) * 100
|
||
} else {
|
||
progressPercent = 100.0
|
||
}
|
||
|
||
// 判断是否完成
|
||
isComplete := indexedItems >= totalItems && totalItems > 0
|
||
|
||
return map[string]interface{}{
|
||
"total_items": totalItems,
|
||
"indexed_items": indexedItems,
|
||
"progress_percent": progressPercent,
|
||
"is_complete": isComplete,
|
||
}, nil
|
||
}
|
||
|
||
// GetRetrievalLogs 获取检索日志
|
||
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
|
||
var rows *sql.Rows
|
||
var err error
|
||
|
||
if messageID != "" {
|
||
rows, err = m.db.Query(
|
||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
|
||
messageID, limit,
|
||
)
|
||
} else if conversationID != "" {
|
||
rows, err = m.db.Query(
|
||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
|
||
conversationID, limit,
|
||
)
|
||
} else {
|
||
rows, err = m.db.Query(
|
||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
|
||
limit,
|
||
)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询检索日志失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var logs []*RetrievalLog
|
||
for rows.Next() {
|
||
log := &RetrievalLog{}
|
||
var createdAt string
|
||
var itemsJSON sql.NullString
|
||
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
|
||
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
|
||
}
|
||
|
||
// 解析时间 - 支持多种格式
|
||
var err error
|
||
timeFormats := []string{
|
||
"2006-01-02 15:04:05.999999999-07:00",
|
||
"2006-01-02 15:04:05.999999999",
|
||
"2006-01-02T15:04:05.999999999Z07:00",
|
||
"2006-01-02T15:04:05Z",
|
||
"2006-01-02 15:04:05",
|
||
time.RFC3339,
|
||
time.RFC3339Nano,
|
||
}
|
||
|
||
for _, format := range timeFormats {
|
||
log.CreatedAt, err = time.Parse(format, createdAt)
|
||
if err == nil && !log.CreatedAt.IsZero() {
|
||
break
|
||
}
|
||
}
|
||
|
||
// 如果所有格式都失败,记录警告但继续处理
|
||
if log.CreatedAt.IsZero() {
|
||
m.logger.Warn("解析检索日志时间失败",
|
||
zap.String("timeStr", createdAt),
|
||
zap.Error(err),
|
||
)
|
||
// 使用当前时间作为fallback
|
||
log.CreatedAt = time.Now()
|
||
}
|
||
|
||
// 解析检索项
|
||
if itemsJSON.Valid {
|
||
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
|
||
}
|
||
|
||
logs = append(logs, log)
|
||
}
|
||
|
||
return logs, nil
|
||
}
|
||
|
||
// DeleteRetrievalLog 删除检索日志
|
||
func (m *Manager) DeleteRetrievalLog(id string) error {
|
||
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
|
||
if err != nil {
|
||
return fmt.Errorf("删除检索日志失败: %w", err)
|
||
}
|
||
|
||
rowsAffected, err := result.RowsAffected()
|
||
if err != nil {
|
||
return fmt.Errorf("获取删除行数失败: %w", err)
|
||
}
|
||
|
||
if rowsAffected == 0 {
|
||
return fmt.Errorf("检索日志不存在")
|
||
}
|
||
|
||
return nil
|
||
}
|