From 389fc971c6255a7a38e23a5ff6cd8f523237b78b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 18 Apr 2026 23:35:49 +0800 Subject: [PATCH] Add files via upload --- internal/app/app.go | 35 +- internal/config/config.go | 50 +- internal/database/database.go | 40 + internal/handler/config.go | 33 +- internal/handler/knowledge.go | 1 + internal/handler/openapi.go | 2 +- internal/knowledge/chunk_eino.go | 67 ++ internal/knowledge/eino_meta.go | 129 ++++ internal/knowledge/eino_meta_test.go | 14 + internal/knowledge/eino_retrieve_chain.go | 25 + .../knowledge/eino_retrieve_chain_test.go | 23 + internal/knowledge/eino_retriever_adapter.go | 202 +++++ internal/knowledge/eino_sqlite_indexer.go | 142 ++++ internal/knowledge/embedder.go | 440 +++++------ internal/knowledge/index_pipeline.go | 91 +++ internal/knowledge/index_pipeline_test.go | 21 + internal/knowledge/indexer.go | 716 ++++------------- internal/knowledge/retrieval_postprocess.go | 213 ++++++ .../knowledge/retrieval_postprocess_test.go | 62 ++ internal/knowledge/retriever.go | 719 +++++------------- internal/knowledge/schema_migrate.go | 51 ++ internal/knowledge/tool.go | 21 +- internal/knowledge/types.go | 11 +- 23 files changed, 1695 insertions(+), 1413 deletions(-) create mode 100644 internal/knowledge/chunk_eino.go create mode 100644 internal/knowledge/eino_meta.go create mode 100644 internal/knowledge/eino_meta_test.go create mode 100644 internal/knowledge/eino_retrieve_chain.go create mode 100644 internal/knowledge/eino_retrieve_chain_test.go create mode 100644 internal/knowledge/eino_retriever_adapter.go create mode 100644 internal/knowledge/eino_sqlite_indexer.go create mode 100644 internal/knowledge/index_pipeline.go create mode 100644 internal/knowledge/index_pipeline_test.go create mode 100644 internal/knowledge/retrieval_postprocess.go create mode 100644 internal/knowledge/retrieval_postprocess_test.go create mode 100644 internal/knowledge/schema_migrate.go diff --git a/internal/app/app.go b/internal/app/app.go index 28e79f75..69161824 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -19,7 +19,6 @@ import ( "cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" - "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/robot" "cyberstrike-ai/internal/security" "cyberstrike-ai/internal/skills" @@ -185,22 +184,25 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL } - httpClient := &http.Client{ - Timeout: 30 * time.Minute, + embedder, err := knowledge.NewEmbedder(context.Background(), &cfg.Knowledge, &cfg.OpenAI, log.Logger) + if err != nil { + return nil, fmt.Errorf("初始化知识库嵌入器失败: %w", err) } - openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger) - embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger) // 创建检索器 retrievalConfig := &knowledge.RetrievalConfig{ TopK: cfg.Knowledge.Retrieval.TopK, SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold, - HybridWeight: cfg.Knowledge.Retrieval.HybridWeight, + SubIndexFilter: cfg.Knowledge.Retrieval.SubIndexFilter, + PostRetrieve: cfg.Knowledge.Retrieval.PostRetrieve, } knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger) - // 创建索引器 - knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger, &cfg.Knowledge.Indexing) + // 创建索引器(Eino Compose 链) + knowledgeIndexer, err = knowledge.NewIndexer(context.Background(), knowledgeDB, embedder, log.Logger, &cfg.Knowledge) + if err != nil { + return nil, fmt.Errorf("初始化知识库索引器失败: %w", err) + } // 注册知识检索工具到MCP服务器 knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger) @@ -1697,22 +1699,25 @@ func initializeKnowledge( cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL } - httpClient := &http.Client{ - Timeout: 30 * time.Minute, + embedder, err := knowledge.NewEmbedder(context.Background(), &cfg.Knowledge, &cfg.OpenAI, logger) + if err != nil { + return nil, fmt.Errorf("初始化知识库嵌入器失败: %w", err) } - openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger) - embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger) // 创建检索器 retrievalConfig := &knowledge.RetrievalConfig{ TopK: cfg.Knowledge.Retrieval.TopK, SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold, - HybridWeight: cfg.Knowledge.Retrieval.HybridWeight, + SubIndexFilter: cfg.Knowledge.Retrieval.SubIndexFilter, + PostRetrieve: cfg.Knowledge.Retrieval.PostRetrieve, } knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger) - // 创建索引器 - knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger, &cfg.Knowledge.Indexing) + // 创建索引器(Eino Compose 链) + knowledgeIndexer, err := knowledge.NewIndexer(context.Background(), knowledgeDB, embedder, logger, &cfg.Knowledge) + if err != nil { + return nil, fmt.Errorf("初始化知识库索引器失败: %w", err) + } // 注册知识检索工具到MCP服务器 knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger) diff --git a/internal/config/config.go b/internal/config/config.go index fac227ae..17831e71 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -754,16 +754,20 @@ func Default() *Config { Retrieval: RetrievalConfig{ TopK: 5, SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检 - HybridWeight: 0.7, }, Indexing: IndexingConfig{ - ChunkSize: 768, // 增加到 768,更好的上下文保持 - ChunkOverlap: 50, - MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额 - MaxRPM: 100, // 默认 100 RPM,避免 429 错误 - RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM - MaxRetries: 3, - RetryDelayMs: 1000, + ChunkStrategy: "markdown_then_recursive", + RequestTimeoutSeconds: 120, + ChunkSize: 768, // 增加到 768,更好的上下文保持 + ChunkOverlap: 50, + MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额 + BatchSize: 64, + PreferSourceFile: false, + MaxRPM: 100, // 默认 100 RPM,避免 429 错误 + RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM + MaxRetries: 3, + RetryDelayMs: 1000, + SubIndexes: nil, }, }, } @@ -780,11 +784,18 @@ type KnowledgeConfig struct { // IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为) type IndexingConfig struct { + // ChunkStrategy: "markdown_then_recursive"(默认,Eino Markdown 标题切分后再递归切)或 "recursive"(仅递归切分) + ChunkStrategy string `yaml:"chunk_strategy,omitempty" json:"chunk_strategy,omitempty"` + // RequestTimeoutSeconds 嵌入 HTTP 客户端超时(秒),0 表示使用默认 120 + RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"` // 分块配置 ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512 ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50 MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制 + // PreferSourceFile 为 true 时优先用 Eino FileLoader 从 file_path 读原文再索引(与库内 content 不一致时以磁盘为准) + PreferSourceFile bool `yaml:"prefer_source_file,omitempty" json:"prefer_source_file,omitempty"` + // 速率限制配置(用于避免 API 速率限制) RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟 MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制 @@ -793,8 +804,10 @@ type IndexingConfig struct { MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3 RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000 - // 批处理配置(用于批量嵌入,当前未使用,保留扩展) - BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` // 批量处理大小,0 表示逐个处理 + // BatchSize 嵌入批大小(SQLite 索引写入),0 表示默认 64 + BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` + // SubIndexes 传入 Eino indexer.WithSubIndexes(逻辑分区标记,随 Document 元数据传递) + SubIndexes []string `yaml:"sub_indexes,omitempty" json:"sub_indexes,omitempty"` } // EmbeddingConfig 嵌入配置 @@ -805,11 +818,24 @@ type EmbeddingConfig struct { APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承) } +// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。 +type PostRetrieveConfig struct { + // PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。 + PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"` + // MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。 + MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"` + // MaxContextTokens 返回文档内容总 token 上限(tiktoken,按嵌入模型名映射,失败则 cl100k_base);0 表示不限制。 + MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"` +} + // RetrievalConfig 检索配置 type RetrievalConfig struct { TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K - SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值 - HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1) + SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值 + // SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。 + SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"` + // PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。 + PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"` } // RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代) diff --git a/internal/database/database.go b/internal/database/database.go index 39593ec4..0e0ec524 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -718,6 +718,9 @@ func (db *DB) initKnowledgeTables() error { chunk_index INTEGER NOT NULL, chunk_text TEXT NOT NULL, embedding TEXT NOT NULL, + sub_indexes TEXT NOT NULL DEFAULT '', + embedding_model TEXT NOT NULL DEFAULT '', + embedding_dim INTEGER NOT NULL DEFAULT 0, created_at DATETIME NOT NULL, FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE );` @@ -759,10 +762,47 @@ func (db *DB) initKnowledgeTables() error { return fmt.Errorf("创建索引失败: %w", err) } + if err := db.migrateKnowledgeEmbeddingsColumns(); err != nil { + return fmt.Errorf("迁移 knowledge_embeddings 列失败: %w", err) + } + db.logger.Info("知识库数据库表初始化完成") return nil } +// migrateKnowledgeEmbeddingsColumns 为已有库补充 sub_indexes、embedding_model、embedding_dim。 +func (db *DB) migrateKnowledgeEmbeddingsColumns() error { + var n int + if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { + return err + } + if n == 0 { + return nil + } + migrations := []struct { + col string + stmt string + }{ + {"sub_indexes", `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`}, + {"embedding_model", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`}, + {"embedding_dim", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`}, + } + for _, m := range migrations { + var colCount int + q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` + if err := db.QueryRow(q, m.col).Scan(&colCount); err != nil { + return err + } + if colCount > 0 { + continue + } + if _, err := db.Exec(m.stmt); err != nil { + return err + } + } + return nil +} + // Close 关闭数据库连接 func (db *DB) Close() error { return db.DB.Close() diff --git a/internal/handler/config.go b/internal/handler/config.go index 766099ea..54bb19f0 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -642,7 +642,6 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { zap.String("embedding_model", h.config.Knowledge.Embedding.Model), zap.Int("retrieval_top_k", h.config.Knowledge.Retrieval.TopK), zap.Float64("similarity_threshold", h.config.Knowledge.Retrieval.SimilarityThreshold), - zap.Float64("hybrid_weight", h.config.Knowledge.Retrieval.HybridWeight), ) } @@ -1051,13 +1050,13 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { retrievalConfig := &knowledge.RetrievalConfig{ TopK: h.config.Knowledge.Retrieval.TopK, SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, - HybridWeight: h.config.Knowledge.Retrieval.HybridWeight, + SubIndexFilter: h.config.Knowledge.Retrieval.SubIndexFilter, + PostRetrieve: h.config.Knowledge.Retrieval.PostRetrieve, } h.retrieverUpdater.UpdateConfig(retrievalConfig) h.logger.Info("检索器配置已更新", zap.Int("top_k", retrievalConfig.TopK), zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), - zap.Float64("hybrid_weight", retrievalConfig.HybridWeight), ) } @@ -1289,13 +1288,22 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { retrievalNode := ensureMap(knowledgeNode, "retrieval") setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK) setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold) - setFloatInMap(retrievalNode, "hybrid_weight", cfg.Retrieval.HybridWeight) + setStringInMap(retrievalNode, "sub_index_filter", cfg.Retrieval.SubIndexFilter) + postNode := ensureMap(retrievalNode, "post_retrieve") + setIntInMap(postNode, "prefetch_top_k", cfg.Retrieval.PostRetrieve.PrefetchTopK) + setIntInMap(postNode, "max_context_chars", cfg.Retrieval.PostRetrieve.MaxContextChars) + setIntInMap(postNode, "max_context_tokens", cfg.Retrieval.PostRetrieve.MaxContextTokens) // 更新索引配置 indexingNode := ensureMap(knowledgeNode, "indexing") + setStringInMap(indexingNode, "chunk_strategy", cfg.Indexing.ChunkStrategy) + setIntInMap(indexingNode, "request_timeout_seconds", cfg.Indexing.RequestTimeoutSeconds) setIntInMap(indexingNode, "chunk_size", cfg.Indexing.ChunkSize) setIntInMap(indexingNode, "chunk_overlap", cfg.Indexing.ChunkOverlap) setIntInMap(indexingNode, "max_chunks_per_item", cfg.Indexing.MaxChunksPerItem) + setBoolInMap(indexingNode, "prefer_source_file", cfg.Indexing.PreferSourceFile) + setIntInMap(indexingNode, "batch_size", cfg.Indexing.BatchSize) + setStringSliceInMap(indexingNode, "sub_indexes", cfg.Indexing.SubIndexes) setIntInMap(indexingNode, "max_rpm", cfg.Indexing.MaxRPM) setIntInMap(indexingNode, "rate_limit_delay_ms", cfg.Indexing.RateLimitDelayMs) setIntInMap(indexingNode, "max_retries", cfg.Indexing.MaxRetries) @@ -1397,6 +1405,21 @@ func setStringInMap(mapNode *yaml.Node, key, value string) { valueNode.Value = value } +func setStringSliceInMap(mapNode *yaml.Node, key string, values []string) { + _, valueNode := ensureKeyValue(mapNode, key) + valueNode.Kind = yaml.SequenceNode + valueNode.Tag = "!!seq" + valueNode.Style = 0 + valueNode.Content = nil + for _, v := range values { + valueNode.Content = append(valueNode.Content, &yaml.Node{ + Kind: yaml.ScalarNode, + Tag: "!!str", + Value: v, + }) + } +} + func setIntInMap(mapNode *yaml.Node, key string, value int) { _, valueNode := ensureKeyValue(mapNode, key) valueNode.Kind = yaml.ScalarNode @@ -1450,7 +1473,7 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) { valueNode.Kind = yaml.ScalarNode valueNode.Tag = "!!float" valueNode.Style = 0 - // 对于0.0到1.0之间的值(如hybrid_weight),使用%.1f确保0.0被明确序列化为"0.0" + // 对于0.0到1.0之间的值(如 similarity_threshold),使用%.1f确保0.0被明确序列化为"0.0" // 对于其他值,使用%g自动选择最合适的格式 if value >= 0.0 && value <= 1.0 { valueNode.Value = fmt.Sprintf("%.1f", value) diff --git a/internal/handler/knowledge.go b/internal/handler/knowledge.go index c92f46e7..76d7b974 100644 --- a/internal/handler/knowledge.go +++ b/internal/handler/knowledge.go @@ -482,6 +482,7 @@ func (h *KnowledgeHandler) Search(c *gin.Context) { return } + // Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。 results, err := h.retriever.Search(c.Request.Context(), &req) if err != nil { h.logger.Error("搜索知识库失败", zap.Error(err)) diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index 6245b34f..5b1b80c0 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -4181,7 +4181,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "post": map[string]interface{}{ "tags": []string{"知识库"}, "summary": "搜索知识库", - "description": "在知识库中搜索相关内容。使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。\n**搜索说明**:\n- 支持语义相似度搜索(向量检索)\n- 支持关键词匹配(BM25)\n- 支持混合搜索(结合向量和关键词)\n- 可以按风险类型过滤(如:SQL注入、XSS、文件上传等)\n- 建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表\n**使用示例**:\n```json\n{\n \"query\": \"SQL注入漏洞的检测方法\",\n \"riskType\": \"SQL注入\",\n \"topK\": 5,\n \"threshold\": 0.7\n}\n```", + "description": "在知识库中搜索相关内容。基于向量检索,按查询与知识片段的语义相似度(余弦)返回最相关结果。\n**搜索说明**:\n- 语义相似度搜索:嵌入向量 + 余弦相似度,可配置相似度阈值与 TopK\n- 可按风险类型等元数据过滤(如:SQL注入、XSS、文件上传等)\n- 建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表\n**使用示例**:\n```json\n{\n \"query\": \"SQL注入漏洞的检测方法\",\n \"riskType\": \"SQL注入\",\n \"topK\": 5,\n \"threshold\": 0.7\n}\n```", "operationId": "searchKnowledge", "requestBody": map[string]interface{}{ "required": true, diff --git a/internal/knowledge/chunk_eino.go b/internal/knowledge/chunk_eino.go new file mode 100644 index 00000000..6592f350 --- /dev/null +++ b/internal/knowledge/chunk_eino.go @@ -0,0 +1,67 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown" + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" + "github.com/cloudwego/eino/components/document" + "github.com/pkoukk/tiktoken-go" +) + +func tokenizerLenFunc(embeddingModel string) func(string) int { + fallback := func(s string) int { + r := []rune(s) + if len(r) == 0 { + return 0 + } + return (len(r) + 3) / 4 + } + m := strings.TrimSpace(embeddingModel) + if m == "" { + return fallback + } + tok, err := tiktoken.EncodingForModel(m) + if err != nil { + return fallback + } + return func(s string) int { + return len(tok.Encode(s, nil, nil)) + } +} + +// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for +// embeddingModel when available, else rune/4 approximation. +func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) { + if chunkSize <= 0 { + return nil, fmt.Errorf("chunk size must be positive") + } + if overlap < 0 { + overlap = 0 + } + return recursive.NewSplitter(context.Background(), &recursive.Config{ + ChunkSize: chunkSize, + OverlapSize: overlap, + LenFunc: tokenizerLenFunc(embeddingModel), + Separators: []string{ + "\n\n", "\n## ", "\n### ", "\n#### ", "\n", + "。", "!", "?", ". ", "? ", "! ", + " ", + }, + }) +} + +// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。 +func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) { + return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{ + Headers: map[string]string{ + "#": "h1", + "##": "h2", + "###": "h3", + "####": "h4", + }, + TrimHeaders: false, + }) +} diff --git a/internal/knowledge/eino_meta.go b/internal/knowledge/eino_meta.go new file mode 100644 index 00000000..2ae419c4 --- /dev/null +++ b/internal/knowledge/eino_meta.go @@ -0,0 +1,129 @@ +package knowledge + +import ( + "fmt" + "strings" +) + +// Document metadata keys for Eino schema.Document flowing through the RAG pipeline. +const ( + metaKBCategory = "kb_category" + metaKBTitle = "kb_title" + metaKBItemID = "kb_item_id" + metaKBChunkIndex = "kb_chunk_index" + metaSimilarity = "similarity" +) + +// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo]. +const ( + DSLRiskType = "risk_type" + DSLSimilarityThreshold = "similarity_threshold" + DSLSubIndexFilter = "sub_index_filter" +) + +// FormatEmbeddingInput matches the historical indexing format so existing embeddings +// stay comparable if users skip reindex; new indexes use the same string shape. +func FormatEmbeddingInput(category, title, chunkText string) string { + return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText) +} + +// FormatQueryEmbeddingText builds the string embedded at query time so it matches +// [FormatEmbeddingInput] for the same risk category (title left empty for queries). +func FormatQueryEmbeddingText(riskType, query string) string { + q := strings.TrimSpace(query) + rt := strings.TrimSpace(riskType) + if rt != "" { + return FormatEmbeddingInput(rt, "", q) + } + return q +} + +// MetaLookupString returns metadata string value or "" if absent. +func MetaLookupString(md map[string]any, key string) string { + if md == nil { + return "" + } + v, ok := md[key] + if !ok || v == nil { + return "" + } + switch t := v.(type) { + case string: + return t + default: + return strings.TrimSpace(fmt.Sprint(t)) + } +} + +// MetaStringOK returns trimmed non-empty string and true if present and non-empty. +func MetaStringOK(md map[string]any, key string) (string, bool) { + s := strings.TrimSpace(MetaLookupString(md, key)) + if s == "" { + return "", false + } + return s, true +} + +// RequireMetaString requires a non-empty string metadata field. +func RequireMetaString(md map[string]any, key string) (string, error) { + s, ok := MetaStringOK(md, key) + if !ok { + return "", fmt.Errorf("missing or empty metadata %q", key) + } + return s, nil +} + +// RequireMetaInt requires an integer metadata field. +func RequireMetaInt(md map[string]any, key string) (int, error) { + if md == nil { + return 0, fmt.Errorf("missing metadata key %q", key) + } + v, ok := md[key] + if !ok { + return 0, fmt.Errorf("missing metadata key %q", key) + } + switch t := v.(type) { + case int: + return t, nil + case int32: + return int(t), nil + case int64: + return int(t), nil + case float64: + return int(t), nil + default: + return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v) + } +} + +// DSLNumeric coerces DSL map values (e.g. from JSON) to float64. +func DSLNumeric(v any) (float64, bool) { + switch t := v.(type) { + case float64: + return t, true + case float32: + return float64(t), true + case int: + return float64(t), true + case int64: + return float64(t), true + case uint32: + return float64(t), true + case uint64: + return float64(t), true + default: + return 0, false + } +} + +// MetaFloat64OK reads a float metadata value. +func MetaFloat64OK(md map[string]any, key string) (float64, bool) { + if md == nil { + return 0, false + } + v, ok := md[key] + if !ok { + return 0, false + } + return DSLNumeric(v) +} diff --git a/internal/knowledge/eino_meta_test.go b/internal/knowledge/eino_meta_test.go new file mode 100644 index 00000000..ba3f60da --- /dev/null +++ b/internal/knowledge/eino_meta_test.go @@ -0,0 +1,14 @@ +package knowledge + +import "testing" + +func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) { + q := FormatQueryEmbeddingText("XSS", "payload") + want := FormatEmbeddingInput("XSS", "", "payload") + if q != want { + t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want) + } + if FormatQueryEmbeddingText("", "hello") != "hello" { + t.Fatalf("expected bare query without risk type") + } +} diff --git a/internal/knowledge/eino_retrieve_chain.go b/internal/knowledge/eino_retrieve_chain.go new file mode 100644 index 00000000..2d1b72eb --- /dev/null +++ b/internal/knowledge/eino_retrieve_chain.go @@ -0,0 +1,25 @@ +package knowledge + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。 +// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。 +func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) { + if r == nil { + return nil, fmt.Errorf("retriever is nil") + } + ch := compose.NewChain[string, []*schema.Document]() + ch.AppendRetriever(r.AsEinoRetriever()) + return ch.Compile(ctx) +} + +// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。 +func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) { + return BuildKnowledgeRetrieveChain(ctx, r) +} diff --git a/internal/knowledge/eino_retrieve_chain_test.go b/internal/knowledge/eino_retrieve_chain_test.go new file mode 100644 index 00000000..c74a6900 --- /dev/null +++ b/internal/knowledge/eino_retrieve_chain_test.go @@ -0,0 +1,23 @@ +package knowledge + +import ( + "context" + "testing" + + "go.uber.org/zap" +) + +func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) { + r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop()) + _, err := BuildKnowledgeRetrieveChain(context.Background(), r) + if err != nil { + t.Fatal(err) + } +} + +func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) { + _, err := BuildKnowledgeRetrieveChain(context.Background(), nil) + if err == nil { + t.Fatal("expected error for nil retriever") + } +} diff --git a/internal/knowledge/eino_retriever_adapter.go b/internal/knowledge/eino_retriever_adapter.go new file mode 100644 index 00000000..f5635121 --- /dev/null +++ b/internal/knowledge/eino_retriever_adapter.go @@ -0,0 +1,202 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity. +// +// Options: +// - [retriever.WithTopK] +// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string) +// +// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric. +// +// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then +// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig]. +type VectorEinoRetriever struct { + inner *Retriever +} + +// NewVectorEinoRetriever wraps r for Eino compose / tooling. +func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever { + if r == nil { + return nil + } + return &VectorEinoRetriever{inner: r} +} + +// GetType identifies this retriever for Eino callbacks. +func (h *VectorEinoRetriever) GetType() string { + return "SQLiteVectorKnowledgeRetriever" +} + +// Retrieve runs vector search and returns [schema.Document] rows. +func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) { + if h == nil || h.inner == nil { + return nil, fmt.Errorf("VectorEinoRetriever: nil retriever") + } + q := strings.TrimSpace(query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") + } + + ro := retriever.GetCommonOptions(nil, opts...) + cfg := h.inner.config + + req := &SearchRequest{Query: q} + + if ro.TopK != nil && *ro.TopK > 0 { + req.TopK = *ro.TopK + } else if cfg != nil && cfg.TopK > 0 { + req.TopK = cfg.TopK + } else { + req.TopK = 5 + } + + req.Threshold = 0 + if ro.DSLInfo != nil { + if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok { + req.RiskType = strings.TrimSpace(rt) + } + if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok { + if f, ok2 := DSLNumeric(v); ok2 && f > 0 { + req.Threshold = f + } + } + if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok { + req.SubIndexFilter = strings.TrimSpace(sf) + } + } + if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" { + req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter) + } + if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 { + req.Threshold = cfg.SimilarityThreshold + } + if req.Threshold <= 0 { + req.Threshold = 0.7 + } + + finalTopK := req.TopK + var postPO *config.PostRetrieveConfig + if cfg != nil { + postPO = &cfg.PostRetrieve + } + fetchK := EffectivePrefetchTopK(finalTopK, postPO) + searchReq := *req + searchReq.TopK = fetchK + + ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever) + th := req.Threshold + st := &th + ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ + Query: q, + TopK: finalTopK, + ScoreThreshold: st, + Extra: ro.DSLInfo, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out}) + }() + + results, err := h.inner.vectorSearch(ctx, &searchReq) + if err != nil { + return nil, err + } + out = retrievalResultsToDocuments(results) + + if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 { + reranked, rerr := rr.Rerank(ctx, q, out) + if rerr != nil { + if h.inner.logger != nil { + h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr)) + } + } else if len(reranked) > 0 { + out = reranked + } + } + + tokenModel := "" + if h.inner.embedder != nil { + tokenModel = h.inner.embedder.EmbeddingModelName() + } + out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK) + if err != nil { + return nil, err + } + return out, nil +} + +func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document { + out := make([]*schema.Document, 0, len(results)) + for _, res := range results { + if res == nil || res.Chunk == nil || res.Item == nil { + continue + } + d := &schema.Document{ + ID: res.Chunk.ID, + Content: res.Chunk.ChunkText, + MetaData: map[string]any{ + metaKBItemID: res.Item.ID, + metaKBCategory: res.Item.Category, + metaKBTitle: res.Item.Title, + metaKBChunkIndex: res.Chunk.ChunkIndex, + metaSimilarity: res.Similarity, + }, + } + d.WithScore(res.Score) + out = append(out, d) + } + return out +} + +func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) { + out := make([]*RetrievalResult, 0, len(docs)) + for i, d := range docs { + if d == nil { + continue + } + itemID, err := RequireMetaString(d.MetaData, metaKBItemID) + if err != nil { + return nil, fmt.Errorf("document %d: %w", i, err) + } + cat := MetaLookupString(d.MetaData, metaKBCategory) + title := MetaLookupString(d.MetaData, metaKBTitle) + chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex) + if err != nil { + return nil, fmt.Errorf("document %d: %w", i, err) + } + sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity) + item := &KnowledgeItem{ID: itemID, Category: cat, Title: title} + chunk := &KnowledgeChunk{ + ID: d.ID, + ItemID: itemID, + ChunkIndex: chunkIdx, + ChunkText: d.Content, + } + out = append(out, &RetrievalResult{ + Chunk: chunk, + Item: item, + Similarity: sim, + Score: d.Score(), + }) + } + return out, nil +} + +var _ retriever.Retriever = (*VectorEinoRetriever)(nil) diff --git a/internal/knowledge/eino_sqlite_indexer.go b/internal/knowledge/eino_sqlite_indexer.go new file mode 100644 index 00000000..a0bbdcdc --- /dev/null +++ b/internal/knowledge/eino_sqlite_indexer.go @@ -0,0 +1,142 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" + "github.com/google/uuid" +) + +// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema. +type SQLiteIndexer struct { + db *sql.DB + batchSize int + embeddingModel string +} + +// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call. +// batchSize is the embedding batch size; if <= 0, default 64 is used. +// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty). +func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer { + return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)} +} + +// GetType implements eino callback run info. +func (s *SQLiteIndexer) GetType() string { + return "SQLiteKnowledgeIndexer" +} + +// Store embeds documents and inserts rows. Each doc must carry MetaData: +// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only. +func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { + options := indexer.GetCommonOptions(nil, opts...) + if options.Embedding == nil { + return nil, fmt.Errorf("sqlite indexer: embedding is required") + } + if len(docs) == 0 { + return nil, nil + } + + ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer) + ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids}) + }() + + subIdxStr := strings.Join(options.SubIndexes, ",") + + texts := make([]string, len(docs)) + for i, d := range docs { + if d == nil { + return nil, fmt.Errorf("sqlite indexer: nil document at %d", i) + } + cat := MetaLookupString(d.MetaData, metaKBCategory) + title := MetaLookupString(d.MetaData, metaKBTitle) + texts[i] = FormatEmbeddingInput(cat, title, d.Content) + } + + bs := s.batchSize + if bs <= 0 { + bs = 64 + } + + var allVecs [][]float64 + for start := 0; start < len(texts); start += bs { + end := start + bs + if end > len(texts) { + end = len(texts) + } + batch := texts[start:end] + vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch) + if embedErr != nil { + return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr) + } + if len(vecs) != len(batch) { + return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch)) + } + allVecs = append(allVecs, vecs...) + } + + embedDim := 0 + if len(allVecs) > 0 { + embedDim = len(allVecs[0]) + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err) + } + defer tx.Rollback() + + ids = make([]string, 0, len(docs)) + for i, d := range docs { + chunkID := uuid.New().String() + itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID) + if metaErr != nil { + return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) + } + chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex) + if metaErr != nil { + return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) + } + vec := allVecs[i] + if embedDim > 0 && len(vec) != embedDim { + return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim) + } + vec32 := make([]float32, len(vec)) + for j, v := range vec { + vec32[j] = float32(v) + } + embeddingJSON, jsonErr := json.Marshal(vec32) + if jsonErr != nil { + return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr) + } + _, err = tx.ExecContext(ctx, + `INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`, + chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim, + ) + if err != nil { + return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err) + } + ids = append(ids, chunkID) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("sqlite indexer: commit: %w", err) + } + return ids, nil +} + +var _ indexer.Indexer = (*SQLiteIndexer)(nil) diff --git a/internal/knowledge/embedder.go b/internal/knowledge/embedder.go index ff62ea30..d9ce8afa 100644 --- a/internal/knowledge/embedder.go +++ b/internal/knowledge/embedder.go @@ -2,7 +2,6 @@ package knowledge import ( "context" - "encoding/json" "fmt" "net/http" "strings" @@ -10,43 +9,47 @@ import ( "time" "cyberstrike-ai/internal/config" - "cyberstrike-ai/internal/openai" + einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai" + "github.com/cloudwego/eino/components/embedding" "go.uber.org/zap" "golang.org/x/time/rate" ) -// Embedder 文本嵌入器 +// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。 type Embedder struct { - openAIClient *openai.Client - config *config.KnowledgeConfig - openAIConfig *config.OpenAIConfig // 用于获取 API Key - logger *zap.Logger - rateLimiter *rate.Limiter // 速率限制器 - rateLimitDelay time.Duration // 请求间隔时间 - maxRetries int // 最大重试次数 - retryDelay time.Duration // 重试间隔 - mu sync.Mutex // 保护 rateLimiter + eino embedding.Embedder + config *config.KnowledgeConfig + logger *zap.Logger + + rateLimiter *rate.Limiter + rateLimitDelay time.Duration + maxRetries int + retryDelay time.Duration + mu sync.Mutex } -// NewEmbedder 创建新的嵌入器 -func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder { - // 初始化速率限制器 +// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。 +func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) { + if cfg == nil { + return nil, fmt.Errorf("knowledge config is nil") + } + var rateLimiter *rate.Limiter var rateLimitDelay time.Duration - - // 如果配置了 MaxRPM,根据 RPM 计算速率限制 if cfg.Indexing.MaxRPM > 0 { rpm := cfg.Indexing.MaxRPM rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm) - logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm)) + if logger != nil { + logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm)) + } } else if cfg.Indexing.RateLimitDelayMs > 0 { - // 如果没有配置 MaxRPM 但配置了固定延迟,使用固定延迟模式 rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond - logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay)) + if logger != nil { + logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay)) + } } - // 重试配置 maxRetries := 3 retryDelay := 1000 * time.Millisecond if cfg.Indexing.MaxRetries > 0 { @@ -56,268 +59,193 @@ func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond } - return &Embedder{ - openAIClient: openAIClient, - config: cfg, - openAIConfig: openAIConfig, - logger: logger, - rateLimiter: rateLimiter, - rateLimitDelay: rateLimitDelay, - maxRetries: maxRetries, - retryDelay: retryDelay, - } -} - -// EmbeddingRequest OpenAI 嵌入请求 -type EmbeddingRequest struct { - Model string `json:"model"` - Input []string `json:"input"` -} - -// EmbeddingResponse OpenAI 嵌入响应 -type EmbeddingResponse struct { - Data []EmbeddingData `json:"data"` - Error *EmbeddingError `json:"error,omitempty"` -} - -// EmbeddingData 嵌入数据 -type EmbeddingData struct { - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -// EmbeddingError 嵌入错误 -type EmbeddingError struct { - Message string `json:"message"` - Type string `json:"type"` -} - -// waitRateLimiter 等待速率限制器 -func (e *Embedder) waitRateLimiter() { - e.mu.Lock() - defer e.mu.Unlock() - - if e.rateLimiter != nil { - // 等待令牌 - ctx := context.Background() - if err := e.rateLimiter.Wait(ctx); err != nil { - e.logger.Warn("速率限制器等待失败", zap.Error(err)) - } - } - - if e.rateLimitDelay > 0 { - time.Sleep(e.rateLimitDelay) - } -} - -// EmbedText 对文本进行嵌入(带重试和速率限制) -func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) { - if e.openAIClient == nil { - return nil, fmt.Errorf("OpenAI 客户端未初始化") - } - - var lastErr error - for attempt := 0; attempt < e.maxRetries; attempt++ { - // 速率限制 - if attempt > 0 { - // 重试时等待更长时间 - waitTime := e.retryDelay * time.Duration(attempt) - e.logger.Debug("重试前等待", zap.Int("attempt", attempt+1), zap.Duration("waitTime", waitTime)) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(waitTime): - } - } else { - e.waitRateLimiter() - } - - result, err := e.doEmbedText(ctx, text) - if err == nil { - return result, nil - } - - lastErr = err - - // 检查是否是可重试的错误(429 速率限制、5xx 服务器错误、网络错误) - if !e.isRetryableError(err) { - return nil, err - } - - e.logger.Debug("嵌入请求失败,准备重试", - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", e.maxRetries), - zap.Error(err)) - } - - return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr) -} - -// doEmbedText 执行实际的嵌入请求(内部方法) -func (e *Embedder) doEmbedText(ctx context.Context, text string) ([]float32, error) { - // 使用配置的嵌入模型 - model := e.config.Embedding.Model + model := strings.TrimSpace(cfg.Embedding.Model) if model == "" { model = "text-embedding-3-small" } - req := EmbeddingRequest{ - Model: model, - Input: []string{text}, - } - - // 清理 baseURL:去除前后空格和尾部斜杠 - baseURL := strings.TrimSpace(e.config.Embedding.BaseURL) + baseURL := strings.TrimSpace(cfg.Embedding.BaseURL) baseURL = strings.TrimSuffix(baseURL, "/") if baseURL == "" { baseURL = "https://api.openai.com/v1" } - // 构建请求 - body, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("序列化请求失败:%w", err) - } - - requestURL := baseURL + "/embeddings" - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body))) - if err != nil { - return nil, fmt.Errorf("创建请求失败:%w", err) - } - - httpReq.Header.Set("Content-Type", "application/json") - - // 使用配置的 API Key,如果没有则使用 OpenAI 配置的 - apiKey := strings.TrimSpace(e.config.Embedding.APIKey) - if apiKey == "" && e.openAIConfig != nil { - apiKey = e.openAIConfig.APIKey + apiKey := strings.TrimSpace(cfg.Embedding.APIKey) + if apiKey == "" && openAIConfig != nil { + apiKey = strings.TrimSpace(openAIConfig.APIKey) } if apiKey == "" { - return nil, fmt.Errorf("API Key 未配置") + return nil, fmt.Errorf("embedding API key 未配置") } - httpReq.Header.Set("Authorization", "Bearer "+apiKey) - // 发送请求 - httpClient := &http.Client{ - Timeout: 30 * time.Second, + timeout := 120 * time.Second + if cfg.Indexing.RequestTimeoutSeconds > 0 { + timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second } - resp, err := httpClient.Do(httpReq) + httpClient := &http.Client{Timeout: timeout} + + inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{ + APIKey: apiKey, + BaseURL: baseURL, + ByAzure: false, + Model: model, + HTTPClient: httpClient, + }) if err != nil { - return nil, fmt.Errorf("发送请求失败:%w", err) - } - defer resp.Body.Close() - - // 读取响应体以便在错误时输出详细信息 - bodyBytes := make([]byte, 0) - buf := make([]byte, 4096) - for { - n, err := resp.Body.Read(buf) - if n > 0 { - bodyBytes = append(bodyBytes, buf[:n]...) - } - if err != nil { - break - } + return nil, fmt.Errorf("eino OpenAI embedder: %w", err) } - // 记录请求和响应信息(用于调试) - requestBodyPreview := string(body) - if len(requestBodyPreview) > 200 { - requestBodyPreview = requestBodyPreview[:200] + "..." - } - e.logger.Debug("嵌入 API 请求", - zap.String("url", httpReq.URL.String()), - zap.String("model", model), - zap.String("requestBody", requestBodyPreview), - zap.Int("status", resp.StatusCode), - zap.Int("bodySize", len(bodyBytes)), - zap.String("contentType", resp.Header.Get("Content-Type")), - ) - - var embeddingResp EmbeddingResponse - if err := json.Unmarshal(bodyBytes, &embeddingResp); err != nil { - // 输出详细的错误信息 - bodyPreview := string(bodyBytes) - if len(bodyPreview) > 500 { - bodyPreview = bodyPreview[:500] + "..." - } - return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码:%d, 响应长度:%d字节): %w\n请求体:%s\n响应内容预览:%s", - requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview) - } - - if embeddingResp.Error != nil { - return nil, fmt.Errorf("OpenAI API 错误 (状态码:%d): 类型=%s, 消息=%s", - resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message) - } - - if resp.StatusCode != http.StatusOK { - bodyPreview := string(bodyBytes) - if len(bodyPreview) > 500 { - bodyPreview = bodyPreview[:500] + "..." - } - return nil, fmt.Errorf("HTTP 请求失败 (URL: %s, 状态码:%d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview) - } - - if len(embeddingResp.Data) == 0 { - bodyPreview := string(bodyBytes) - if len(bodyPreview) > 500 { - bodyPreview = bodyPreview[:500] + "..." - } - return nil, fmt.Errorf("未收到嵌入数据 (状态码:%d, 响应长度:%d字节)\n响应内容:%s", - resp.StatusCode, len(bodyBytes), bodyPreview) - } - - // 转换为 float32 - embedding := make([]float32, len(embeddingResp.Data[0].Embedding)) - for i, v := range embeddingResp.Data[0].Embedding { - embedding[i] = float32(v) - } - - return embedding, nil + return &Embedder{ + eino: inner, + config: cfg, + logger: logger, + rateLimiter: rateLimiter, + rateLimitDelay: rateLimitDelay, + maxRetries: maxRetries, + retryDelay: retryDelay, + }, nil } -// isRetryableError 判断是否是可重试的错误 -func (e *Embedder) isRetryableError(err error) bool { - if err == nil { - return false +// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。 +func (e *Embedder) EmbeddingModelName() string { + if e == nil || e.config == nil { + return "" } - - errStr := err.Error() - - // 429 速率限制错误 - if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") { - return true + s := strings.TrimSpace(e.config.Embedding.Model) + if s != "" { + return s } - - // 5xx 服务器错误 - if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") || - strings.Contains(errStr, "503") || strings.Contains(errStr, "504") { - return true - } - - // 网络错误 - if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") || - strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") { - return true - } - - return false + return "text-embedding-3-small" } -// EmbedTexts 批量嵌入文本 -func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { +func (e *Embedder) waitRateLimiter() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.rateLimiter != nil { + ctx := context.Background() + if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil { + e.logger.Warn("速率限制器等待失败", zap.Error(err)) + } + } + if e.rateLimitDelay > 0 { + time.Sleep(e.rateLimitDelay) + } +} + +// EmbedText 单条嵌入(float32,与历史存储格式一致)。 +func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) { + vecs, err := e.EmbedStrings(ctx, []string{text}) + if err != nil { + return nil, err + } + if len(vecs) != 1 { + return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs)) + } + return vecs[0], nil +} + +// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。 +func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) { + if e == nil || e.eino == nil { + return nil, fmt.Errorf("embedder not initialized") + } if len(texts) == 0 { return nil, nil } - embeddings := make([][]float32, len(texts)) - for i, text := range texts { - embedding, err := e.EmbedText(ctx, text) - if err != nil { - return nil, fmt.Errorf("嵌入文本 [%d] 失败:%w", i, err) + var lastErr error + for attempt := 0; attempt < e.maxRetries; attempt++ { + if attempt > 0 { + wait := e.retryDelay * time.Duration(attempt) + if e.logger != nil { + e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait)) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } else { + e.waitRateLimiter() } - embeddings[i] = embedding - } - return embeddings, nil + raw, err := e.eino.EmbedStrings(ctx, texts, opts...) + if err == nil { + out := make([][]float32, len(raw)) + for i, row := range raw { + out[i] = make([]float32, len(row)) + for j, v := range row { + out[i][j] = float32(v) + } + } + return out, nil + } + lastErr = err + if !e.isRetryableError(err) { + return nil, err + } + if e.logger != nil { + e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err)) + } + } + return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr) +} + +// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。 +func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { + return e.EmbedStrings(ctx, texts) +} + +func (e *Embedder) isRetryableError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") { + return true + } + if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") || + strings.Contains(errStr, "503") || strings.Contains(errStr, "504") { + return true + } + if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") || + strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") { + return true + } + return false +} + +// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store. +type einoFloatEmbedder struct { + inner *Embedder +} + +func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + vec32, err := w.inner.EmbedStrings(ctx, texts, opts...) + if err != nil { + return nil, err + } + out := make([][]float64, len(vec32)) + for i, row := range vec32 { + out[i] = make([]float64, len(row)) + for j, v := range row { + out[i][j] = float64(v) + } + } + return out, nil +} + +func (w *einoFloatEmbedder) GetType() string { + return "CyberStrikeKnowledgeEmbedder" +} + +func (w *einoFloatEmbedder) IsCallbacksEnabled() bool { + return false +} + +// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path +// and produces float64 vectors expected by generic Eino indexer helpers. +func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder { + return &einoFloatEmbedder{inner: e} } diff --git a/internal/knowledge/index_pipeline.go b/internal/knowledge/index_pipeline.go new file mode 100644 index 00000000..de5d466e --- /dev/null +++ b/internal/knowledge/index_pipeline.go @@ -0,0 +1,91 @@ +package knowledge + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/schema" +) + +// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive". +func normalizeChunkStrategy(s string) string { + v := strings.TrimSpace(strings.ToLower(s)) + switch v { + case "recursive": + return "recursive" + case "markdown_then_recursive", "markdown_recursive", "markdown": + return "markdown_then_recursive" + case "": + return "markdown_then_recursive" + default: + return "markdown_then_recursive" + } +} + +func buildKnowledgeIndexChain( + ctx context.Context, + indexingCfg *config.IndexingConfig, + db *sql.DB, + recursive document.Transformer, + embeddingModel string, +) (compose.Runnable[[]*schema.Document, []string], error) { + if recursive == nil { + return nil, fmt.Errorf("recursive transformer is nil") + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + strategy := normalizeChunkStrategy("markdown_then_recursive") + batch := 64 + maxChunks := 0 + if indexingCfg != nil { + strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy) + if indexingCfg.BatchSize > 0 { + batch = indexingCfg.BatchSize + } + maxChunks = indexingCfg.MaxChunksPerItem + } + + si := NewSQLiteIndexer(db, batch, embeddingModel) + ch := compose.NewChain[[]*schema.Document, []string]() + if strategy != "recursive" { + md, err := newMarkdownHeaderSplitter(ctx) + if err != nil { + return nil, fmt.Errorf("markdown splitter: %w", err) + } + ch.AppendDocumentTransformer(md) + } + ch.AppendDocumentTransformer(recursive) + ch.AppendLambda(newChunkEnrichLambda(maxChunks)) + ch.AppendIndexer(si) + return ch.Compile(ctx) +} + +func newChunkEnrichLambda(maxChunks int) *compose.Lambda { + return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) { + _ = ctx + out := make([]*schema.Document, 0, len(docs)) + for _, d := range docs { + if d == nil || strings.TrimSpace(d.Content) == "" { + continue + } + out = append(out, d) + } + if maxChunks > 0 && len(out) > maxChunks { + out = out[:maxChunks] + } + for i, d := range out { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + d.MetaData[metaKBChunkIndex] = i + } + return out, nil + }) +} diff --git a/internal/knowledge/index_pipeline_test.go b/internal/knowledge/index_pipeline_test.go new file mode 100644 index 00000000..9e4b03fa --- /dev/null +++ b/internal/knowledge/index_pipeline_test.go @@ -0,0 +1,21 @@ +package knowledge + +import "testing" + +func TestNormalizeChunkStrategy(t *testing.T) { + cases := []struct { + in, want string + }{ + {"", "markdown_then_recursive"}, + {"recursive", "recursive"}, + {"RECURSIVE", "recursive"}, + {"markdown_then_recursive", "markdown_then_recursive"}, + {"markdown", "markdown_then_recursive"}, + {"unknown", "markdown_then_recursive"}, + } + for _, tc := range cases { + if got := normalizeChunkStrategy(tc.in); got != tc.want { + t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go index 4a0da3eb..390835c6 100644 --- a/internal/knowledge/indexer.go +++ b/internal/knowledge/indexer.go @@ -3,596 +3,203 @@ package knowledge import ( "context" "database/sql" - "encoding/json" "fmt" - "regexp" "strings" "sync" "time" "cyberstrike-ai/internal/config" - "github.com/google/uuid" + fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" "go.uber.org/zap" ) -// Indexer 索引器,负责将知识项分块并向量化 +// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。 type Indexer struct { - db *sql.DB - embedder *Embedder - logger *zap.Logger - chunkSize int // 每个块的最大 token 数(估算) - overlap int // 块之间的重叠 token 数 - maxChunks int // 单个知识项的最大块数量(0 表示不限制) + db *sql.DB + embedder *Embedder + logger *zap.Logger + chunkSize int + overlap int + indexingCfg *config.IndexingConfig + + indexChain compose.Runnable[[]*schema.Document, []string] + fileLoader *fileloader.FileLoader - // 错误跟踪 mu sync.RWMutex - lastError string // 最近一次错误信息 - lastErrorTime time.Time // 最近一次错误时间 - errorCount int // 连续错误计数 + lastError string + lastErrorTime time.Time + errorCount int - // 重建索引状态跟踪 rebuildMu sync.RWMutex - isRebuilding bool // 是否正在重建索引 - rebuildTotalItems int // 重建总项数 - rebuildCurrent int // 当前已处理项数 - rebuildFailed int // 重建失败项数 - rebuildStartTime time.Time // 重建开始时间 - rebuildLastItemID string // 最近处理的项 ID - rebuildLastChunks int // 最近处理的项的分块数 + isRebuilding bool + rebuildTotalItems int + rebuildCurrent int + rebuildFailed int + rebuildStartTime time.Time + rebuildLastItemID string + rebuildLastChunks int } -// NewIndexer 创建新的索引器 -func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger, indexingCfg *config.IndexingConfig) *Indexer { +// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。 +func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) { + if db == nil { + return nil, fmt.Errorf("db is nil") + } + if embedder == nil { + return nil, fmt.Errorf("embedder is nil") + } + if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil { + return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err) + } + if kcfg == nil { + kcfg = &config.KnowledgeConfig{} + } + indexingCfg := &kcfg.Indexing + chunkSize := 512 overlap := 50 - maxChunks := 0 - if indexingCfg != nil { - if indexingCfg.ChunkSize > 0 { - chunkSize = indexingCfg.ChunkSize - } - if indexingCfg.ChunkOverlap >= 0 { - overlap = indexingCfg.ChunkOverlap - } - if indexingCfg.MaxChunksPerItem > 0 { - maxChunks = indexingCfg.MaxChunksPerItem - } + if indexingCfg.ChunkSize > 0 { + chunkSize = indexingCfg.ChunkSize } + if indexingCfg.ChunkOverlap >= 0 { + overlap = indexingCfg.ChunkOverlap + } + + embedModel := embedder.EmbeddingModelName() + splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel) + if err != nil { + return nil, fmt.Errorf("eino recursive splitter: %w", err) + } + + chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel) + if err != nil { + return nil, fmt.Errorf("knowledge index chain: %w", err) + } + + var fl *fileloader.FileLoader + fl, err = fileloader.NewFileLoader(ctx, nil) + if err != nil { + if logger != nil { + logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err)) + } + fl = nil + err = nil + } + return &Indexer{ - db: db, - embedder: embedder, - logger: logger, - chunkSize: chunkSize, - overlap: overlap, - maxChunks: maxChunks, - } + db: db, + embedder: embedder, + logger: logger, + chunkSize: chunkSize, + overlap: overlap, + indexingCfg: indexingCfg, + indexChain: chain, + fileLoader: fl, + }, nil } -// ChunkText 将文本分块(支持重叠,保留标题上下文) -func (idx *Indexer) ChunkText(text string) []string { - // 按 Markdown 标题分割,获取带标题的块 - sections := idx.splitByMarkdownHeadersWithContent(text) - - // 处理每个块 - result := make([]string, 0) - for _, section := range sections { - // 构建父级标题路径(不包含最后一级标题,因为内容中已经包含) - // 例如:["# A", "## B", "### C"] -> "[# A > ## B]" - var parentHeaderPath string - if len(section.HeaderPath) > 1 { - parentHeaderPath = strings.Join(section.HeaderPath[:len(section.HeaderPath)-1], " > ") - } - - // 提取内容的第一行作为标题(如 "# Prompt Injection") - firstLine, remainingContent := extractFirstLine(section.Content) - - // 如果剩余内容为空或只有空白,说明这个块只有标题没有正文,跳过 - if strings.TrimSpace(remainingContent) == "" { - continue - } - - // 如果块太大,进一步分割 - if idx.estimateTokens(section.Content) <= idx.chunkSize { - // 块大小合适,添加父级标题前缀 - if parentHeaderPath != "" { - result = append(result, fmt.Sprintf("[%s] %s", parentHeaderPath, section.Content)) - } else { - result = append(result, section.Content) - } - } else { - // 块太大,按子标题或段落分割,保持标题上下文 - // 首先尝试按子标题分割(保留子标题结构) - subSections := idx.splitBySubHeaders(section.Content, firstLine, parentHeaderPath) - if len(subSections) > 1 { - // 成功按子标题分割,递归处理每个子块 - for _, sub := range subSections { - if idx.estimateTokens(sub) <= idx.chunkSize { - result = append(result, sub) - } else { - // 子块仍然太大,按段落分割(保留标题前缀) - paragraphs := idx.splitByParagraphsWithHeader(sub, parentHeaderPath) - for _, para := range paragraphs { - if idx.estimateTokens(para) <= idx.chunkSize { - result = append(result, para) - } else { - // 段落仍太大,按句子分割 - sentenceChunks := idx.splitBySentencesWithOverlap(para) - for _, chunk := range sentenceChunks { - result = append(result, chunk) - } - } - } - } - } - } else { - // 没有子标题,按段落分割(保留标题前缀) - paragraphs := idx.splitByParagraphsWithHeader(section.Content, parentHeaderPath) - for _, para := range paragraphs { - if idx.estimateTokens(para) <= idx.chunkSize { - result = append(result, para) - } else { - // 段落仍太大,按句子分割 - sentenceChunks := idx.splitBySentencesWithOverlap(para) - for _, chunk := range sentenceChunks { - result = append(result, chunk) - } - } - } - } - } +// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。 +func (idx *Indexer) RecompileIndexChain(ctx context.Context) error { + if idx == nil || idx.db == nil || idx.embedder == nil { + return fmt.Errorf("indexer 未初始化") } - - return result + if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil { + return err + } + embedModel := idx.embedder.EmbeddingModelName() + splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel) + if err != nil { + return fmt.Errorf("eino recursive splitter: %w", err) + } + chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel) + if err != nil { + return fmt.Errorf("knowledge index chain: %w", err) + } + idx.indexChain = chain + return nil } -// extractFirstLine 提取第一行内容和剩余内容 -func extractFirstLine(content string) (firstLine, remaining string) { - lines := strings.SplitN(content, "\n", 2) - if len(lines) == 0 { - return "", "" - } - if len(lines) == 1 { - return lines[0], "" - } - return lines[0], lines[1] -} - -// splitBySubHeaders 尝试按子标题分割内容(用于处理大块内容) -// headerPrefix 是父级标题路径,用于添加到每个子块 -func (idx *Indexer) splitBySubHeaders(content, headerPrefix, parentPath string) []string { - // 匹配 Markdown 子标题(## 及以上) - subHeaderRegex := regexp.MustCompile(`(?m)^#{2,6}\s+.+$`) - matches := subHeaderRegex.FindAllStringIndex(content, -1) - - if len(matches) == 0 { - // 没有子标题,返回原始内容 - return []string{content} - } - - result := make([]string, 0, len(matches)) - for i, match := range matches { - start := match[0] - nextStart := len(content) - if i+1 < len(matches) { - nextStart = matches[i+1][0] - } - - subContent := strings.TrimSpace(content[start:nextStart]) - - // 添加父级路径前缀 - if parentPath != "" { - result = append(result, fmt.Sprintf("[%s] %s", parentPath, subContent)) - } else { - result = append(result, subContent) - } - } - - return result -} - -// splitByParagraphsWithHeader 按段落分割,每个段落添加标题前缀(用于保持上下文) -func (idx *Indexer) splitByParagraphsWithHeader(content, parentPath string) []string { - // 提取第一行作为标题 - firstLine, _ := extractFirstLine(content) - - paragraphs := strings.Split(content, "\n\n") - result := make([]string, 0) - - for i, p := range paragraphs { - trimmed := strings.TrimSpace(p) - if trimmed == "" { - continue - } - - // 过滤掉只有标题的段落(没有实际内容) - if strings.TrimSpace(trimmed) == strings.TrimSpace(firstLine) { - continue - } - - // 第一个段落已经包含标题,不需要重复添加 - if i == 0 && strings.Contains(trimmed, firstLine) { - if parentPath != "" { - result = append(result, fmt.Sprintf("[%s] %s", parentPath, trimmed)) - } else { - result = append(result, trimmed) - } - } else { - // 其他段落添加标题前缀以保持上下文 - if parentPath != "" { - result = append(result, fmt.Sprintf("[%s] %s\n%s", parentPath, firstLine, trimmed)) - } else { - result = append(result, fmt.Sprintf("%s\n%s", firstLine, trimmed)) - } - } - } - - return result -} - -// Section 表示一个带标题路径的文本块 -type Section struct { - HeaderPath []string // 标题路径(如 ["# SQL 注入", "## 检测方法"]) - Content string // 块内容 -} - -// splitByMarkdownHeadersWithContent 按 Markdown 标题分割,返回带标题路径的块 -// 每个块的内容包含自己的标题,用于向量化检索 -// -// 例如,对于以下 Markdown: -// # Prompt Injection -// 引言内容 -// ## Summary -// 目录内容 -// -// 返回: -// [{HeaderPath: ["# Prompt Injection"], Content: "# Prompt Injection\n引言内容"}, -// {HeaderPath: ["# Prompt Injection", "## Summary"], Content: "## Summary\n目录内容"}] -func (idx *Indexer) splitByMarkdownHeadersWithContent(text string) []Section { - // 匹配 Markdown 标题 (# ## ### 等) - headerRegex := regexp.MustCompile(`(?m)^#{1,6}\s+.+$`) - - // 找到所有标题位置 - matches := headerRegex.FindAllStringIndex(text, -1) - if len(matches) == 0 { - // 没有标题,返回整个文本 - return []Section{{HeaderPath: []string{}, Content: text}} - } - - sections := make([]Section, 0, len(matches)) - currentHeaderPath := []string{} - - for i, match := range matches { - start := match[0] - end := match[1] - nextStart := len(text) - - // 找到下一个标题的位置 - if i+1 < len(matches) { - nextStart = matches[i+1][0] - } - - // 提取当前标题 - headerLine := strings.TrimSpace(text[start:end]) - - // 计算标题层级(# 的数量) - level := 0 - for _, ch := range headerLine { - if ch == '#' { - level++ - } else { - break - } - } - - // 更新标题路径:移除比当前层级深或等于的子标题,然后添加当前标题 - newPath := make([]string, 0, len(currentHeaderPath)+1) - for _, h := range currentHeaderPath { - hLevel := 0 - for _, ch := range h { - if ch == '#' { - hLevel++ - } else { - break - } - } - if hLevel < level { - newPath = append(newPath, h) - } - } - newPath = append(newPath, headerLine) - currentHeaderPath = newPath - - // 提取当前标题到下一个标题之间的内容(包含当前标题) - content := strings.TrimSpace(text[start:nextStart]) - - // 创建块,使用当前标题路径(包含当前标题) - sections = append(sections, Section{ - HeaderPath: append([]string(nil), currentHeaderPath...), - Content: content, - }) - } - - // 过滤空块 - result := make([]Section, 0, len(sections)) - for _, section := range sections { - if strings.TrimSpace(section.Content) != "" { - result = append(result, section) - } - } - - if len(result) == 0 { - return []Section{{HeaderPath: []string{}, Content: text}} - } - - return result -} - -// splitByParagraphs 按段落分割 -func (idx *Indexer) splitByParagraphs(text string) []string { - paragraphs := strings.Split(text, "\n\n") - result := make([]string, 0) - for _, p := range paragraphs { - if strings.TrimSpace(p) != "" { - result = append(result, strings.TrimSpace(p)) - } - } - return result -} - -// splitBySentences 按句子分割(用于内部,不包含重叠逻辑) -func (idx *Indexer) splitBySentences(text string) []string { - // 简单的句子分割(按句号、问号、感叹号,支持中英文) - // . ! ? = 英文标点 - // \u3002 = 。(中文句号) - // \uFF01 = !(中文叹号) - // \uFF1F = ?(中文问号) - sentenceRegex := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`) - sentences := sentenceRegex.Split(text, -1) - result := make([]string, 0) - for _, s := range sentences { - if strings.TrimSpace(s) != "" { - result = append(result, strings.TrimSpace(s)) - } - } - return result -} - -// splitBySentencesWithOverlap 按句子分割并应用重叠策略 -func (idx *Indexer) splitBySentencesWithOverlap(text string) []string { - if idx.overlap <= 0 { - // 如果没有重叠,使用简单分割 - return idx.splitBySentencesSimple(text) - } - - sentences := idx.splitBySentences(text) - if len(sentences) == 0 { - return []string{} - } - - result := make([]string, 0) - currentChunk := "" - - for _, sentence := range sentences { - testChunk := currentChunk - if testChunk != "" { - testChunk += "\n" - } - testChunk += sentence - - testTokens := idx.estimateTokens(testChunk) - - if testTokens > idx.chunkSize && currentChunk != "" { - // 当前块已达到大小限制,保存它 - result = append(result, currentChunk) - - // 从当前块的末尾提取重叠部分 - overlapText := idx.extractLastTokens(currentChunk, idx.overlap) - if overlapText != "" { - // 如果有重叠内容,作为下一个块的起始 - currentChunk = overlapText + "\n" + sentence - } else { - // 如果无法提取足够的重叠内容,直接使用当前句子 - currentChunk = sentence - } - } else { - currentChunk = testChunk - } - } - - // 添加最后一个块 - if strings.TrimSpace(currentChunk) != "" { - result = append(result, currentChunk) - } - - // 过滤空块 - filtered := make([]string, 0) - for _, chunk := range result { - if strings.TrimSpace(chunk) != "" { - filtered = append(filtered, chunk) - } - } - - return filtered -} - -// splitBySentencesSimple 按句子分割(简单版本,无重叠) -func (idx *Indexer) splitBySentencesSimple(text string) []string { - sentences := idx.splitBySentences(text) - result := make([]string, 0) - currentChunk := "" - - for _, sentence := range sentences { - testChunk := currentChunk - if testChunk != "" { - testChunk += "\n" - } - testChunk += sentence - - if idx.estimateTokens(testChunk) > idx.chunkSize && currentChunk != "" { - result = append(result, currentChunk) - currentChunk = sentence - } else { - currentChunk = testChunk - } - } - if currentChunk != "" { - result = append(result, currentChunk) - } - - return result -} - -// extractLastTokens 从文本末尾提取指定 token 数量的内容 -func (idx *Indexer) extractLastTokens(text string, tokenCount int) string { - if tokenCount <= 0 || text == "" { - return "" - } - - // 估算字符数(1 token ≈ 4 字符) - charCount := tokenCount * 4 - runes := []rune(text) - - if len(runes) <= charCount { - return text - } - - // 从末尾提取指定数量的字符 - startPos := len(runes) - charCount - extracted := string(runes[startPos:]) - - // 尝试找到第一个句子边界(支持中英文标点) - sentenceBoundary := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`) - matches := sentenceBoundary.FindStringIndex(extracted) - if len(matches) > 0 && matches[0] > 0 { - // 在句子边界处截断,保留完整句子 - extracted = extracted[matches[0]:] - } - - return strings.TrimSpace(extracted) -} - -// estimateTokens 估算 token 数(简单估算:1 token ≈ 4 字符) -func (idx *Indexer) estimateTokens(text string) int { - return len([]rune(text)) / 4 -} - -// IndexItem 索引知识项(分块并向量化) +// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。 func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { - // 获取知识项(包含 category 和 title,用于向量化) - var content, category, title string - err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title) + if idx.indexChain == nil { + return fmt.Errorf("索引链未初始化") + } + if idx.embedder == nil { + return fmt.Errorf("嵌入器未初始化") + } + + var content, category, title, filePath string + err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath) if err != nil { return fmt.Errorf("获取知识项失败:%w", err) } - // 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性) - _, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID) - if err != nil { + if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil { return fmt.Errorf("删除旧向量失败:%w", err) } - // 分块 - chunks := idx.ChunkText(content) - - // 应用最大块数限制 - if idx.maxChunks > 0 && len(chunks) > idx.maxChunks { - idx.logger.Info("知识项块数量超过限制,已截断", - zap.String("itemId", itemID), - zap.Int("originalChunks", len(chunks)), - zap.Int("maxChunks", idx.maxChunks)) - chunks = chunks[:idx.maxChunks] - } - - idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) - - // 跟踪该知识项的错误 - itemErrorCount := 0 - var firstError error - firstErrorChunkIndex := -1 - - // 向量化每个块(包含 category 和 title 信息,以便向量检索时能匹配到风险类型) - for i, chunk := range chunks { - // 将 category 和 title 信息包含到向量化的文本中 - // 格式:"[风险类型:{category}] [标题:{title}]\n{chunk 内容}" - // 这样向量嵌入就会包含风险类型信息,即使 SQL 过滤失败,向量相似度也能帮助匹配 - textForEmbedding := fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunk) - - embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding) - if err != nil { - itemErrorCount++ - if firstError == nil { - firstError = err - firstErrorChunkIndex = i - // 只在第一个块失败时记录详细日志 - chunkPreview := chunk - if len(chunkPreview) > 200 { - chunkPreview = chunkPreview[:200] + "..." + body := strings.TrimSpace(content) + if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil { + docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)}) + if lerr == nil && len(docs) > 0 { + var b strings.Builder + for i, d := range docs { + if d == nil { + continue } - idx.logger.Warn("向量化失败", - zap.String("itemId", itemID), - zap.Int("chunkIndex", i), - zap.Int("totalChunks", len(chunks)), - zap.String("chunkPreview", chunkPreview), - zap.Error(err), - ) - - // 更新全局错误跟踪 - errorMsg := fmt.Sprintf("向量化失败 (知识项:%s): %v", itemID, err) - idx.mu.Lock() - idx.lastError = errorMsg - idx.lastErrorTime = time.Now() - idx.mu.Unlock() + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(d.Content) } - - // 如果连续失败 5 个块,立即停止处理该知识项 - // 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题 - // 对于大文档(超过 10 个块),允许失败比例不超过 50% - maxConsecutiveFailures := 5 - if len(chunks) > 10 && itemErrorCount > len(chunks)/2 { - idx.logger.Error("知识项向量化失败比例过高,停止处理", - zap.String("itemId", itemID), - zap.Int("totalChunks", len(chunks)), - zap.Int("failedChunks", itemErrorCount), - zap.Int("firstErrorChunkIndex", firstErrorChunkIndex), - zap.Error(firstError), - ) - return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError) + if s := strings.TrimSpace(b.String()); s != "" { + body = s } - if itemErrorCount >= maxConsecutiveFailures { - idx.logger.Error("知识项连续向量化失败,停止处理", - zap.String("itemId", itemID), - zap.Int("totalChunks", len(chunks)), - zap.Int("failedChunks", itemErrorCount), - zap.Int("firstErrorChunkIndex", firstErrorChunkIndex), - zap.Error(firstError), - ) - return fmt.Errorf("知识项连续向量化失败 (%d个块失败): %v", itemErrorCount, firstError) - } - continue - } - - // 保存向量 - chunkID := uuid.New().String() - embeddingJSON, _ := json.Marshal(embedding) - - _, err = idx.db.Exec( - "INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, created_at) VALUES (?, ?, ?, ?, ?, datetime('now'))", - chunkID, itemID, i, chunk, string(embeddingJSON), - ) - if err != nil { - idx.logger.Warn("保存向量失败", zap.String("itemId", itemID), zap.Int("chunkIndex", i), zap.Error(err)) - continue + } else if idx.logger != nil { + idx.logger.Warn("优先源文件读取失败,使用数据库正文", + zap.String("itemId", itemID), + zap.String("path", filePath), + zap.Error(lerr)) } } - idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks))) + root := &schema.Document{ + ID: itemID, + Content: body, + MetaData: map[string]any{ + metaKBCategory: category, + metaKBTitle: title, + metaKBItemID: itemID, + }, + } - // 更新重建状态中的最近处理信息 + idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())} + if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 { + idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes)) + } + + ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...)) + if err != nil { + msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err) + idx.mu.Lock() + idx.lastError = msg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + return err + } + + if idx.logger != nil { + idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids))) + } idx.rebuildMu.Lock() idx.rebuildLastItemID = itemID - idx.rebuildLastChunks = len(chunks) + idx.rebuildLastChunks = len(ids) idx.rebuildMu.Unlock() - return nil } @@ -608,7 +215,6 @@ func (idx *Indexer) HasIndex() (bool, error) { // RebuildIndex 重建所有索引 func (idx *Indexer) RebuildIndex(ctx context.Context) error { - // 设置重建状态 idx.rebuildMu.Lock() idx.isRebuilding = true idx.rebuildTotalItems = 0 @@ -619,7 +225,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { idx.rebuildLastChunks = 0 idx.rebuildMu.Unlock() - // 重置错误跟踪 idx.mu.Lock() idx.lastError = "" idx.lastErrorTime = time.Time{} @@ -628,7 +233,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") if err != nil { - // 重置重建状态 idx.rebuildMu.Lock() idx.isRebuilding = false idx.rebuildMu.Unlock() @@ -640,7 +244,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { for rows.Next() { var id string if err := rows.Scan(&id); err != nil { - // 重置重建状态 idx.rebuildMu.Lock() idx.isRebuilding = false idx.rebuildMu.Unlock() @@ -655,13 +258,9 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs))) - // 注意:不再清空所有旧索引,而是按增量方式更新 - // 每个知识项在 IndexItem 中会先删除自己的旧向量,然后插入新向量 - // 这样配置更新后只重新索引变化的知识项,保留其他知识项的索引 - failedCount := 0 consecutiveFailures := 0 - maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误) + maxConsecutiveFailures := 5 firstFailureItemID := "" var firstFailureError error @@ -670,7 +269,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { failedCount++ consecutiveFailures++ - // 只在第一个失败时记录详细日志 if consecutiveFailures == 1 { firstFailureItemID = itemID firstFailureError = err @@ -681,7 +279,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { ) } - // 如果连续失败过多,可能是配置问题,立即停止索引 if consecutiveFailures >= maxConsecutiveFailures { errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError) idx.mu.Lock() @@ -699,7 +296,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError) } - // 如果失败的知识项过多,记录警告但继续处理(降低阈值到 30%) if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 { errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError) idx.mu.Lock() @@ -717,26 +313,22 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { continue } - // 成功时重置连续失败计数和第一个失败信息 if consecutiveFailures > 0 { consecutiveFailures = 0 firstFailureItemID = "" firstFailureError = nil } - // 更新重建进度 idx.rebuildMu.Lock() idx.rebuildCurrent = i + 1 idx.rebuildFailed = failedCount idx.rebuildMu.Unlock() - // 减少进度日志频率(每 10 个或每 10% 记录一次) if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) { idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount)) } } - // 重置重建状态 idx.rebuildMu.Lock() idx.isRebuilding = false idx.rebuildMu.Unlock() diff --git a/internal/knowledge/retrieval_postprocess.go b/internal/knowledge/retrieval_postprocess.go new file mode 100644 index 00000000..eb69e4c3 --- /dev/null +++ b/internal/knowledge/retrieval_postprocess.go @@ -0,0 +1,213 @@ +package knowledge + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" + "unicode" + "unicode/utf8" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" + "github.com/pkoukk/tiktoken-go" +) + +// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。 +const postRetrieveMaxPrefetchCap = 200 + +// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。 +type DocumentReranker interface { + Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) +} + +// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。 +type NopDocumentReranker struct{} + +// Rerank implements [DocumentReranker] as no-op. +func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) { + return docs, nil +} + +var tiktokenEncMu sync.Mutex +var tiktokenEncCache = map[string]*tiktoken.Tiktoken{} + +func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) { + m := strings.TrimSpace(model) + if m == "" { + m = "gpt-4" + } + tiktokenEncMu.Lock() + defer tiktokenEncMu.Unlock() + if enc, ok := tiktokenEncCache[m]; ok { + return enc, nil + } + enc, err := tiktoken.EncodingForModel(m) + if err != nil { + enc, err = tiktoken.GetEncoding("cl100k_base") + if err != nil { + return nil, err + } + } + tiktokenEncCache[m] = enc + return enc, nil +} + +func countDocTokens(text, model string) (int, error) { + enc, err := encodingForTokenizerModel(model) + if err != nil { + return 0, err + } + toks := enc.Encode(text, nil, nil) + return len(toks), nil +} + +// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。 +func normalizeContentFingerprintKey(s string) string { + s = strings.TrimSpace(s) + var b strings.Builder + b.Grow(len(s)) + prevSpace := false + for _, r := range s { + if unicode.IsSpace(r) { + if !prevSpace { + b.WriteByte(' ') + prevSpace = true + } + continue + } + prevSpace = false + b.WriteRune(r) + } + return b.String() +} + +func contentNormKey(d *schema.Document) string { + if d == nil { + return "" + } + n := normalizeContentFingerprintKey(d.Content) + if n == "" { + return "" + } + sum := sha256.Sum256([]byte(n)) + return hex.EncodeToString(sum[:]) +} + +// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。 +func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document { + if len(docs) < 2 { + return docs + } + seen := make(map[string]struct{}, len(docs)) + out := make([]*schema.Document, 0, len(docs)) + for _, d := range docs { + if d == nil { + continue + } + k := contentNormKey(d) + if k == "" { + out = append(out, d) + continue + } + if _, ok := seen[k]; ok { + continue + } + seen[k] = struct{}{} + out = append(out, d) + } + return out +} + +// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。 +func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) { + if len(docs) == 0 { + return docs, nil + } + unlimitedChars := maxRunes <= 0 + unlimitedTok := maxTokens <= 0 + if unlimitedChars && unlimitedTok { + return docs, nil + } + + remRunes := maxRunes + remTok := maxTokens + out := make([]*schema.Document, 0, len(docs)) + + for _, d := range docs { + if d == nil || strings.TrimSpace(d.Content) == "" { + continue + } + runes := utf8.RuneCountInString(d.Content) + if !unlimitedChars && runes > remRunes { + break + } + var tok int + var err error + if !unlimitedTok { + tok, err = countDocTokens(d.Content, tokenModel) + if err != nil { + return nil, fmt.Errorf("token count: %w", err) + } + if tok > remTok { + break + } + } + out = append(out, d) + if !unlimitedChars { + remRunes -= runes + } + if !unlimitedTok { + remTok -= tok + } + } + return out, nil +} + +// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。 +func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int { + if topK < 1 { + topK = 5 + } + fetch := topK + if po != nil && po.PrefetchTopK > fetch { + fetch = po.PrefetchTopK + } + if fetch > postRetrieveMaxPrefetchCap { + fetch = postRetrieveMaxPrefetchCap + } + return fetch +} + +// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。 +func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) { + if finalTopK < 1 { + finalTopK = 5 + } + if len(docs) == 0 { + return docs, nil + } + + maxChars := 0 + maxTok := 0 + if po != nil { + maxChars = po.MaxContextChars + maxTok = po.MaxContextTokens + } + + out := dedupeByNormalizedContent(docs) + + var err error + out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel) + if err != nil { + return nil, err + } + + if len(out) > finalTopK { + out = out[:finalTopK] + } + return out, nil +} diff --git a/internal/knowledge/retrieval_postprocess_test.go b/internal/knowledge/retrieval_postprocess_test.go new file mode 100644 index 00000000..10c661a8 --- /dev/null +++ b/internal/knowledge/retrieval_postprocess_test.go @@ -0,0 +1,62 @@ +package knowledge + +import ( + "testing" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" +) + +func doc(id, content string, score float64) *schema.Document { + d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}} + d.WithScore(score) + return d +} + +func TestDedupeByNormalizedContent(t *testing.T) { + a := doc("1", "hello world", 0.9) + b := doc("2", "hello world", 0.8) + c := doc("3", "other", 0.7) + out := dedupeByNormalizedContent([]*schema.Document{a, b, c}) + if len(out) != 2 { + t.Fatalf("len=%d want 2", len(out)) + } + if out[0].ID != "1" || out[1].ID != "3" { + t.Fatalf("order/ids wrong: %#v", out) + } +} + +func TestEffectivePrefetchTopK(t *testing.T) { + if g := EffectivePrefetchTopK(5, nil); g != 5 { + t.Fatalf("got %d", g) + } + if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 { + t.Fatalf("got %d", g) + } + if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap { + t.Fatalf("cap: got %d", g) + } +} + +func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) { + d1 := doc("1", "ab", 0.9) + d2 := doc("2", "cd", 0.8) + d3 := doc("3", "ef", 0.7) + po := &config.PostRetrieveConfig{MaxContextChars: 3} + out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5) + if err != nil { + t.Fatal(err) + } + if len(out) != 1 || out[0].ID != "1" { + t.Fatalf("got %#v", out) + } + + out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2) + if err != nil { + t.Fatal(err) + } + if len(out2) != 2 { + t.Fatalf("topk: len=%d", len(out2)) + } +} diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go index 6a6551a1..9145b2c6 100644 --- a/internal/knowledge/retriever.go +++ b/internal/knowledge/retriever.go @@ -8,23 +8,34 @@ import ( "math" "sort" "strings" + "sync" + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" "go.uber.org/zap" ) -// Retriever 检索器 +// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值), +// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。 type Retriever struct { db *sql.DB embedder *Embedder config *RetrievalConfig logger *zap.Logger + + rerankMu sync.RWMutex + reranker DocumentReranker } // RetrievalConfig 检索配置 type RetrievalConfig struct { TopK int SimilarityThreshold float64 - HybridWeight float64 + // SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。 + SubIndexFilter string + PostRetrieve config.PostRetrieveConfig } // NewRetriever 创建新的检索器 @@ -38,18 +49,41 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge } // UpdateConfig 更新检索配置 -func (r *Retriever) UpdateConfig(config *RetrievalConfig) { - if config != nil { - r.config = config - r.logger.Info("检索器配置已更新", - zap.Int("top_k", config.TopK), - zap.Float64("similarity_threshold", config.SimilarityThreshold), - zap.Float64("hybrid_weight", config.HybridWeight), - ) +func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) { + if cfg != nil { + r.config = cfg + if r.logger != nil { + r.logger.Info("检索器配置已更新", + zap.Int("top_k", cfg.TopK), + zap.Float64("similarity_threshold", cfg.SimilarityThreshold), + zap.String("sub_index_filter", cfg.SubIndexFilter), + zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK), + zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars), + zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens), + ) + } } } -// cosineSimilarity 计算余弦相似度 +// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。 +func (r *Retriever) SetDocumentReranker(rr DocumentReranker) { + if r == nil { + return + } + r.rerankMu.Lock() + defer r.rerankMu.Unlock() + r.reranker = rr +} + +func (r *Retriever) documentReranker() DocumentReranker { + if r == nil { + return nil + } + r.rerankMu.RLock() + defer r.rerankMu.RUnlock() + return r.reranker +} + func cosineSimilarity(a, b []float32) float64 { if len(a) != len(b) { return 0.0 @@ -69,608 +103,203 @@ func cosineSimilarity(a, b []float32) float64 { return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) } -// bm25Score 计算 BM25 分数(带缓存的改进版本) -// 注意:由于缺少全局文档统计,使用简化 IDF 计算 -func (r *Retriever) bm25Score(query, text string) float64 { - queryTerms := strings.Fields(strings.ToLower(query)) - if len(queryTerms) == 0 { - return 0.0 +// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。 +func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { + if req == nil { + return nil, fmt.Errorf("请求不能为空") } - - textLower := strings.ToLower(text) - textTerms := strings.Fields(textLower) - if len(textTerms) == 0 { - return 0.0 + q := strings.TrimSpace(req.Query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") } - - // BM25 参数(标准值) - k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0) - b := 0.75 // 长度归一化参数(标准值) - avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小) - docLength := float64(len(textTerms)) - - // 计算词频映射 - textTermFreq := make(map[string]int, len(textTerms)) - for _, term := range textTerms { - textTermFreq[term]++ + opts := r.einoRetrieverOptions(req) + docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...) + if err != nil { + return nil, err } - - score := 0.0 - matchedQueryTerms := 0 - - for _, term := range queryTerms { - termFreq, exists := textTermFreq[term] - if !exists || termFreq == 0 { - continue - } - matchedQueryTerms++ - - // BM25 TF 计算公式 - tf := float64(termFreq) - lengthNorm := 1 - b + b*(docLength/avgDocLength) - tfScore := tf / (tf + k1*lengthNorm) - - // 改进的 IDF 计算:使用词长度和出现频率估算 - // 短词(2-3 字符)通常更重要,长词 IDF 略低 - idfWeight := 1.0 - termLen := len(term) - if termLen <= 2 { - // 极短词(如 go, js)给予更高权重 - idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0) - } else if termLen <= 4 { - // 短词(4 字符)标准权重 - idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0) - } else { - // 长词稍微降低权重 - idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0) - } - - score += tfScore * idfWeight - } - - // 归一化:考虑匹配的查询词比例 - if len(queryTerms) > 0 { - // 使用匹配比例作为额外因子 - matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms)) - score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2 - } - - return math.Min(score, 1.0) + return documentsToRetrievalResults(docs) } -// Search 搜索知识库 -func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { +func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option { + var opts []retriever.Option + if req.TopK > 0 { + opts = append(opts, retriever.WithTopK(req.TopK)) + } + dsl := map[string]any{} + if strings.TrimSpace(req.RiskType) != "" { + dsl[DSLRiskType] = strings.TrimSpace(req.RiskType) + } + if req.Threshold > 0 { + dsl[DSLSimilarityThreshold] = req.Threshold + } + if strings.TrimSpace(req.SubIndexFilter) != "" { + dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter) + } + if len(dsl) > 0 { + opts = append(opts, retriever.WithDSLInfo(dsl)) + } + return opts +} + +// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。 +func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...) +} + +func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) { + q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title +FROM knowledge_embeddings e +JOIN knowledge_base_items i ON e.item_id = i.id +WHERE 1=1` + var args []interface{} + if strings.TrimSpace(riskType) != "" { + q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE` + args = append(args, riskType) + } + if tag := strings.TrimSpace(subIndexFilter); tag != "" { + tag = strings.ToLower(strings.ReplaceAll(tag, " ", "")) + q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)` + args = append(args, tag) + } + return q, args +} + +// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。 +func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { if req.Query == "" { return nil, fmt.Errorf("查询不能为空") } topK := req.TopK - if topK <= 0 { + if topK <= 0 && r.config != nil { topK = r.config.TopK } - if topK == 0 { + if topK <= 0 { topK = 5 } threshold := req.Threshold - if threshold <= 0 { + if threshold <= 0 && r.config != nil { threshold = r.config.SimilarityThreshold } - if threshold == 0 { + if threshold <= 0 { threshold = 0.7 } - // 向量化查询(如果提供了risk_type,也包含在查询文本中,以便更好地匹配) - queryText := req.Query - if req.RiskType != "" { - // 将risk_type信息包含到查询中,格式与索引时保持一致 - queryText = fmt.Sprintf("[风险类型: %s] %s", req.RiskType, req.Query) + subIdxFilter := strings.TrimSpace(req.SubIndexFilter) + if subIdxFilter == "" && r.config != nil { + subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter) } + + queryText := FormatQueryEmbeddingText(req.RiskType, req.Query) queryEmbedding, err := r.embedder.EmbedText(ctx, queryText) if err != nil { return nil, fmt.Errorf("向量化查询失败: %w", err) } - - // 查询所有向量(或按风险类型过滤) - // 使用精确匹配(=)以提高性能和准确性 - // 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称 - // 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配 - var rows *sql.Rows - if req.RiskType != "" { - // 使用精确匹配(=),性能更好且更准确 - // 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性 - // 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到 - // 建议用户先调用相应的内置工具获取准确的category名称 - rows, err = r.db.Query(` - SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title - FROM knowledge_embeddings e - JOIN knowledge_base_items i ON e.item_id = i.id - WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE - `, req.RiskType) - } else { - rows, err = r.db.Query(` - SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title - FROM knowledge_embeddings e - JOIN knowledge_base_items i ON e.item_id = i.id - `) + queryDim := len(queryEmbedding) + expectedModel := "" + if r.embedder != nil { + expectedModel = r.embedder.EmbeddingModelName() } + + sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter) + rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...) if err != nil { return nil, fmt.Errorf("查询向量失败: %w", err) } defer rows.Close() - // 计算相似度 type candidate struct { - chunk *KnowledgeChunk - item *KnowledgeItem - similarity float64 - bm25Score float64 - hasStrongKeywordMatch bool - hybridScore float64 // 混合分数,用于最终排序 + chunk *KnowledgeChunk + item *KnowledgeItem + similarity float64 } candidates := make([]candidate, 0) - + rowNum := 0 for rows.Next() { - var chunkID, itemID, chunkText, embeddingJSON, category, title string - var chunkIndex int + rowNum++ + if rowNum%48 == 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } - if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &category, &title); err != nil { + var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string + var chunkIndex, rowDim int + + if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil { r.logger.Warn("扫描向量失败", zap.Error(err)) continue } - // 解析向量 var embedding []float32 if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil { r.logger.Warn("解析向量失败", zap.Error(err)) continue } - // 计算余弦相似度 - similarity := cosineSimilarity(queryEmbedding, embedding) - - // 计算BM25分数(考虑chunk文本、category和title) - // category和title是结构化字段,完全匹配时应该被优先考虑 - chunkBM25 := r.bm25Score(req.Query, chunkText) - categoryBM25 := r.bm25Score(req.Query, category) - titleBM25 := r.bm25Score(req.Query, title) - - // 检查category或title是否有显著匹配(这对于结构化字段很重要) - hasStrongKeywordMatch := categoryBM25 > 0.3 || titleBM25 > 0.3 - - // 综合BM25分数(用于后续排序) - bm25Score := math.Max(math.Max(chunkBM25, categoryBM25), titleBM25) - - // 收集所有候选(先不严格过滤,以便后续智能处理跨语言情况) - // 只过滤掉相似度极低的结果(< 0.1),避免噪音 - if similarity < 0.1 { + if rowDim > 0 && len(embedding) != rowDim { + r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding))) + continue + } + if queryDim > 0 && len(embedding) != queryDim { + r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding))) + continue + } + if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel { + r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel)) continue } - chunk := &KnowledgeChunk{ - ID: chunkID, - ItemID: itemID, - ChunkIndex: chunkIndex, - ChunkText: chunkText, - Embedding: embedding, - } - - item := &KnowledgeItem{ - ID: itemID, - Category: category, - Title: title, - } - + similarity := cosineSimilarity(queryEmbedding, embedding) candidates = append(candidates, candidate{ - chunk: chunk, - item: item, - similarity: similarity, - bm25Score: bm25Score, - hasStrongKeywordMatch: hasStrongKeywordMatch, + chunk: &KnowledgeChunk{ + ID: chunkID, + ItemID: itemID, + ChunkIndex: chunkIndex, + ChunkText: chunkText, + Embedding: embedding, + }, + item: &KnowledgeItem{ + ID: itemID, + Category: category, + Title: title, + }, + similarity: similarity, }) } - // 先按相似度排序(使用更高效的排序) sort.Slice(candidates, func(i, j int) bool { return candidates[i].similarity > candidates[j].similarity }) - // 智能过滤策略:优先保留关键词匹配的结果,对跨语言查询使用更宽松的阈值 - filteredCandidates := make([]candidate, 0) - - // 检查是否有任何关键词匹配(用于判断是否是跨语言查询) - hasAnyKeywordMatch := false - for _, cand := range candidates { - if cand.hasStrongKeywordMatch { - hasAnyKeywordMatch = true - break + filtered := make([]candidate, 0, len(candidates)) + for _, c := range candidates { + if c.similarity >= threshold { + filtered = append(filtered, c) } } - // 检查最高相似度,用于判断是否确实有相关内容 - maxSimilarity := 0.0 - if len(candidates) > 0 { - maxSimilarity = candidates[0].similarity + if len(filtered) > topK { + filtered = filtered[:topK] } - // 应用智能过滤 - // 如果用户设置了高阈值(>=0.8),更严格地遵守阈值,减少自动放宽 - strictMode := threshold >= 0.8 - - // 根据是否有关键词匹配,采用不同的阈值策略 - // 严格模式下,禁用跨语言放宽策略,严格遵守用户设置的阈值 - effectiveThreshold := threshold - if !strictMode && !hasAnyKeywordMatch { - // 非严格模式下,没有关键词匹配,可能是跨语言查询,适度放宽阈值 - // 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性 - // 跨语言阈值设为0.6,确保返回的结果至少有一定相关性 - effectiveThreshold = math.Max(threshold*0.85, 0.6) - r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值", - zap.Float64("originalThreshold", threshold), - zap.Float64("effectiveThreshold", effectiveThreshold), - ) - } else if strictMode { - // 严格模式下,即使没有关键词匹配,也严格遵守阈值 - r.logger.Debug("严格模式:严格遵守用户设置的阈值", - zap.Float64("threshold", threshold), - zap.Bool("hasKeywordMatch", hasAnyKeywordMatch), - ) - } - for _, cand := range candidates { - if cand.similarity >= effectiveThreshold { - // 达到阈值,直接通过 - filteredCandidates = append(filteredCandidates, cand) - } else if !strictMode && cand.hasStrongKeywordMatch { - // 非严格模式下,有关键词匹配但相似度略低于阈值,适当放宽 - // 严格模式下,即使有关键词匹配,也严格遵守阈值 - relaxedThreshold := math.Max(effectiveThreshold*0.85, 0.55) - if cand.similarity >= relaxedThreshold { - filteredCandidates = append(filteredCandidates, cand) - } - } - // 如果既没有关键词匹配,相似度又低于阈值,则过滤掉 - } - - // 智能兜底策略:只有在最高相似度达到合理水平时,才考虑返回结果 - // 如果最高相似度都很低(<0.55),说明确实没有相关内容,应该返回空 - // 严格模式下(阈值>=0.8),禁用兜底策略,严格遵守用户设置的阈值 - if len(filteredCandidates) == 0 && len(candidates) > 0 && !strictMode { - // 即使没有通过阈值过滤,如果最高相似度还可以(>=0.55),可以考虑返回Top-K - // 但这是最后的兜底,只在确实有一定相关性时才使用 - // 严格模式下不使用兜底策略 - minAcceptableSimilarity := 0.55 - if maxSimilarity >= minAcceptableSimilarity { - r.logger.Debug("过滤后无结果,但最高相似度可接受,返回Top-K结果", - zap.Int("totalCandidates", len(candidates)), - zap.Float64("maxSimilarity", maxSimilarity), - zap.Float64("effectiveThreshold", effectiveThreshold), - ) - maxResults := topK - if len(candidates) < maxResults { - maxResults = len(candidates) - } - // 只返回相似度 >= 0.55 的结果 - for _, cand := range candidates { - if cand.similarity >= minAcceptableSimilarity && len(filteredCandidates) < maxResults { - filteredCandidates = append(filteredCandidates, cand) - } - } - } else { - r.logger.Debug("过滤后无结果,且最高相似度过低,返回空结果", - zap.Int("totalCandidates", len(candidates)), - zap.Float64("maxSimilarity", maxSimilarity), - zap.Float64("minAcceptableSimilarity", minAcceptableSimilarity), - ) - } - } else if len(filteredCandidates) == 0 && strictMode { - // 严格模式下,如果过滤后无结果,直接返回空,不使用兜底策略 - r.logger.Debug("严格模式:过滤后无结果,严格遵守阈值,返回空结果", - zap.Float64("threshold", threshold), - zap.Float64("maxSimilarity", maxSimilarity), - ) - } - - // 统一在最终返回前严格限制 Top-K 数量 - if len(filteredCandidates) > topK { - // 如果过滤后结果太多,只取Top-K - filteredCandidates = filteredCandidates[:topK] - } - - candidates = filteredCandidates - - // 混合排序(向量相似度 + BM25) - // 注意:hybridWeight可以是0.0(纯关键词检索),所以不设置默认值 - // 如果配置文件中未设置,应该在配置加载时使用默认值 - hybridWeight := r.config.HybridWeight - // 如果未设置,使用默认值0.7(偏重向量检索) - if hybridWeight < 0 || hybridWeight > 1 { - r.logger.Warn("混合权重超出范围,使用默认值0.7", - zap.Float64("provided", hybridWeight)) - hybridWeight = 0.7 - } - - // 先计算混合分数并存储在candidate中,用于排序 - for i := range candidates { - normalizedBM25 := math.Min(candidates[i].bm25Score, 1.0) - candidates[i].hybridScore = hybridWeight*candidates[i].similarity + (1-hybridWeight)*normalizedBM25 - - // 调试日志:记录前几个候选的分数计算(仅在debug级别) - if i < 3 { - r.logger.Debug("混合分数计算", - zap.Int("index", i), - zap.Float64("similarity", candidates[i].similarity), - zap.Float64("bm25Score", candidates[i].bm25Score), - zap.Float64("normalizedBM25", normalizedBM25), - zap.Float64("hybridWeight", hybridWeight), - zap.Float64("hybridScore", candidates[i].hybridScore)) - } - } - - // 根据混合分数重新排序(这才是真正的混合检索) - sort.Slice(candidates, func(i, j int) bool { - return candidates[i].hybridScore > candidates[j].hybridScore - }) - - // 转换为结果 - results := make([]*RetrievalResult, len(candidates)) - for i, cand := range candidates { + results := make([]*RetrievalResult, len(filtered)) + for i, c := range filtered { results[i] = &RetrievalResult{ - Chunk: cand.chunk, - Item: cand.item, - Similarity: cand.similarity, - Score: cand.hybridScore, + Chunk: c.chunk, + Item: c.item, + Similarity: c.similarity, + Score: c.similarity, } } - - // 上下文扩展:为每个匹配的chunk添加同一文档中的相关chunk - // 这可以防止文本描述和payload被分开切分时,只返回描述而丢失payload的问题 - results = r.expandContext(ctx, results) - return results, nil } -// expandContext 扩展检索结果的上下文 -// 对于每个匹配的chunk,自动包含同一文档中的相关chunk(特别是包含代码块、payload的chunk) -func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResult) []*RetrievalResult { - if len(results) == 0 { - return results - } - - // 收集所有匹配到的文档ID - itemIDs := make(map[string]bool) - for _, result := range results { - itemIDs[result.Item.ID] = true - } - - // 为每个文档加载所有chunk - itemChunksMap := make(map[string][]*KnowledgeChunk) - for itemID := range itemIDs { - chunks, err := r.loadAllChunksForItem(itemID) - if err != nil { - r.logger.Warn("加载文档chunk失败", zap.String("itemId", itemID), zap.Error(err)) - continue - } - itemChunksMap[itemID] = chunks - } - - // 按文档分组结果,每个文档只扩展一次 - resultsByItem := make(map[string][]*RetrievalResult) - for _, result := range results { - itemID := result.Item.ID - resultsByItem[itemID] = append(resultsByItem[itemID], result) - } - - // 扩展每个文档的结果 - expandedResults := make([]*RetrievalResult, 0, len(results)) - processedChunkIDs := make(map[string]bool) // 避免重复添加 - - for itemID, itemResults := range resultsByItem { - // 获取该文档的所有chunk - allChunks, exists := itemChunksMap[itemID] - if !exists { - // 如果无法加载chunk,直接添加原始结果 - for _, result := range itemResults { - if !processedChunkIDs[result.Chunk.ID] { - expandedResults = append(expandedResults, result) - processedChunkIDs[result.Chunk.ID] = true - } - } - continue - } - - // 添加原始结果 - for _, result := range itemResults { - if !processedChunkIDs[result.Chunk.ID] { - expandedResults = append(expandedResults, result) - processedChunkIDs[result.Chunk.ID] = true - } - } - - // 为该文档的匹配chunk收集需要扩展的相邻chunk - // 策略:只对混合分数最高的前3个匹配chunk进行扩展,避免扩展过多 - // 先按混合分数排序,只扩展前3个(使用混合分数而不是相似度) - sortedItemResults := make([]*RetrievalResult, len(itemResults)) - copy(sortedItemResults, itemResults) - sort.Slice(sortedItemResults, func(i, j int) bool { - return sortedItemResults[i].Score > sortedItemResults[j].Score - }) - - // 只扩展前3个(或所有,如果少于3个) - maxExpandFrom := 3 - if len(sortedItemResults) < maxExpandFrom { - maxExpandFrom = len(sortedItemResults) - } - - // 使用map去重,避免同一个chunk被多次添加 - relatedChunksMap := make(map[string]*KnowledgeChunk) - - for i := 0; i < maxExpandFrom; i++ { - result := sortedItemResults[i] - // 查找相关chunk(上下各2个,排除已处理的chunk) - relatedChunks := r.findRelatedChunks(result.Chunk, allChunks, processedChunkIDs) - for _, relatedChunk := range relatedChunks { - // 使用chunk ID作为key去重 - if !processedChunkIDs[relatedChunk.ID] { - relatedChunksMap[relatedChunk.ID] = relatedChunk - } - } - } - - // 限制每个文档最多扩展的chunk数量(避免扩展过多) - // 策略:最多扩展8个chunk,无论匹配了多少个chunk - // 这样可以避免当多个匹配chunk分散在文档不同位置时,扩展出过多chunk - maxExpandPerItem := 8 - - // 将相关chunk转换为切片并按索引排序,优先选择距离匹配chunk最近的 - relatedChunksList := make([]*KnowledgeChunk, 0, len(relatedChunksMap)) - for _, chunk := range relatedChunksMap { - relatedChunksList = append(relatedChunksList, chunk) - } - - // 计算每个相关chunk到最近匹配chunk的距离,按距离排序 - sort.Slice(relatedChunksList, func(i, j int) bool { - // 计算到最近匹配chunk的距离 - minDistI := len(allChunks) - minDistJ := len(allChunks) - for _, result := range itemResults { - distI := abs(relatedChunksList[i].ChunkIndex - result.Chunk.ChunkIndex) - distJ := abs(relatedChunksList[j].ChunkIndex - result.Chunk.ChunkIndex) - if distI < minDistI { - minDistI = distI - } - if distJ < minDistJ { - minDistJ = distJ - } - } - return minDistI < minDistJ - }) - - // 限制数量 - if len(relatedChunksList) > maxExpandPerItem { - relatedChunksList = relatedChunksList[:maxExpandPerItem] - } - - // 添加去重后的相关chunk - // 使用该文档中混合分数最高的结果作为参考 - maxScore := 0.0 - maxSimilarity := 0.0 - for _, result := range itemResults { - if result.Score > maxScore { - maxScore = result.Score - } - if result.Similarity > maxSimilarity { - maxSimilarity = result.Similarity - } - } - - // 计算扩展chunk的混合分数(使用相同的混合权重) - hybridWeight := r.config.HybridWeight - expandedSimilarity := maxSimilarity * 0.8 // 相关chunk的相似度略低 - // 对于扩展的chunk,BM25分数设为0(因为它们是上下文扩展,不是直接匹配) - expandedBM25 := 0.0 - expandedScore := hybridWeight*expandedSimilarity + (1-hybridWeight)*expandedBM25 - - for _, relatedChunk := range relatedChunksList { - expandedResult := &RetrievalResult{ - Chunk: relatedChunk, - Item: itemResults[0].Item, // 使用第一个结果的Item信息 - Similarity: expandedSimilarity, - Score: expandedScore, // 使用正确的混合分数 - } - expandedResults = append(expandedResults, expandedResult) - processedChunkIDs[relatedChunk.ID] = true - } - } - - return expandedResults -} - -// loadAllChunksForItem 加载文档的所有chunk -func (r *Retriever) loadAllChunksForItem(itemID string) ([]*KnowledgeChunk, error) { - rows, err := r.db.Query(` - SELECT id, item_id, chunk_index, chunk_text, embedding - FROM knowledge_embeddings - WHERE item_id = ? - ORDER BY chunk_index - `, itemID) - if err != nil { - return nil, fmt.Errorf("查询chunk失败: %w", err) - } - defer rows.Close() - - var chunks []*KnowledgeChunk - for rows.Next() { - var chunkID, itemID, chunkText, embeddingJSON string - var chunkIndex int - - if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON); err != nil { - r.logger.Warn("扫描chunk失败", zap.Error(err)) - continue - } - - // 解析向量(可选,这里不需要) - var embedding []float32 - if embeddingJSON != "" { - json.Unmarshal([]byte(embeddingJSON), &embedding) - } - - chunk := &KnowledgeChunk{ - ID: chunkID, - ItemID: itemID, - ChunkIndex: chunkIndex, - ChunkText: chunkText, - Embedding: embedding, - } - chunks = append(chunks, chunk) - } - - return chunks, nil -} - -// findRelatedChunks 查找与给定chunk相关的其他chunk -// 策略:只返回上下各2个相邻的chunk(共最多4个) -// 排除已处理的chunk,避免重复添加 -func (r *Retriever) findRelatedChunks(targetChunk *KnowledgeChunk, allChunks []*KnowledgeChunk, processedChunkIDs map[string]bool) []*KnowledgeChunk { - related := make([]*KnowledgeChunk, 0) - - // 查找上下各2个相邻chunk - for _, chunk := range allChunks { - if chunk.ID == targetChunk.ID { - continue - } - - // 检查是否已经被处理过(可能已经在检索结果中) - if processedChunkIDs[chunk.ID] { - continue - } - - // 检查是否是相邻chunk(索引相差不超过2,且不为0) - indexDiff := chunk.ChunkIndex - targetChunk.ChunkIndex - if indexDiff >= -2 && indexDiff <= 2 && indexDiff != 0 { - related = append(related, chunk) - } - } - - // 按索引距离排序,优先选择最近的 - sort.Slice(related, func(i, j int) bool { - diffI := abs(related[i].ChunkIndex - targetChunk.ChunkIndex) - diffJ := abs(related[j].ChunkIndex - targetChunk.ChunkIndex) - return diffI < diffJ - }) - - // 限制最多返回4个(上下各2个) - if len(related) > 4 { - related = related[:4] - } - - return related -} - -// abs 返回整数的绝对值 -func abs(x int) int { - if x < 0 { - return -x - } - return x +// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。 +func (r *Retriever) AsEinoRetriever() retriever.Retriever { + return NewVectorEinoRetriever(r) } diff --git a/internal/knowledge/schema_migrate.go b/internal/knowledge/schema_migrate.go new file mode 100644 index 00000000..85fd26e2 --- /dev/null +++ b/internal/knowledge/schema_migrate.go @@ -0,0 +1,51 @@ +package knowledge + +import ( + "database/sql" + "fmt" +) + +// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata. +func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error { + if db == nil { + return fmt.Errorf("db is nil") + } + var n int + if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { + return err + } + if n == 0 { + return nil + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes", + `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model", + `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim", + `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil { + return err + } + return nil +} + +func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error { + var colCount int + q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` + if err := db.QueryRow(q, column).Scan(&colCount); err != nil { + return err + } + if colCount > 0 { + return nil + } + _, err := db.Exec(alterSQL) + return err +} + +// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。 +func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error { + return EnsureKnowledgeEmbeddingsSchema(db) +} diff --git a/internal/knowledge/tool.go b/internal/knowledge/tool.go index 31e52554..c7aa3f68 100644 --- a/internal/knowledge/tool.go +++ b/internal/knowledge/tool.go @@ -81,8 +81,8 @@ func RegisterKnowledgeTool( // 注册第二个工具:搜索知识库(保持原有功能) searchTool := mcp.Tool{ Name: builtin.ToolSearchKnowledgeBase, - Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", - ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)", + Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", + ShortDescription: "搜索知识库中的安全知识(向量语义检索)", InputSchema: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ @@ -123,7 +123,7 @@ func RegisterKnowledgeTool( zap.String("riskType", riskType), ) - // 执行检索 + // 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。 searchReq := &SearchRequest{ Query: query, RiskType: riskType, @@ -158,17 +158,16 @@ func RegisterKnowledgeTool( // 格式化结果 var resultText strings.Builder - // 先按混合分数排序,确保文档顺序是按混合分数的(混合检索的核心) + // 按余弦相似度(Score)降序 sort.Slice(results, func(i, j int) bool { return results[i].Score > results[j].Score }) // 按文档分组结果,以便更好地展示上下文 - // 使用有序的slice来保持文档顺序(按最高混合分数) type itemGroup struct { itemID string results []*RetrievalResult - maxScore float64 // 该文档的最高混合分数 + maxScore float64 // 该文档块的最高相似度 } itemGroups := make([]*itemGroup, 0) itemMap := make(map[string]*itemGroup) @@ -191,7 +190,7 @@ func RegisterKnowledgeTool( } } - // 按最高混合分数排序文档组 + // 按文档内最高相似度排序 sort.Slice(itemGroups, func(i, j int) bool { return itemGroups[i].maxScore > itemGroups[j].maxScore }) @@ -199,12 +198,11 @@ func RegisterKnowledgeTool( // 收集检索到的知识项ID(用于日志) retrievedItemIDs := make([]string, 0, len(itemGroups)) - resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results))) + resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results))) resultIndex := 1 for _, group := range itemGroups { itemResults := group.results - // 找到混合分数最高的作为主结果(使用混合分数,而不是相似度) mainResult := itemResults[0] maxScore := mainResult.Score for _, result := range itemResults { @@ -219,9 +217,8 @@ func RegisterKnowledgeTool( return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex }) - // 显示主结果(混合分数最高的,同时显示相似度和混合分数) - resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n", - resultIndex, mainResult.Similarity*100, mainResult.Score*100)) + resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", + resultIndex, mainResult.Similarity*100)) resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) // 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk) diff --git a/internal/knowledge/types.go b/internal/knowledge/types.go index bccd3a93..80d0eb5f 100644 --- a/internal/knowledge/types.go +++ b/internal/knowledge/types.go @@ -80,7 +80,7 @@ type RetrievalResult struct { Chunk *KnowledgeChunk `json:"chunk"` Item *KnowledgeItem `json:"item"` Similarity float64 `json:"similarity"` // 相似度分数 - Score float64 `json:"score"` // 综合分数(混合检索) + Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度 } // RetrievalLog 检索日志 @@ -115,8 +115,9 @@ type CategoryWithItems struct { // SearchRequest 搜索请求 type SearchRequest struct { - Query string `json:"query"` - RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 - TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5 - Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7 + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 + SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据) + TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5 + Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7 }