Add files via upload

This commit is contained in:
公明
2025-12-20 17:36:40 +08:00
committed by GitHub
parent b659fb7445
commit abc4085c8a
21 changed files with 5234 additions and 46 deletions
+205
View File
@@ -0,0 +1,205 @@
package knowledge
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap"
)
// Embedder 文本嵌入器
type Embedder struct {
openAIClient *openai.Client
config *config.KnowledgeConfig
openAIConfig *config.OpenAIConfig // 用于获取API Key
logger *zap.Logger
}
// NewEmbedder 创建新的嵌入器
func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder {
return &Embedder{
openAIClient: openAIClient,
config: cfg,
openAIConfig: openAIConfig,
logger: logger,
}
}
// EmbeddingRequest OpenAI嵌入请求
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
// EmbeddingResponse OpenAI嵌入响应
type EmbeddingResponse struct {
Data []EmbeddingData `json:"data"`
Error *EmbeddingError `json:"error,omitempty"`
}
// EmbeddingData 嵌入数据
type EmbeddingData struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
// EmbeddingError 嵌入错误
type EmbeddingError struct {
Message string `json:"message"`
Type string `json:"type"`
}
// EmbedText 对文本进行嵌入
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
if e.openAIClient == nil {
return nil, fmt.Errorf("OpenAI客户端未初始化")
}
// 使用配置的嵌入模型
model := e.config.Embedding.Model
if model == "" {
model = "text-embedding-3-small"
}
req := EmbeddingRequest{
Model: model,
Input: []string{text},
}
// 清理baseURL:去除前后空格和尾部斜杠
baseURL := strings.TrimSpace(e.config.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
// 构建请求
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
requestURL := baseURL + "/embeddings"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 使用配置的API Key,如果没有则使用OpenAI配置的
apiKey := strings.TrimSpace(e.config.Embedding.APIKey)
if apiKey == "" && e.openAIConfig != nil {
apiKey = e.openAIConfig.APIKey
}
if apiKey == "" {
return nil, fmt.Errorf("API Key未配置")
}
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// 发送请求
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
// 读取响应体以便在错误时输出详细信息
bodyBytes := make([]byte, 0)
buf := make([]byte, 4096)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
bodyBytes = append(bodyBytes, buf[:n]...)
}
if err != nil {
break
}
}
// 记录请求和响应信息(用于调试)
requestBodyPreview := string(body)
if len(requestBodyPreview) > 200 {
requestBodyPreview = requestBodyPreview[:200] + "..."
}
e.logger.Debug("嵌入API请求",
zap.String("url", httpReq.URL.String()),
zap.String("model", model),
zap.String("requestBody", requestBodyPreview),
zap.Int("status", resp.StatusCode),
zap.Int("bodySize", len(bodyBytes)),
zap.String("contentType", resp.Header.Get("Content-Type")),
)
var embeddingResp EmbeddingResponse
if err := json.Unmarshal(bodyBytes, &embeddingResp); err != nil {
// 输出详细的错误信息
bodyPreview := string(bodyBytes)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码: %d, 响应长度: %d字节): %w\n请求体: %s\n响应内容预览: %s",
requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview)
}
if embeddingResp.Error != nil {
return nil, fmt.Errorf("OpenAI API错误 (状态码: %d): 类型=%s, 消息=%s",
resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message)
}
if resp.StatusCode != http.StatusOK {
bodyPreview := string(bodyBytes)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("HTTP请求失败 (URL: %s, 状态码: %d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
}
if len(embeddingResp.Data) == 0 {
bodyPreview := string(bodyBytes)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("未收到嵌入数据 (状态码: %d, 响应长度: %d字节)\n响应内容: %s",
resp.StatusCode, len(bodyBytes), bodyPreview)
}
// 转换为float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
}
return embedding, nil
}
// EmbedTexts 批量嵌入文本
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// OpenAI API支持批量,但为了简单起见,我们逐个处理
// 实际可以使用批量API以提高效率
embeddings := make([][]float32, len(texts))
for i, text := range texts {
embedding, err := e.EmbedText(ctx, text)
if err != nil {
return nil, fmt.Errorf("嵌入文本[%d]失败: %w", i, err)
}
embeddings[i] = embedding
}
return embeddings, nil
}
+247
View File
@@ -0,0 +1,247 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Indexer 索引器,负责将知识项分块并向量化
type Indexer struct {
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int // 每个块的最大token数(估算)
overlap int // 块之间的重叠token数
}
// NewIndexer 创建新的索引器
func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger) *Indexer {
return &Indexer{
db: db,
embedder: embedder,
logger: logger,
chunkSize: 512, // 默认512 tokens
overlap: 50, // 默认50 tokens重叠
}
}
// ChunkText 将文本分块
func (idx *Indexer) ChunkText(text string) []string {
// 按Markdown标题分割
chunks := idx.splitByMarkdownHeaders(text)
// 如果块太大,进一步分割
result := make([]string, 0)
for _, chunk := range chunks {
if idx.estimateTokens(chunk) <= idx.chunkSize {
result = append(result, chunk)
} else {
// 按段落分割
subChunks := idx.splitByParagraphs(chunk)
for _, subChunk := range subChunks {
if idx.estimateTokens(subChunk) <= idx.chunkSize {
result = append(result, subChunk)
} else {
// 按句子分割
sentences := idx.splitBySentences(subChunk)
currentChunk := ""
for _, sentence := range sentences {
testChunk := currentChunk
if testChunk != "" {
testChunk += "\n"
}
testChunk += sentence
if idx.estimateTokens(testChunk) > idx.chunkSize && currentChunk != "" {
result = append(result, currentChunk)
currentChunk = sentence
} else {
currentChunk = testChunk
}
}
if currentChunk != "" {
result = append(result, currentChunk)
}
}
}
}
}
return result
}
// splitByMarkdownHeaders 按Markdown标题分割
func (idx *Indexer) splitByMarkdownHeaders(text string) []string {
// 匹配Markdown标题 (# ## ### 等)
headerRegex := regexp.MustCompile(`(?m)^#{1,6}\s+.+$`)
// 找到所有标题位置
matches := headerRegex.FindAllStringIndex(text, -1)
if len(matches) == 0 {
return []string{text}
}
chunks := make([]string, 0)
lastPos := 0
for _, match := range matches {
start := match[0]
if start > lastPos {
chunks = append(chunks, strings.TrimSpace(text[lastPos:start]))
}
lastPos = start
}
// 添加最后一部分
if lastPos < len(text) {
chunks = append(chunks, strings.TrimSpace(text[lastPos:]))
}
// 过滤空块
result := make([]string, 0)
for _, chunk := range chunks {
if strings.TrimSpace(chunk) != "" {
result = append(result, chunk)
}
}
if len(result) == 0 {
return []string{text}
}
return result
}
// splitByParagraphs 按段落分割
func (idx *Indexer) splitByParagraphs(text string) []string {
paragraphs := strings.Split(text, "\n\n")
result := make([]string, 0)
for _, p := range paragraphs {
if strings.TrimSpace(p) != "" {
result = append(result, strings.TrimSpace(p))
}
}
return result
}
// splitBySentences 按句子分割
func (idx *Indexer) splitBySentences(text string) []string {
// 简单的句子分割(按句号、问号、感叹号)
sentenceRegex := regexp.MustCompile(`[.!?]+\s+`)
sentences := sentenceRegex.Split(text, -1)
result := make([]string, 0)
for _, s := range sentences {
if strings.TrimSpace(s) != "" {
result = append(result, strings.TrimSpace(s))
}
}
return result
}
// estimateTokens 估算token数(简单估算:1 token ≈ 4字符)
func (idx *Indexer) estimateTokens(text string) int {
return len([]rune(text)) / 4
}
// IndexItem 索引知识项(分块并向量化)
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
// 获取知识项
var content string
err := idx.db.QueryRow("SELECT content FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content)
if err != nil {
return fmt.Errorf("获取知识项失败: %w", err)
}
// 删除旧的向量
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID)
if err != nil {
return fmt.Errorf("删除旧向量失败: %w", err)
}
// 分块
chunks := idx.ChunkText(content)
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
// 向量化每个块
for i, chunk := range chunks {
chunkPreview := chunk
if len(chunkPreview) > 200 {
chunkPreview = chunkPreview[:200] + "..."
}
embedding, err := idx.embedder.EmbedText(ctx, chunk)
if err != nil {
idx.logger.Warn("向量化失败",
zap.String("itemId", itemID),
zap.Int("chunkIndex", i),
zap.Int("chunkLength", len(chunk)),
zap.String("chunkPreview", chunkPreview),
zap.Error(err),
)
continue
}
// 保存向量
chunkID := uuid.New().String()
embeddingJSON, _ := json.Marshal(embedding)
_, err = idx.db.Exec(
"INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, created_at) VALUES (?, ?, ?, ?, ?, datetime('now'))",
chunkID, itemID, i, chunk, string(embeddingJSON),
)
if err != nil {
idx.logger.Warn("保存向量失败", zap.String("itemId", itemID), zap.Int("chunkIndex", i), zap.Error(err))
continue
}
}
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
return nil
}
// HasIndex 检查是否存在索引
func (idx *Indexer) HasIndex() (bool, error) {
var count int
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败: %w", err)
}
return count > 0, nil
}
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var itemIDs []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return fmt.Errorf("扫描知识项ID失败: %w", err)
}
itemIDs = append(itemIDs, id)
}
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
idx.logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
continue
}
idx.logger.Debug("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)))
}
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)))
return nil
}
+447
View File
@@ -0,0 +1,447 @@
package knowledge
import (
"database/sql"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Manager 知识库管理器
type Manager struct {
db *sql.DB
basePath string
logger *zap.Logger
}
// NewManager 创建新的知识库管理器
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
return &Manager{
db: db,
basePath: basePath,
logger: logger,
}
}
// ScanKnowledgeBase 扫描知识库目录,更新数据库
func (m *Manager) ScanKnowledgeBase() error {
if m.basePath == "" {
return fmt.Errorf("知识库路径未配置")
}
// 确保目录存在
if err := os.MkdirAll(m.basePath, 0755); err != nil {
return fmt.Errorf("创建知识库目录失败: %w", err)
}
// 遍历知识库目录
return filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// 跳过目录和非markdown文件
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
return nil
}
// 计算相对路径和分类
relPath, err := filepath.Rel(m.basePath, path)
if err != nil {
return err
}
// 第一个目录名作为分类(风险类型)
parts := strings.Split(relPath, string(filepath.Separator))
category := "未分类"
if len(parts) > 1 {
category = parts[0]
}
// 文件名为标题
title := strings.TrimSuffix(filepath.Base(path), ".md")
// 读取文件内容
content, err := os.ReadFile(path)
if err != nil {
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
return nil // 继续处理其他文件
}
// 检查是否已存在
var existingID string
err = m.db.QueryRow(
"SELECT id FROM knowledge_base_items WHERE file_path = ?",
path,
).Scan(&existingID)
if err == sql.ErrNoRows {
// 创建新项
id := uuid.New().String()
now := time.Now()
_, err = m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, path, string(content), now, now,
)
if err != nil {
return fmt.Errorf("插入知识项失败: %w", err)
}
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
} else if err == nil {
// 更新现有项
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, string(content), time.Now(), existingID,
)
if err != nil {
return fmt.Errorf("更新知识项失败: %w", err)
}
m.logger.Debug("更新知识项", zap.String("id", existingID), zap.String("title", title))
} else {
return fmt.Errorf("查询知识项失败: %w", err)
}
return nil
})
}
// GetCategories 获取所有分类(风险类型)
func (m *Manager) GetCategories() ([]string, error) {
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
if err != nil {
return nil, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
var categories []string
for rows.Next() {
var category string
if err := rows.Scan(&category); err != nil {
return nil, fmt.Errorf("扫描分类失败: %w", err)
}
categories = append(categories, category)
}
return categories, nil
}
// GetItems 获取知识项列表
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
var rows *sql.Rows
var err error
if category != "" {
rows, err = m.db.Query(
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE category = ? ORDER BY title",
category,
)
} else {
rows, err = m.db.Query(
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items ORDER BY category, title",
)
}
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItem
for rows.Next() {
item := &KnowledgeItem{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if item.CreatedAt.IsZero() {
item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
}
item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if item.UpdatedAt.IsZero() {
item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
}
items = append(items, item)
}
return items, nil
}
// GetItem 获取单个知识项
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
item := &KnowledgeItem{}
var createdAt, updatedAt string
err := m.db.QueryRow(
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
id,
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("知识项不存在")
}
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
// 解析时间
item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if item.CreatedAt.IsZero() {
item.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
}
item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if item.UpdatedAt.IsZero() {
item.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
}
return item, nil
}
// CreateItem 创建知识项
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
id := uuid.New().String()
now := time.Now()
// 构建文件路径
filePath := filepath.Join(m.basePath, category, title+".md")
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 写入文件
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 插入数据库
_, err := m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, filePath, content, now, now,
)
if err != nil {
return nil, fmt.Errorf("插入知识项失败: %w", err)
}
return &KnowledgeItem{
ID: id,
Category: category,
Title: title,
FilePath: filePath,
Content: content,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// UpdateItem 更新知识项
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
// 获取现有项
item, err := m.GetItem(id)
if err != nil {
return nil, err
}
// 构建新文件路径
newFilePath := filepath.Join(m.basePath, category, title+".md")
// 如果路径改变,需要移动文件
if item.FilePath != newFilePath {
// 确保新目录存在
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 移动文件
if err := os.Rename(item.FilePath, newFilePath); err != nil {
return nil, fmt.Errorf("移动文件失败: %w", err)
}
// 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath)
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
os.Remove(oldDir)
}
}
// 写入文件
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 更新数据库
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, newFilePath, content, time.Now(), id,
)
if err != nil {
return nil, fmt.Errorf("更新知识项失败: %w", err)
}
// 删除旧的向量嵌入(需要重新索引)
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
if err != nil {
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
}
return m.GetItem(id)
}
// DeleteItem 删除知识项
func (m *Manager) DeleteItem(id string) error {
// 获取文件路径
var filePath string
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
}
// 删除文件
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
}
// 删除数据库记录(级联删除向量)
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除知识项失败: %w", err)
}
return nil
}
// LogRetrieval 记录检索日志
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
id := uuid.New().String()
itemsJSON, _ := json.Marshal(retrievedItems)
_, err := m.db.Exec(
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
)
return err
}
// GetIndexStatus 获取索引状态
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
// 获取总知识项数
var totalItems int
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
}
// 获取已索引的知识项数(有向量嵌入的)
var indexedItems int
err = m.db.QueryRow(`
SELECT COUNT(DISTINCT item_id)
FROM knowledge_embeddings
`).Scan(&indexedItems)
if err != nil {
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
}
// 计算进度百分比
var progressPercent float64
if totalItems > 0 {
progressPercent = float64(indexedItems) / float64(totalItems) * 100
} else {
progressPercent = 100.0
}
// 判断是否完成
isComplete := indexedItems >= totalItems && totalItems > 0
return map[string]interface{}{
"total_items": totalItems,
"indexed_items": indexedItems,
"progress_percent": progressPercent,
"is_complete": isComplete,
}, nil
}
// GetRetrievalLogs 获取检索日志
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
var rows *sql.Rows
var err error
if messageID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
messageID, limit,
)
} else if conversationID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
conversationID, limit,
)
} else {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
limit,
)
}
if err != nil {
return nil, fmt.Errorf("查询检索日志失败: %w", err)
}
defer rows.Close()
var logs []*RetrievalLog
for rows.Next() {
log := &RetrievalLog{}
var createdAt string
var itemsJSON sql.NullString
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
}
// 解析时间 - 支持多种格式
var err error
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
for _, format := range timeFormats {
log.CreatedAt, err = time.Parse(format, createdAt)
if err == nil && !log.CreatedAt.IsZero() {
break
}
}
// 如果所有格式都失败,记录警告但继续处理
if log.CreatedAt.IsZero() {
m.logger.Warn("解析检索日志时间失败",
zap.String("timeStr", createdAt),
zap.Error(err),
)
// 使用当前时间作为fallback
log.CreatedAt = time.Now()
}
// 解析检索项
if itemsJSON.Valid {
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
}
logs = append(logs, log)
}
return logs, nil
}
+230
View File
@@ -0,0 +1,230 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"strings"
"go.uber.org/zap"
)
// Retriever 检索器
type Retriever struct {
db *sql.DB
embedder *Embedder
config *RetrievalConfig
logger *zap.Logger
}
// RetrievalConfig 检索配置
type RetrievalConfig struct {
TopK int
SimilarityThreshold float64
HybridWeight float64
}
// NewRetriever 创建新的检索器
func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever {
return &Retriever{
db: db,
embedder: embedder,
config: config,
logger: logger,
}
}
// cosineSimilarity 计算余弦相似度
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
return 0.0
}
var dotProduct, normA, normB float64
for i := range a {
dotProduct += float64(a[i] * b[i])
normA += float64(a[i] * a[i])
normB += float64(b[i] * b[i])
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// bm25Score 计算BM25分数(简化版)
func (r *Retriever) bm25Score(query, text string) float64 {
queryTerms := strings.Fields(strings.ToLower(query))
textLower := strings.ToLower(text)
textTerms := strings.Fields(textLower)
score := 0.0
for _, term := range queryTerms {
termFreq := 0
for _, textTerm := range textTerms {
if textTerm == term {
termFreq++
}
}
if termFreq > 0 {
// 简化的BM25公式
score += float64(termFreq) / float64(len(textTerms))
}
}
return score / float64(len(queryTerms))
}
// Search 搜索知识库
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req.Query == "" {
return nil, fmt.Errorf("查询不能为空")
}
topK := req.TopK
if topK <= 0 {
topK = r.config.TopK
}
if topK == 0 {
topK = 5
}
threshold := req.Threshold
if threshold <= 0 {
threshold = r.config.SimilarityThreshold
}
if threshold == 0 {
threshold = 0.7
}
// 向量化查询
queryEmbedding, err := r.embedder.EmbedText(ctx, req.Query)
if err != nil {
return nil, fmt.Errorf("向量化查询失败: %w", err)
}
// 查询所有向量(或按风险类型过滤)
var rows *sql.Rows
if req.RiskType != "" {
rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE i.category = ?
`, req.RiskType)
} else {
rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
`)
}
if err != nil {
return nil, fmt.Errorf("查询向量失败: %w", err)
}
defer rows.Close()
// 计算相似度
type candidate struct {
chunk *KnowledgeChunk
item *KnowledgeItem
similarity float64
bm25Score float64
}
candidates := make([]candidate, 0)
for rows.Next() {
var chunkID, itemID, chunkText, embeddingJSON, category, title string
var chunkIndex int
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &category, &title); err != nil {
r.logger.Warn("扫描向量失败", zap.Error(err))
continue
}
// 解析向量
var embedding []float32
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
r.logger.Warn("解析向量失败", zap.Error(err))
continue
}
// 计算余弦相似度
similarity := cosineSimilarity(queryEmbedding, embedding)
// 计算BM25分数
bm25Score := r.bm25Score(req.Query, chunkText)
// 过滤低相似度结果
if similarity < threshold {
continue
}
chunk := &KnowledgeChunk{
ID: chunkID,
ItemID: itemID,
ChunkIndex: chunkIndex,
ChunkText: chunkText,
Embedding: embedding,
}
item := &KnowledgeItem{
ID: itemID,
Category: category,
Title: title,
}
candidates = append(candidates, candidate{
chunk: chunk,
item: item,
similarity: similarity,
bm25Score: bm25Score,
})
}
// 混合排序(向量相似度 + BM25)
hybridWeight := r.config.HybridWeight
if hybridWeight == 0 {
hybridWeight = 0.7
}
// 按混合分数排序(简化:主要按相似度,BM25作为次要因素)
// 这里我们主要使用相似度,因为BM25分数可能不稳定
// 实际可以使用更复杂的混合策略
// 选择Top-K
if len(candidates) > topK {
// 简单排序(按相似度)
for i := 0; i < len(candidates)-1; i++ {
for j := i + 1; j < len(candidates); j++ {
if candidates[i].similarity < candidates[j].similarity {
candidates[i], candidates[j] = candidates[j], candidates[i]
}
}
}
candidates = candidates[:topK]
}
// 转换为结果
results := make([]*RetrievalResult, len(candidates))
for i, cand := range candidates {
// 计算混合分数
normalizedBM25 := math.Min(cand.bm25Score, 1.0)
hybridScore := hybridWeight*cand.similarity + (1-hybridWeight)*normalizedBM25
results[i] = &RetrievalResult{
Chunk: cand.chunk,
Item: cand.item,
Similarity: cand.similarity,
Score: hybridScore,
}
}
return results, nil
}
+191
View File
@@ -0,0 +1,191 @@
package knowledge
import (
"context"
"encoding/json"
"fmt"
"strings"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
func RegisterKnowledgeTool(
mcpServer *mcp.Server,
retriever *Retriever,
manager *Manager,
logger *zap.Logger,
) {
// manager 和 retriever 在 handler 中直接使用参数
_ = manager // 保留参数,可能将来用于日志记录等
tool := mcp.Tool{
Name: "search_knowledge_base",
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。",
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "搜索查询内容,描述你想要了解的安全知识主题",
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等),如果不指定则搜索所有类型",
},
},
"required": []string{"query"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 查询参数不能为空",
},
},
IsError: true,
}, nil
}
riskType := ""
if rt, ok := args["risk_type"].(string); ok && rt != "" {
riskType = rt
}
logger.Info("执行知识库检索",
zap.String("query", query),
zap.String("riskType", riskType),
)
// 执行检索
searchReq := &SearchRequest{
Query: query,
RiskType: riskType,
TopK: 5,
}
results, err := retriever.Search(ctx, searchReq)
if err != nil {
logger.Error("知识库检索失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("检索失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(results) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
},
},
}, nil
}
// 格式化结果
var resultText strings.Builder
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识:\n\n", len(results)))
// 收集检索到的知识项ID(用于日志)
retrievedItemIDs := make([]string, 0, len(results))
for i, result := range results {
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", i+1, result.Similarity*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s\n", result.Item.Category, result.Item.Title))
resultText.WriteString(fmt.Sprintf("内容:\n%s\n\n", result.Chunk.ChunkText))
if !contains(retrievedItemIDs, result.Item.ID) {
retrievedItemIDs = append(retrievedItemIDs, result.Item.ID)
}
}
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
// 使用特殊标记,避免影响AI阅读结果
if len(retrievedItemIDs) > 0 {
metadataJSON, _ := json.Marshal(map[string]interface{}{
"_metadata": map[string]interface{}{
"retrievedItemIDs": retrievedItemIDs,
},
})
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
}
// 记录检索日志(异步,不阻塞)
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
// 实际的日志记录应该在Agent的progressCallback中完成
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
}
// contains 检查切片是否包含元素
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
if q, ok := args["query"].(string); ok {
query = q
}
if rt, ok := args["risk_type"].(string); ok {
riskType = rt
}
return
}
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
func FormatRetrievalResults(results []*RetrievalResult) string {
if len(results) == 0 {
return "未找到相关结果"
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
itemIDs := make(map[string]bool)
for i, result := range results {
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
itemIDs[result.Item.ID] = true
}
// 返回知识项ID列表(JSON格式)
ids := make([]string, 0, len(itemIDs))
for id := range itemIDs {
ids = append(ids, id)
}
idsJSON, _ := json.Marshal(ids)
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
return builder.String()
}
+67
View File
@@ -0,0 +1,67 @@
package knowledge
import (
"encoding/json"
"time"
)
// KnowledgeItem 知识库项
type KnowledgeItem struct {
ID string `json:"id"`
Category string `json:"category"` // 风险类型(文件夹名)
Title string `json:"title"` // 标题(文件名)
FilePath string `json:"filePath"` // 文件路径
Content string `json:"content"` // 文件内容
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// KnowledgeChunk 知识块(用于向量化)
type KnowledgeChunk struct {
ID string `json:"id"`
ItemID string `json:"itemId"`
ChunkIndex int `json:"chunkIndex"`
ChunkText string `json:"chunkText"`
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON
CreatedAt time.Time `json:"createdAt"`
}
// RetrievalResult 检索结果
type RetrievalResult struct {
Chunk *KnowledgeChunk `json:"chunk"`
Item *KnowledgeItem `json:"item"`
Similarity float64 `json:"similarity"` // 相似度分数
Score float64 `json:"score"` // 综合分数(混合检索)
}
// RetrievalLog 检索日志
type RetrievalLog struct {
ID string `json:"id"`
ConversationID string `json:"conversationId,omitempty"`
MessageID string `json:"messageId,omitempty"`
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"`
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表
CreatedAt time.Time `json:"createdAt"`
}
// MarshalJSON 自定义JSON序列化,确保时间格式正确
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
type Alias RetrievalLog
return json.Marshal(&struct {
*Alias
CreatedAt string `json:"createdAt"`
}{
Alias: (*Alias)(r),
CreatedAt: r.CreatedAt.Format(time.RFC3339),
})
}
// SearchRequest 搜索请求
type SearchRequest struct {
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7
}