mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-29 22:38:03 +02:00
Add files via upload
This commit is contained in:
+20
-15
@@ -19,7 +19,6 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/robot"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/skills"
|
||||
@@ -185,22 +184,25 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
embedder, err := knowledge.NewEmbedder(context.Background(), &cfg.Knowledge, &cfg.OpenAI, log.Logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库嵌入器失败: %w", err)
|
||||
}
|
||||
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger)
|
||||
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger)
|
||||
|
||||
// 创建检索器
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: cfg.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
|
||||
SubIndexFilter: cfg.Knowledge.Retrieval.SubIndexFilter,
|
||||
PostRetrieve: cfg.Knowledge.Retrieval.PostRetrieve,
|
||||
}
|
||||
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
|
||||
|
||||
// 创建索引器
|
||||
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger, &cfg.Knowledge.Indexing)
|
||||
// 创建索引器(Eino Compose 链)
|
||||
knowledgeIndexer, err = knowledge.NewIndexer(context.Background(), knowledgeDB, embedder, log.Logger, &cfg.Knowledge)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库索引器失败: %w", err)
|
||||
}
|
||||
|
||||
// 注册知识检索工具到MCP服务器
|
||||
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
|
||||
@@ -1697,22 +1699,25 @@ func initializeKnowledge(
|
||||
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
embedder, err := knowledge.NewEmbedder(context.Background(), &cfg.Knowledge, &cfg.OpenAI, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库嵌入器失败: %w", err)
|
||||
}
|
||||
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger)
|
||||
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger)
|
||||
|
||||
// 创建检索器
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: cfg.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
|
||||
SubIndexFilter: cfg.Knowledge.Retrieval.SubIndexFilter,
|
||||
PostRetrieve: cfg.Knowledge.Retrieval.PostRetrieve,
|
||||
}
|
||||
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
|
||||
|
||||
// 创建索引器
|
||||
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger, &cfg.Knowledge.Indexing)
|
||||
// 创建索引器(Eino Compose 链)
|
||||
knowledgeIndexer, err := knowledge.NewIndexer(context.Background(), knowledgeDB, embedder, logger, &cfg.Knowledge)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库索引器失败: %w", err)
|
||||
}
|
||||
|
||||
// 注册知识检索工具到MCP服务器
|
||||
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
|
||||
|
||||
+38
-12
@@ -754,16 +754,20 @@ func Default() *Config {
|
||||
Retrieval: RetrievalConfig{
|
||||
TopK: 5,
|
||||
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
|
||||
HybridWeight: 0.7,
|
||||
},
|
||||
Indexing: IndexingConfig{
|
||||
ChunkSize: 768, // 增加到 768,更好的上下文保持
|
||||
ChunkOverlap: 50,
|
||||
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
|
||||
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
|
||||
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
|
||||
MaxRetries: 3,
|
||||
RetryDelayMs: 1000,
|
||||
ChunkStrategy: "markdown_then_recursive",
|
||||
RequestTimeoutSeconds: 120,
|
||||
ChunkSize: 768, // 增加到 768,更好的上下文保持
|
||||
ChunkOverlap: 50,
|
||||
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
|
||||
BatchSize: 64,
|
||||
PreferSourceFile: false,
|
||||
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
|
||||
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
|
||||
MaxRetries: 3,
|
||||
RetryDelayMs: 1000,
|
||||
SubIndexes: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -780,11 +784,18 @@ type KnowledgeConfig struct {
|
||||
|
||||
// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为)
|
||||
type IndexingConfig struct {
|
||||
// ChunkStrategy: "markdown_then_recursive"(默认,Eino Markdown 标题切分后再递归切)或 "recursive"(仅递归切分)
|
||||
ChunkStrategy string `yaml:"chunk_strategy,omitempty" json:"chunk_strategy,omitempty"`
|
||||
// RequestTimeoutSeconds 嵌入 HTTP 客户端超时(秒),0 表示使用默认 120
|
||||
RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"`
|
||||
// 分块配置
|
||||
ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512
|
||||
ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50
|
||||
MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制
|
||||
|
||||
// PreferSourceFile 为 true 时优先用 Eino FileLoader 从 file_path 读原文再索引(与库内 content 不一致时以磁盘为准)
|
||||
PreferSourceFile bool `yaml:"prefer_source_file,omitempty" json:"prefer_source_file,omitempty"`
|
||||
|
||||
// 速率限制配置(用于避免 API 速率限制)
|
||||
RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟
|
||||
MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制
|
||||
@@ -793,8 +804,10 @@ type IndexingConfig struct {
|
||||
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3
|
||||
RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000
|
||||
|
||||
// 批处理配置(用于批量嵌入,当前未使用,保留扩展)
|
||||
BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` // 批量处理大小,0 表示逐个处理
|
||||
// BatchSize 嵌入批大小(SQLite 索引写入),0 表示默认 64
|
||||
BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"`
|
||||
// SubIndexes 传入 Eino indexer.WithSubIndexes(逻辑分区标记,随 Document 元数据传递)
|
||||
SubIndexes []string `yaml:"sub_indexes,omitempty" json:"sub_indexes,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingConfig 嵌入配置
|
||||
@@ -805,11 +818,24 @@ type EmbeddingConfig struct {
|
||||
APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承)
|
||||
}
|
||||
|
||||
// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。
|
||||
type PostRetrieveConfig struct {
|
||||
// PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。
|
||||
PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"`
|
||||
// MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。
|
||||
MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"`
|
||||
// MaxContextTokens 返回文档内容总 token 上限(tiktoken,按嵌入模型名映射,失败则 cl100k_base);0 表示不限制。
|
||||
MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值
|
||||
HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1)
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值
|
||||
// SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。
|
||||
SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"`
|
||||
// PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。
|
||||
PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"`
|
||||
}
|
||||
|
||||
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
|
||||
|
||||
@@ -718,6 +718,9 @@ func (db *DB) initKnowledgeTables() error {
|
||||
chunk_index INTEGER NOT NULL,
|
||||
chunk_text TEXT NOT NULL,
|
||||
embedding TEXT NOT NULL,
|
||||
sub_indexes TEXT NOT NULL DEFAULT '',
|
||||
embedding_model TEXT NOT NULL DEFAULT '',
|
||||
embedding_dim INTEGER NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE
|
||||
);`
|
||||
@@ -759,10 +762,47 @@ func (db *DB) initKnowledgeTables() error {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.migrateKnowledgeEmbeddingsColumns(); err != nil {
|
||||
return fmt.Errorf("迁移 knowledge_embeddings 列失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("知识库数据库表初始化完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateKnowledgeEmbeddingsColumns 为已有库补充 sub_indexes、embedding_model、embedding_dim。
|
||||
func (db *DB) migrateKnowledgeEmbeddingsColumns() error {
|
||||
var n int
|
||||
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
migrations := []struct {
|
||||
col string
|
||||
stmt string
|
||||
}{
|
||||
{"sub_indexes", `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`},
|
||||
{"embedding_model", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`},
|
||||
{"embedding_dim", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`},
|
||||
}
|
||||
for _, m := range migrations {
|
||||
var colCount int
|
||||
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
|
||||
if err := db.QueryRow(q, m.col).Scan(&colCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if colCount > 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := db.Exec(m.stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
|
||||
@@ -642,7 +642,6 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
zap.String("embedding_model", h.config.Knowledge.Embedding.Model),
|
||||
zap.Int("retrieval_top_k", h.config.Knowledge.Retrieval.TopK),
|
||||
zap.Float64("similarity_threshold", h.config.Knowledge.Retrieval.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", h.config.Knowledge.Retrieval.HybridWeight),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1051,13 +1050,13 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: h.config.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
|
||||
SubIndexFilter: h.config.Knowledge.Retrieval.SubIndexFilter,
|
||||
PostRetrieve: h.config.Knowledge.Retrieval.PostRetrieve,
|
||||
}
|
||||
h.retrieverUpdater.UpdateConfig(retrievalConfig)
|
||||
h.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", retrievalConfig.TopK),
|
||||
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1289,13 +1288,22 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
retrievalNode := ensureMap(knowledgeNode, "retrieval")
|
||||
setIntInMap(retrievalNode, "top_k", cfg.Retrieval.TopK)
|
||||
setFloatInMap(retrievalNode, "similarity_threshold", cfg.Retrieval.SimilarityThreshold)
|
||||
setFloatInMap(retrievalNode, "hybrid_weight", cfg.Retrieval.HybridWeight)
|
||||
setStringInMap(retrievalNode, "sub_index_filter", cfg.Retrieval.SubIndexFilter)
|
||||
postNode := ensureMap(retrievalNode, "post_retrieve")
|
||||
setIntInMap(postNode, "prefetch_top_k", cfg.Retrieval.PostRetrieve.PrefetchTopK)
|
||||
setIntInMap(postNode, "max_context_chars", cfg.Retrieval.PostRetrieve.MaxContextChars)
|
||||
setIntInMap(postNode, "max_context_tokens", cfg.Retrieval.PostRetrieve.MaxContextTokens)
|
||||
|
||||
// 更新索引配置
|
||||
indexingNode := ensureMap(knowledgeNode, "indexing")
|
||||
setStringInMap(indexingNode, "chunk_strategy", cfg.Indexing.ChunkStrategy)
|
||||
setIntInMap(indexingNode, "request_timeout_seconds", cfg.Indexing.RequestTimeoutSeconds)
|
||||
setIntInMap(indexingNode, "chunk_size", cfg.Indexing.ChunkSize)
|
||||
setIntInMap(indexingNode, "chunk_overlap", cfg.Indexing.ChunkOverlap)
|
||||
setIntInMap(indexingNode, "max_chunks_per_item", cfg.Indexing.MaxChunksPerItem)
|
||||
setBoolInMap(indexingNode, "prefer_source_file", cfg.Indexing.PreferSourceFile)
|
||||
setIntInMap(indexingNode, "batch_size", cfg.Indexing.BatchSize)
|
||||
setStringSliceInMap(indexingNode, "sub_indexes", cfg.Indexing.SubIndexes)
|
||||
setIntInMap(indexingNode, "max_rpm", cfg.Indexing.MaxRPM)
|
||||
setIntInMap(indexingNode, "rate_limit_delay_ms", cfg.Indexing.RateLimitDelayMs)
|
||||
setIntInMap(indexingNode, "max_retries", cfg.Indexing.MaxRetries)
|
||||
@@ -1397,6 +1405,21 @@ func setStringInMap(mapNode *yaml.Node, key, value string) {
|
||||
valueNode.Value = value
|
||||
}
|
||||
|
||||
func setStringSliceInMap(mapNode *yaml.Node, key string, values []string) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Style = 0
|
||||
valueNode.Content = nil
|
||||
for _, v := range values {
|
||||
valueNode.Content = append(valueNode.Content, &yaml.Node{
|
||||
Kind: yaml.ScalarNode,
|
||||
Tag: "!!str",
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setIntInMap(mapNode *yaml.Node, key string, value int) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.ScalarNode
|
||||
@@ -1450,7 +1473,7 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
|
||||
valueNode.Kind = yaml.ScalarNode
|
||||
valueNode.Tag = "!!float"
|
||||
valueNode.Style = 0
|
||||
// 对于0.0到1.0之间的值(如hybrid_weight),使用%.1f确保0.0被明确序列化为"0.0"
|
||||
// 对于0.0到1.0之间的值(如 similarity_threshold),使用%.1f确保0.0被明确序列化为"0.0"
|
||||
// 对于其他值,使用%g自动选择最合适的格式
|
||||
if value >= 0.0 && value <= 1.0 {
|
||||
valueNode.Value = fmt.Sprintf("%.1f", value)
|
||||
|
||||
@@ -482,6 +482,7 @@ func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。
|
||||
results, err := h.retriever.Search(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
h.logger.Error("搜索知识库失败", zap.Error(err))
|
||||
|
||||
@@ -4181,7 +4181,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"知识库"},
|
||||
"summary": "搜索知识库",
|
||||
"description": "在知识库中搜索相关内容。使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。\n**搜索说明**:\n- 支持语义相似度搜索(向量检索)\n- 支持关键词匹配(BM25)\n- 支持混合搜索(结合向量和关键词)\n- 可以按风险类型过滤(如:SQL注入、XSS、文件上传等)\n- 建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表\n**使用示例**:\n```json\n{\n \"query\": \"SQL注入漏洞的检测方法\",\n \"riskType\": \"SQL注入\",\n \"topK\": 5,\n \"threshold\": 0.7\n}\n```",
|
||||
"description": "在知识库中搜索相关内容。基于向量检索,按查询与知识片段的语义相似度(余弦)返回最相关结果。\n**搜索说明**:\n- 语义相似度搜索:嵌入向量 + 余弦相似度,可配置相似度阈值与 TopK\n- 可按风险类型等元数据过滤(如:SQL注入、XSS、文件上传等)\n- 建议先调用 `/api/knowledge/categories` 获取可用的风险类型列表\n**使用示例**:\n```json\n{\n \"query\": \"SQL注入漏洞的检测方法\",\n \"riskType\": \"SQL注入\",\n \"topK\": 5,\n \"threshold\": 0.7\n}\n```",
|
||||
"operationId": "searchKnowledge",
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
func tokenizerLenFunc(embeddingModel string) func(string) int {
|
||||
fallback := func(s string) int {
|
||||
r := []rune(s)
|
||||
if len(r) == 0 {
|
||||
return 0
|
||||
}
|
||||
return (len(r) + 3) / 4
|
||||
}
|
||||
m := strings.TrimSpace(embeddingModel)
|
||||
if m == "" {
|
||||
return fallback
|
||||
}
|
||||
tok, err := tiktoken.EncodingForModel(m)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return func(s string) int {
|
||||
return len(tok.Encode(s, nil, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for
|
||||
// embeddingModel when available, else rune/4 approximation.
|
||||
func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) {
|
||||
if chunkSize <= 0 {
|
||||
return nil, fmt.Errorf("chunk size must be positive")
|
||||
}
|
||||
if overlap < 0 {
|
||||
overlap = 0
|
||||
}
|
||||
return recursive.NewSplitter(context.Background(), &recursive.Config{
|
||||
ChunkSize: chunkSize,
|
||||
OverlapSize: overlap,
|
||||
LenFunc: tokenizerLenFunc(embeddingModel),
|
||||
Separators: []string{
|
||||
"\n\n", "\n## ", "\n### ", "\n#### ", "\n",
|
||||
"。", "!", "?", ". ", "? ", "! ",
|
||||
" ",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。
|
||||
func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) {
|
||||
return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
|
||||
Headers: map[string]string{
|
||||
"#": "h1",
|
||||
"##": "h2",
|
||||
"###": "h3",
|
||||
"####": "h4",
|
||||
},
|
||||
TrimHeaders: false,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Document metadata keys for Eino schema.Document flowing through the RAG pipeline.
|
||||
const (
|
||||
metaKBCategory = "kb_category"
|
||||
metaKBTitle = "kb_title"
|
||||
metaKBItemID = "kb_item_id"
|
||||
metaKBChunkIndex = "kb_chunk_index"
|
||||
metaSimilarity = "similarity"
|
||||
)
|
||||
|
||||
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
|
||||
const (
|
||||
DSLRiskType = "risk_type"
|
||||
DSLSimilarityThreshold = "similarity_threshold"
|
||||
DSLSubIndexFilter = "sub_index_filter"
|
||||
)
|
||||
|
||||
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
|
||||
// stay comparable if users skip reindex; new indexes use the same string shape.
|
||||
func FormatEmbeddingInput(category, title, chunkText string) string {
|
||||
return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText)
|
||||
}
|
||||
|
||||
// FormatQueryEmbeddingText builds the string embedded at query time so it matches
|
||||
// [FormatEmbeddingInput] for the same risk category (title left empty for queries).
|
||||
func FormatQueryEmbeddingText(riskType, query string) string {
|
||||
q := strings.TrimSpace(query)
|
||||
rt := strings.TrimSpace(riskType)
|
||||
if rt != "" {
|
||||
return FormatEmbeddingInput(rt, "", q)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// MetaLookupString returns metadata string value or "" if absent.
|
||||
func MetaLookupString(md map[string]any, key string) string {
|
||||
if md == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return t
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(t))
|
||||
}
|
||||
}
|
||||
|
||||
// MetaStringOK returns trimmed non-empty string and true if present and non-empty.
|
||||
func MetaStringOK(md map[string]any, key string) (string, bool) {
|
||||
s := strings.TrimSpace(MetaLookupString(md, key))
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// RequireMetaString requires a non-empty string metadata field.
|
||||
func RequireMetaString(md map[string]any, key string) (string, error) {
|
||||
s, ok := MetaStringOK(md, key)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing or empty metadata %q", key)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RequireMetaInt requires an integer metadata field.
|
||||
func RequireMetaInt(md map[string]any, key string) (int, error) {
|
||||
if md == nil {
|
||||
return 0, fmt.Errorf("missing metadata key %q", key)
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("missing metadata key %q", key)
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
return t, nil
|
||||
case int32:
|
||||
return int(t), nil
|
||||
case int64:
|
||||
return int(t), nil
|
||||
case float64:
|
||||
return int(t), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v)
|
||||
}
|
||||
}
|
||||
|
||||
// DSLNumeric coerces DSL map values (e.g. from JSON) to float64.
|
||||
func DSLNumeric(v any) (float64, bool) {
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
return t, true
|
||||
case float32:
|
||||
return float64(t), true
|
||||
case int:
|
||||
return float64(t), true
|
||||
case int64:
|
||||
return float64(t), true
|
||||
case uint32:
|
||||
return float64(t), true
|
||||
case uint64:
|
||||
return float64(t), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// MetaFloat64OK reads a float metadata value.
|
||||
func MetaFloat64OK(md map[string]any, key string) (float64, bool) {
|
||||
if md == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return DSLNumeric(v)
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package knowledge
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) {
|
||||
q := FormatQueryEmbeddingText("XSS", "payload")
|
||||
want := FormatEmbeddingInput("XSS", "", "payload")
|
||||
if q != want {
|
||||
t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want)
|
||||
}
|
||||
if FormatQueryEmbeddingText("", "hello") != "hello" {
|
||||
t.Fatalf("expected bare query without risk type")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
|
||||
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
|
||||
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("retriever is nil")
|
||||
}
|
||||
ch := compose.NewChain[string, []*schema.Document]()
|
||||
ch.AppendRetriever(r.AsEinoRetriever())
|
||||
return ch.Compile(ctx)
|
||||
}
|
||||
|
||||
// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。
|
||||
func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) {
|
||||
return BuildKnowledgeRetrieveChain(ctx, r)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) {
|
||||
r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop())
|
||||
_, err := BuildKnowledgeRetrieveChain(context.Background(), r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) {
|
||||
_, err := BuildKnowledgeRetrieveChain(context.Background(), nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil retriever")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
|
||||
//
|
||||
// Options:
|
||||
// - [retriever.WithTopK]
|
||||
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string)
|
||||
//
|
||||
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
|
||||
//
|
||||
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
|
||||
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
|
||||
type VectorEinoRetriever struct {
|
||||
inner *Retriever
|
||||
}
|
||||
|
||||
// NewVectorEinoRetriever wraps r for Eino compose / tooling.
|
||||
func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &VectorEinoRetriever{inner: r}
|
||||
}
|
||||
|
||||
// GetType identifies this retriever for Eino callbacks.
|
||||
func (h *VectorEinoRetriever) GetType() string {
|
||||
return "SQLiteVectorKnowledgeRetriever"
|
||||
}
|
||||
|
||||
// Retrieve runs vector search and returns [schema.Document] rows.
|
||||
func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
|
||||
if h == nil || h.inner == nil {
|
||||
return nil, fmt.Errorf("VectorEinoRetriever: nil retriever")
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
ro := retriever.GetCommonOptions(nil, opts...)
|
||||
cfg := h.inner.config
|
||||
|
||||
req := &SearchRequest{Query: q}
|
||||
|
||||
if ro.TopK != nil && *ro.TopK > 0 {
|
||||
req.TopK = *ro.TopK
|
||||
} else if cfg != nil && cfg.TopK > 0 {
|
||||
req.TopK = cfg.TopK
|
||||
} else {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
req.Threshold = 0
|
||||
if ro.DSLInfo != nil {
|
||||
if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok {
|
||||
req.RiskType = strings.TrimSpace(rt)
|
||||
}
|
||||
if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok {
|
||||
if f, ok2 := DSLNumeric(v); ok2 && f > 0 {
|
||||
req.Threshold = f
|
||||
}
|
||||
}
|
||||
if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok {
|
||||
req.SubIndexFilter = strings.TrimSpace(sf)
|
||||
}
|
||||
}
|
||||
if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" {
|
||||
req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter)
|
||||
}
|
||||
if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 {
|
||||
req.Threshold = cfg.SimilarityThreshold
|
||||
}
|
||||
if req.Threshold <= 0 {
|
||||
req.Threshold = 0.7
|
||||
}
|
||||
|
||||
finalTopK := req.TopK
|
||||
var postPO *config.PostRetrieveConfig
|
||||
if cfg != nil {
|
||||
postPO = &cfg.PostRetrieve
|
||||
}
|
||||
fetchK := EffectivePrefetchTopK(finalTopK, postPO)
|
||||
searchReq := *req
|
||||
searchReq.TopK = fetchK
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever)
|
||||
th := req.Threshold
|
||||
st := &th
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: q,
|
||||
TopK: finalTopK,
|
||||
ScoreThreshold: st,
|
||||
Extra: ro.DSLInfo,
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
|
||||
}()
|
||||
|
||||
results, err := h.inner.vectorSearch(ctx, &searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = retrievalResultsToDocuments(results)
|
||||
|
||||
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if h.inner.logger != nil {
|
||||
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if h.inner.embedder != nil {
|
||||
tokenModel = h.inner.embedder.EmbeddingModelName()
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document {
|
||||
out := make([]*schema.Document, 0, len(results))
|
||||
for _, res := range results {
|
||||
if res == nil || res.Chunk == nil || res.Item == nil {
|
||||
continue
|
||||
}
|
||||
d := &schema.Document{
|
||||
ID: res.Chunk.ID,
|
||||
Content: res.Chunk.ChunkText,
|
||||
MetaData: map[string]any{
|
||||
metaKBItemID: res.Item.ID,
|
||||
metaKBCategory: res.Item.Category,
|
||||
metaKBTitle: res.Item.Title,
|
||||
metaKBChunkIndex: res.Chunk.ChunkIndex,
|
||||
metaSimilarity: res.Similarity,
|
||||
},
|
||||
}
|
||||
d.WithScore(res.Score)
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) {
|
||||
out := make([]*RetrievalResult, 0, len(docs))
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
itemID, err := RequireMetaString(d.MetaData, metaKBItemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("document %d: %w", i, err)
|
||||
}
|
||||
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
||||
title := MetaLookupString(d.MetaData, metaKBTitle)
|
||||
chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("document %d: %w", i, err)
|
||||
}
|
||||
sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity)
|
||||
item := &KnowledgeItem{ID: itemID, Category: cat, Title: title}
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: d.ID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIdx,
|
||||
ChunkText: d.Content,
|
||||
}
|
||||
out = append(out, &RetrievalResult{
|
||||
Chunk: chunk,
|
||||
Item: item,
|
||||
Similarity: sim,
|
||||
Score: d.Score(),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ retriever.Retriever = (*VectorEinoRetriever)(nil)
|
||||
@@ -0,0 +1,142 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
|
||||
type SQLiteIndexer struct {
|
||||
db *sql.DB
|
||||
batchSize int
|
||||
embeddingModel string
|
||||
}
|
||||
|
||||
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
|
||||
// batchSize is the embedding batch size; if <= 0, default 64 is used.
|
||||
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
|
||||
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
|
||||
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
|
||||
}
|
||||
|
||||
// GetType implements eino callback run info.
|
||||
func (s *SQLiteIndexer) GetType() string {
|
||||
return "SQLiteKnowledgeIndexer"
|
||||
}
|
||||
|
||||
// Store embeds documents and inserts rows. Each doc must carry MetaData:
|
||||
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
|
||||
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
options := indexer.GetCommonOptions(nil, opts...)
|
||||
if options.Embedding == nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: embedding is required")
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
|
||||
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
|
||||
}()
|
||||
|
||||
subIdxStr := strings.Join(options.SubIndexes, ",")
|
||||
|
||||
texts := make([]string, len(docs))
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
|
||||
}
|
||||
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
||||
title := MetaLookupString(d.MetaData, metaKBTitle)
|
||||
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
|
||||
}
|
||||
|
||||
bs := s.batchSize
|
||||
if bs <= 0 {
|
||||
bs = 64
|
||||
}
|
||||
|
||||
var allVecs [][]float64
|
||||
for start := 0; start < len(texts); start += bs {
|
||||
end := start + bs
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
batch := texts[start:end]
|
||||
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
|
||||
if embedErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
|
||||
}
|
||||
if len(vecs) != len(batch) {
|
||||
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
|
||||
}
|
||||
allVecs = append(allVecs, vecs...)
|
||||
}
|
||||
|
||||
embedDim := 0
|
||||
if len(allVecs) > 0 {
|
||||
embedDim = len(allVecs[0])
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
ids = make([]string, 0, len(docs))
|
||||
for i, d := range docs {
|
||||
chunkID := uuid.New().String()
|
||||
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
|
||||
if metaErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
||||
}
|
||||
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
||||
if metaErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
||||
}
|
||||
vec := allVecs[i]
|
||||
if embedDim > 0 && len(vec) != embedDim {
|
||||
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
|
||||
}
|
||||
vec32 := make([]float32, len(vec))
|
||||
for j, v := range vec {
|
||||
vec32[j] = float32(v)
|
||||
}
|
||||
embeddingJSON, jsonErr := json.Marshal(vec32)
|
||||
if jsonErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx,
|
||||
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
|
||||
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
|
||||
}
|
||||
ids = append(ids, chunkID)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
|
||||
+184
-256
@@ -2,7 +2,6 @@ package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -10,43 +9,47 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Embedder 文本嵌入器
|
||||
// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。
|
||||
type Embedder struct {
|
||||
openAIClient *openai.Client
|
||||
config *config.KnowledgeConfig
|
||||
openAIConfig *config.OpenAIConfig // 用于获取 API Key
|
||||
logger *zap.Logger
|
||||
rateLimiter *rate.Limiter // 速率限制器
|
||||
rateLimitDelay time.Duration // 请求间隔时间
|
||||
maxRetries int // 最大重试次数
|
||||
retryDelay time.Duration // 重试间隔
|
||||
mu sync.Mutex // 保护 rateLimiter
|
||||
eino embedding.Embedder
|
||||
config *config.KnowledgeConfig
|
||||
logger *zap.Logger
|
||||
|
||||
rateLimiter *rate.Limiter
|
||||
rateLimitDelay time.Duration
|
||||
maxRetries int
|
||||
retryDelay time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewEmbedder 创建新的嵌入器
|
||||
func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder {
|
||||
// 初始化速率限制器
|
||||
// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。
|
||||
func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("knowledge config is nil")
|
||||
}
|
||||
|
||||
var rateLimiter *rate.Limiter
|
||||
var rateLimitDelay time.Duration
|
||||
|
||||
// 如果配置了 MaxRPM,根据 RPM 计算速率限制
|
||||
if cfg.Indexing.MaxRPM > 0 {
|
||||
rpm := cfg.Indexing.MaxRPM
|
||||
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
|
||||
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
|
||||
if logger != nil {
|
||||
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
|
||||
}
|
||||
} else if cfg.Indexing.RateLimitDelayMs > 0 {
|
||||
// 如果没有配置 MaxRPM 但配置了固定延迟,使用固定延迟模式
|
||||
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
|
||||
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
|
||||
if logger != nil {
|
||||
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
|
||||
}
|
||||
}
|
||||
|
||||
// 重试配置
|
||||
maxRetries := 3
|
||||
retryDelay := 1000 * time.Millisecond
|
||||
if cfg.Indexing.MaxRetries > 0 {
|
||||
@@ -56,268 +59,193 @@ func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig,
|
||||
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
|
||||
}
|
||||
|
||||
return &Embedder{
|
||||
openAIClient: openAIClient,
|
||||
config: cfg,
|
||||
openAIConfig: openAIConfig,
|
||||
logger: logger,
|
||||
rateLimiter: rateLimiter,
|
||||
rateLimitDelay: rateLimitDelay,
|
||||
maxRetries: maxRetries,
|
||||
retryDelay: retryDelay,
|
||||
}
|
||||
}
|
||||
|
||||
// EmbeddingRequest OpenAI 嵌入请求
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse OpenAI 嵌入响应
|
||||
type EmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Error *EmbeddingError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingData 嵌入数据
|
||||
type EmbeddingData struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// EmbeddingError 嵌入错误
|
||||
type EmbeddingError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// waitRateLimiter 等待速率限制器
|
||||
func (e *Embedder) waitRateLimiter() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.rateLimiter != nil {
|
||||
// 等待令牌
|
||||
ctx := context.Background()
|
||||
if err := e.rateLimiter.Wait(ctx); err != nil {
|
||||
e.logger.Warn("速率限制器等待失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if e.rateLimitDelay > 0 {
|
||||
time.Sleep(e.rateLimitDelay)
|
||||
}
|
||||
}
|
||||
|
||||
// EmbedText 对文本进行嵌入(带重试和速率限制)
|
||||
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
|
||||
if e.openAIClient == nil {
|
||||
return nil, fmt.Errorf("OpenAI 客户端未初始化")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < e.maxRetries; attempt++ {
|
||||
// 速率限制
|
||||
if attempt > 0 {
|
||||
// 重试时等待更长时间
|
||||
waitTime := e.retryDelay * time.Duration(attempt)
|
||||
e.logger.Debug("重试前等待", zap.Int("attempt", attempt+1), zap.Duration("waitTime", waitTime))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(waitTime):
|
||||
}
|
||||
} else {
|
||||
e.waitRateLimiter()
|
||||
}
|
||||
|
||||
result, err := e.doEmbedText(ctx, text)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// 检查是否是可重试的错误(429 速率限制、5xx 服务器错误、网络错误)
|
||||
if !e.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e.logger.Debug("嵌入请求失败,准备重试",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", e.maxRetries),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// doEmbedText 执行实际的嵌入请求(内部方法)
|
||||
func (e *Embedder) doEmbedText(ctx context.Context, text string) ([]float32, error) {
|
||||
// 使用配置的嵌入模型
|
||||
model := e.config.Embedding.Model
|
||||
model := strings.TrimSpace(cfg.Embedding.Model)
|
||||
if model == "" {
|
||||
model = "text-embedding-3-small"
|
||||
}
|
||||
|
||||
req := EmbeddingRequest{
|
||||
Model: model,
|
||||
Input: []string{text},
|
||||
}
|
||||
|
||||
// 清理 baseURL:去除前后空格和尾部斜杠
|
||||
baseURL := strings.TrimSpace(e.config.Embedding.BaseURL)
|
||||
baseURL := strings.TrimSpace(cfg.Embedding.BaseURL)
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败:%w", err)
|
||||
}
|
||||
|
||||
requestURL := baseURL + "/embeddings"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败:%w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 使用配置的 API Key,如果没有则使用 OpenAI 配置的
|
||||
apiKey := strings.TrimSpace(e.config.Embedding.APIKey)
|
||||
if apiKey == "" && e.openAIConfig != nil {
|
||||
apiKey = e.openAIConfig.APIKey
|
||||
apiKey := strings.TrimSpace(cfg.Embedding.APIKey)
|
||||
if apiKey == "" && openAIConfig != nil {
|
||||
apiKey = strings.TrimSpace(openAIConfig.APIKey)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("API Key 未配置")
|
||||
return nil, fmt.Errorf("embedding API key 未配置")
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// 发送请求
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
timeout := 120 * time.Second
|
||||
if cfg.Indexing.RequestTimeoutSeconds > 0 {
|
||||
timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second
|
||||
}
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
ByAzure: false,
|
||||
Model: model,
|
||||
HTTPClient: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败:%w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体以便在错误时输出详细信息
|
||||
bodyBytes := make([]byte, 0)
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
bodyBytes = append(bodyBytes, buf[:n]...)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("eino OpenAI embedder: %w", err)
|
||||
}
|
||||
|
||||
// 记录请求和响应信息(用于调试)
|
||||
requestBodyPreview := string(body)
|
||||
if len(requestBodyPreview) > 200 {
|
||||
requestBodyPreview = requestBodyPreview[:200] + "..."
|
||||
}
|
||||
e.logger.Debug("嵌入 API 请求",
|
||||
zap.String("url", httpReq.URL.String()),
|
||||
zap.String("model", model),
|
||||
zap.String("requestBody", requestBodyPreview),
|
||||
zap.Int("status", resp.StatusCode),
|
||||
zap.Int("bodySize", len(bodyBytes)),
|
||||
zap.String("contentType", resp.Header.Get("Content-Type")),
|
||||
)
|
||||
|
||||
var embeddingResp EmbeddingResponse
|
||||
if err := json.Unmarshal(bodyBytes, &embeddingResp); err != nil {
|
||||
// 输出详细的错误信息
|
||||
bodyPreview := string(bodyBytes)
|
||||
if len(bodyPreview) > 500 {
|
||||
bodyPreview = bodyPreview[:500] + "..."
|
||||
}
|
||||
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码:%d, 响应长度:%d字节): %w\n请求体:%s\n响应内容预览:%s",
|
||||
requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview)
|
||||
}
|
||||
|
||||
if embeddingResp.Error != nil {
|
||||
return nil, fmt.Errorf("OpenAI API 错误 (状态码:%d): 类型=%s, 消息=%s",
|
||||
resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyPreview := string(bodyBytes)
|
||||
if len(bodyPreview) > 500 {
|
||||
bodyPreview = bodyPreview[:500] + "..."
|
||||
}
|
||||
return nil, fmt.Errorf("HTTP 请求失败 (URL: %s, 状态码:%d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
|
||||
}
|
||||
|
||||
if len(embeddingResp.Data) == 0 {
|
||||
bodyPreview := string(bodyBytes)
|
||||
if len(bodyPreview) > 500 {
|
||||
bodyPreview = bodyPreview[:500] + "..."
|
||||
}
|
||||
return nil, fmt.Errorf("未收到嵌入数据 (状态码:%d, 响应长度:%d字节)\n响应内容:%s",
|
||||
resp.StatusCode, len(bodyBytes), bodyPreview)
|
||||
}
|
||||
|
||||
// 转换为 float32
|
||||
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
|
||||
for i, v := range embeddingResp.Data[0].Embedding {
|
||||
embedding[i] = float32(v)
|
||||
}
|
||||
|
||||
return embedding, nil
|
||||
return &Embedder{
|
||||
eino: inner,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
rateLimiter: rateLimiter,
|
||||
rateLimitDelay: rateLimitDelay,
|
||||
maxRetries: maxRetries,
|
||||
retryDelay: retryDelay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isRetryableError 判断是否是可重试的错误
|
||||
func (e *Embedder) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。
|
||||
func (e *Embedder) EmbeddingModelName() string {
|
||||
if e == nil || e.config == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// 429 速率限制错误
|
||||
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
|
||||
return true
|
||||
s := strings.TrimSpace(e.config.Embedding.Model)
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
|
||||
// 5xx 服务器错误
|
||||
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
|
||||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 网络错误
|
||||
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
|
||||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
return "text-embedding-3-small"
|
||||
}
|
||||
|
||||
// EmbedTexts 批量嵌入文本
|
||||
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
func (e *Embedder) waitRateLimiter() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.rateLimiter != nil {
|
||||
ctx := context.Background()
|
||||
if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil {
|
||||
e.logger.Warn("速率限制器等待失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if e.rateLimitDelay > 0 {
|
||||
time.Sleep(e.rateLimitDelay)
|
||||
}
|
||||
}
|
||||
|
||||
// EmbedText 单条嵌入(float32,与历史存储格式一致)。
|
||||
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
|
||||
vecs, err := e.EmbedStrings(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vecs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs))
|
||||
}
|
||||
return vecs[0], nil
|
||||
}
|
||||
|
||||
// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。
|
||||
func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) {
|
||||
if e == nil || e.eino == nil {
|
||||
return nil, fmt.Errorf("embedder not initialized")
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, len(texts))
|
||||
for i, text := range texts {
|
||||
embedding, err := e.EmbedText(ctx, text)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("嵌入文本 [%d] 失败:%w", i, err)
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < e.maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
wait := e.retryDelay * time.Duration(attempt)
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(wait):
|
||||
}
|
||||
} else {
|
||||
e.waitRateLimiter()
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
raw, err := e.eino.EmbedStrings(ctx, texts, opts...)
|
||||
if err == nil {
|
||||
out := make([][]float32, len(raw))
|
||||
for i, row := range raw {
|
||||
out[i] = make([]float32, len(row))
|
||||
for j, v := range row {
|
||||
out[i][j] = float32(v)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !e.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err))
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。
|
||||
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
return e.EmbedStrings(ctx, texts)
|
||||
}
|
||||
|
||||
func (e *Embedder) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
|
||||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
|
||||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store.
|
||||
type einoFloatEmbedder struct {
|
||||
inner *Embedder
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
vec32, err := w.inner.EmbedStrings(ctx, texts, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([][]float64, len(vec32))
|
||||
for i, row := range vec32 {
|
||||
out[i] = make([]float64, len(row))
|
||||
for j, v := range row {
|
||||
out[i][j] = float64(v)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) GetType() string {
|
||||
return "CyberStrikeKnowledgeEmbedder"
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) IsCallbacksEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path
|
||||
// and produces float64 vectors expected by generic Eino indexer helpers.
|
||||
func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder {
|
||||
return &einoFloatEmbedder{inner: e}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive".
|
||||
func normalizeChunkStrategy(s string) string {
|
||||
v := strings.TrimSpace(strings.ToLower(s))
|
||||
switch v {
|
||||
case "recursive":
|
||||
return "recursive"
|
||||
case "markdown_then_recursive", "markdown_recursive", "markdown":
|
||||
return "markdown_then_recursive"
|
||||
case "":
|
||||
return "markdown_then_recursive"
|
||||
default:
|
||||
return "markdown_then_recursive"
|
||||
}
|
||||
}
|
||||
|
||||
func buildKnowledgeIndexChain(
|
||||
ctx context.Context,
|
||||
indexingCfg *config.IndexingConfig,
|
||||
db *sql.DB,
|
||||
recursive document.Transformer,
|
||||
embeddingModel string,
|
||||
) (compose.Runnable[[]*schema.Document, []string], error) {
|
||||
if recursive == nil {
|
||||
return nil, fmt.Errorf("recursive transformer is nil")
|
||||
}
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("db is nil")
|
||||
}
|
||||
strategy := normalizeChunkStrategy("markdown_then_recursive")
|
||||
batch := 64
|
||||
maxChunks := 0
|
||||
if indexingCfg != nil {
|
||||
strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy)
|
||||
if indexingCfg.BatchSize > 0 {
|
||||
batch = indexingCfg.BatchSize
|
||||
}
|
||||
maxChunks = indexingCfg.MaxChunksPerItem
|
||||
}
|
||||
|
||||
si := NewSQLiteIndexer(db, batch, embeddingModel)
|
||||
ch := compose.NewChain[[]*schema.Document, []string]()
|
||||
if strategy != "recursive" {
|
||||
md, err := newMarkdownHeaderSplitter(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("markdown splitter: %w", err)
|
||||
}
|
||||
ch.AppendDocumentTransformer(md)
|
||||
}
|
||||
ch.AppendDocumentTransformer(recursive)
|
||||
ch.AppendLambda(newChunkEnrichLambda(maxChunks))
|
||||
ch.AppendIndexer(si)
|
||||
return ch.Compile(ctx)
|
||||
}
|
||||
|
||||
func newChunkEnrichLambda(maxChunks int) *compose.Lambda {
|
||||
return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
_ = ctx
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil || strings.TrimSpace(d.Content) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
if maxChunks > 0 && len(out) > maxChunks {
|
||||
out = out[:maxChunks]
|
||||
}
|
||||
for i, d := range out {
|
||||
if d.MetaData == nil {
|
||||
d.MetaData = make(map[string]any)
|
||||
}
|
||||
d.MetaData[metaKBChunkIndex] = i
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package knowledge
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeChunkStrategy(t *testing.T) {
|
||||
cases := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"", "markdown_then_recursive"},
|
||||
{"recursive", "recursive"},
|
||||
{"RECURSIVE", "recursive"},
|
||||
{"markdown_then_recursive", "markdown_then_recursive"},
|
||||
{"markdown", "markdown_then_recursive"},
|
||||
{"unknown", "markdown_then_recursive"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := normalizeChunkStrategy(tc.in); got != tc.want {
|
||||
t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+154
-562
@@ -3,596 +3,203 @@ package knowledge
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/google/uuid"
|
||||
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Indexer 索引器,负责将知识项分块并向量化
|
||||
// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。
|
||||
type Indexer struct {
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
logger *zap.Logger
|
||||
chunkSize int // 每个块的最大 token 数(估算)
|
||||
overlap int // 块之间的重叠 token 数
|
||||
maxChunks int // 单个知识项的最大块数量(0 表示不限制)
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
logger *zap.Logger
|
||||
chunkSize int
|
||||
overlap int
|
||||
indexingCfg *config.IndexingConfig
|
||||
|
||||
indexChain compose.Runnable[[]*schema.Document, []string]
|
||||
fileLoader *fileloader.FileLoader
|
||||
|
||||
// 错误跟踪
|
||||
mu sync.RWMutex
|
||||
lastError string // 最近一次错误信息
|
||||
lastErrorTime time.Time // 最近一次错误时间
|
||||
errorCount int // 连续错误计数
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
errorCount int
|
||||
|
||||
// 重建索引状态跟踪
|
||||
rebuildMu sync.RWMutex
|
||||
isRebuilding bool // 是否正在重建索引
|
||||
rebuildTotalItems int // 重建总项数
|
||||
rebuildCurrent int // 当前已处理项数
|
||||
rebuildFailed int // 重建失败项数
|
||||
rebuildStartTime time.Time // 重建开始时间
|
||||
rebuildLastItemID string // 最近处理的项 ID
|
||||
rebuildLastChunks int // 最近处理的项的分块数
|
||||
isRebuilding bool
|
||||
rebuildTotalItems int
|
||||
rebuildCurrent int
|
||||
rebuildFailed int
|
||||
rebuildStartTime time.Time
|
||||
rebuildLastItemID string
|
||||
rebuildLastChunks int
|
||||
}
|
||||
|
||||
// NewIndexer 创建新的索引器
|
||||
func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger, indexingCfg *config.IndexingConfig) *Indexer {
|
||||
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
|
||||
func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) {
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("db is nil")
|
||||
}
|
||||
if embedder == nil {
|
||||
return nil, fmt.Errorf("embedder is nil")
|
||||
}
|
||||
if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil {
|
||||
return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err)
|
||||
}
|
||||
if kcfg == nil {
|
||||
kcfg = &config.KnowledgeConfig{}
|
||||
}
|
||||
indexingCfg := &kcfg.Indexing
|
||||
|
||||
chunkSize := 512
|
||||
overlap := 50
|
||||
maxChunks := 0
|
||||
if indexingCfg != nil {
|
||||
if indexingCfg.ChunkSize > 0 {
|
||||
chunkSize = indexingCfg.ChunkSize
|
||||
}
|
||||
if indexingCfg.ChunkOverlap >= 0 {
|
||||
overlap = indexingCfg.ChunkOverlap
|
||||
}
|
||||
if indexingCfg.MaxChunksPerItem > 0 {
|
||||
maxChunks = indexingCfg.MaxChunksPerItem
|
||||
}
|
||||
if indexingCfg.ChunkSize > 0 {
|
||||
chunkSize = indexingCfg.ChunkSize
|
||||
}
|
||||
if indexingCfg.ChunkOverlap >= 0 {
|
||||
overlap = indexingCfg.ChunkOverlap
|
||||
}
|
||||
|
||||
embedModel := embedder.EmbeddingModelName()
|
||||
splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino recursive splitter: %w", err)
|
||||
}
|
||||
|
||||
chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("knowledge index chain: %w", err)
|
||||
}
|
||||
|
||||
var fl *fileloader.FileLoader
|
||||
fl, err = fileloader.NewFileLoader(ctx, nil)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err))
|
||||
}
|
||||
fl = nil
|
||||
err = nil
|
||||
}
|
||||
|
||||
return &Indexer{
|
||||
db: db,
|
||||
embedder: embedder,
|
||||
logger: logger,
|
||||
chunkSize: chunkSize,
|
||||
overlap: overlap,
|
||||
maxChunks: maxChunks,
|
||||
}
|
||||
db: db,
|
||||
embedder: embedder,
|
||||
logger: logger,
|
||||
chunkSize: chunkSize,
|
||||
overlap: overlap,
|
||||
indexingCfg: indexingCfg,
|
||||
indexChain: chain,
|
||||
fileLoader: fl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ChunkText 将文本分块(支持重叠,保留标题上下文)
|
||||
func (idx *Indexer) ChunkText(text string) []string {
|
||||
// 按 Markdown 标题分割,获取带标题的块
|
||||
sections := idx.splitByMarkdownHeadersWithContent(text)
|
||||
|
||||
// 处理每个块
|
||||
result := make([]string, 0)
|
||||
for _, section := range sections {
|
||||
// 构建父级标题路径(不包含最后一级标题,因为内容中已经包含)
|
||||
// 例如:["# A", "## B", "### C"] -> "[# A > ## B]"
|
||||
var parentHeaderPath string
|
||||
if len(section.HeaderPath) > 1 {
|
||||
parentHeaderPath = strings.Join(section.HeaderPath[:len(section.HeaderPath)-1], " > ")
|
||||
}
|
||||
|
||||
// 提取内容的第一行作为标题(如 "# Prompt Injection")
|
||||
firstLine, remainingContent := extractFirstLine(section.Content)
|
||||
|
||||
// 如果剩余内容为空或只有空白,说明这个块只有标题没有正文,跳过
|
||||
if strings.TrimSpace(remainingContent) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果块太大,进一步分割
|
||||
if idx.estimateTokens(section.Content) <= idx.chunkSize {
|
||||
// 块大小合适,添加父级标题前缀
|
||||
if parentHeaderPath != "" {
|
||||
result = append(result, fmt.Sprintf("[%s] %s", parentHeaderPath, section.Content))
|
||||
} else {
|
||||
result = append(result, section.Content)
|
||||
}
|
||||
} else {
|
||||
// 块太大,按子标题或段落分割,保持标题上下文
|
||||
// 首先尝试按子标题分割(保留子标题结构)
|
||||
subSections := idx.splitBySubHeaders(section.Content, firstLine, parentHeaderPath)
|
||||
if len(subSections) > 1 {
|
||||
// 成功按子标题分割,递归处理每个子块
|
||||
for _, sub := range subSections {
|
||||
if idx.estimateTokens(sub) <= idx.chunkSize {
|
||||
result = append(result, sub)
|
||||
} else {
|
||||
// 子块仍然太大,按段落分割(保留标题前缀)
|
||||
paragraphs := idx.splitByParagraphsWithHeader(sub, parentHeaderPath)
|
||||
for _, para := range paragraphs {
|
||||
if idx.estimateTokens(para) <= idx.chunkSize {
|
||||
result = append(result, para)
|
||||
} else {
|
||||
// 段落仍太大,按句子分割
|
||||
sentenceChunks := idx.splitBySentencesWithOverlap(para)
|
||||
for _, chunk := range sentenceChunks {
|
||||
result = append(result, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 没有子标题,按段落分割(保留标题前缀)
|
||||
paragraphs := idx.splitByParagraphsWithHeader(section.Content, parentHeaderPath)
|
||||
for _, para := range paragraphs {
|
||||
if idx.estimateTokens(para) <= idx.chunkSize {
|
||||
result = append(result, para)
|
||||
} else {
|
||||
// 段落仍太大,按句子分割
|
||||
sentenceChunks := idx.splitBySentencesWithOverlap(para)
|
||||
for _, chunk := range sentenceChunks {
|
||||
result = append(result, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。
|
||||
func (idx *Indexer) RecompileIndexChain(ctx context.Context) error {
|
||||
if idx == nil || idx.db == nil || idx.embedder == nil {
|
||||
return fmt.Errorf("indexer 未初始化")
|
||||
}
|
||||
|
||||
return result
|
||||
if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil {
|
||||
return err
|
||||
}
|
||||
embedModel := idx.embedder.EmbeddingModelName()
|
||||
splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("eino recursive splitter: %w", err)
|
||||
}
|
||||
chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("knowledge index chain: %w", err)
|
||||
}
|
||||
idx.indexChain = chain
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFirstLine 提取第一行内容和剩余内容
|
||||
func extractFirstLine(content string) (firstLine, remaining string) {
|
||||
lines := strings.SplitN(content, "\n", 2)
|
||||
if len(lines) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
if len(lines) == 1 {
|
||||
return lines[0], ""
|
||||
}
|
||||
return lines[0], lines[1]
|
||||
}
|
||||
|
||||
// splitBySubHeaders 尝试按子标题分割内容(用于处理大块内容)
|
||||
// headerPrefix 是父级标题路径,用于添加到每个子块
|
||||
func (idx *Indexer) splitBySubHeaders(content, headerPrefix, parentPath string) []string {
|
||||
// 匹配 Markdown 子标题(## 及以上)
|
||||
subHeaderRegex := regexp.MustCompile(`(?m)^#{2,6}\s+.+$`)
|
||||
matches := subHeaderRegex.FindAllStringIndex(content, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
// 没有子标题,返回原始内容
|
||||
return []string{content}
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(matches))
|
||||
for i, match := range matches {
|
||||
start := match[0]
|
||||
nextStart := len(content)
|
||||
if i+1 < len(matches) {
|
||||
nextStart = matches[i+1][0]
|
||||
}
|
||||
|
||||
subContent := strings.TrimSpace(content[start:nextStart])
|
||||
|
||||
// 添加父级路径前缀
|
||||
if parentPath != "" {
|
||||
result = append(result, fmt.Sprintf("[%s] %s", parentPath, subContent))
|
||||
} else {
|
||||
result = append(result, subContent)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// splitByParagraphsWithHeader 按段落分割,每个段落添加标题前缀(用于保持上下文)
|
||||
func (idx *Indexer) splitByParagraphsWithHeader(content, parentPath string) []string {
|
||||
// 提取第一行作为标题
|
||||
firstLine, _ := extractFirstLine(content)
|
||||
|
||||
paragraphs := strings.Split(content, "\n\n")
|
||||
result := make([]string, 0)
|
||||
|
||||
for i, p := range paragraphs {
|
||||
trimmed := strings.TrimSpace(p)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 过滤掉只有标题的段落(没有实际内容)
|
||||
if strings.TrimSpace(trimmed) == strings.TrimSpace(firstLine) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 第一个段落已经包含标题,不需要重复添加
|
||||
if i == 0 && strings.Contains(trimmed, firstLine) {
|
||||
if parentPath != "" {
|
||||
result = append(result, fmt.Sprintf("[%s] %s", parentPath, trimmed))
|
||||
} else {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
} else {
|
||||
// 其他段落添加标题前缀以保持上下文
|
||||
if parentPath != "" {
|
||||
result = append(result, fmt.Sprintf("[%s] %s\n%s", parentPath, firstLine, trimmed))
|
||||
} else {
|
||||
result = append(result, fmt.Sprintf("%s\n%s", firstLine, trimmed))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Section 表示一个带标题路径的文本块
|
||||
type Section struct {
|
||||
HeaderPath []string // 标题路径(如 ["# SQL 注入", "## 检测方法"])
|
||||
Content string // 块内容
|
||||
}
|
||||
|
||||
// splitByMarkdownHeadersWithContent 按 Markdown 标题分割,返回带标题路径的块
|
||||
// 每个块的内容包含自己的标题,用于向量化检索
|
||||
//
|
||||
// 例如,对于以下 Markdown:
|
||||
// # Prompt Injection
|
||||
// 引言内容
|
||||
// ## Summary
|
||||
// 目录内容
|
||||
//
|
||||
// 返回:
|
||||
// [{HeaderPath: ["# Prompt Injection"], Content: "# Prompt Injection\n引言内容"},
|
||||
// {HeaderPath: ["# Prompt Injection", "## Summary"], Content: "## Summary\n目录内容"}]
|
||||
func (idx *Indexer) splitByMarkdownHeadersWithContent(text string) []Section {
|
||||
// 匹配 Markdown 标题 (# ## ### 等)
|
||||
headerRegex := regexp.MustCompile(`(?m)^#{1,6}\s+.+$`)
|
||||
|
||||
// 找到所有标题位置
|
||||
matches := headerRegex.FindAllStringIndex(text, -1)
|
||||
if len(matches) == 0 {
|
||||
// 没有标题,返回整个文本
|
||||
return []Section{{HeaderPath: []string{}, Content: text}}
|
||||
}
|
||||
|
||||
sections := make([]Section, 0, len(matches))
|
||||
currentHeaderPath := []string{}
|
||||
|
||||
for i, match := range matches {
|
||||
start := match[0]
|
||||
end := match[1]
|
||||
nextStart := len(text)
|
||||
|
||||
// 找到下一个标题的位置
|
||||
if i+1 < len(matches) {
|
||||
nextStart = matches[i+1][0]
|
||||
}
|
||||
|
||||
// 提取当前标题
|
||||
headerLine := strings.TrimSpace(text[start:end])
|
||||
|
||||
// 计算标题层级(# 的数量)
|
||||
level := 0
|
||||
for _, ch := range headerLine {
|
||||
if ch == '#' {
|
||||
level++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 更新标题路径:移除比当前层级深或等于的子标题,然后添加当前标题
|
||||
newPath := make([]string, 0, len(currentHeaderPath)+1)
|
||||
for _, h := range currentHeaderPath {
|
||||
hLevel := 0
|
||||
for _, ch := range h {
|
||||
if ch == '#' {
|
||||
hLevel++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if hLevel < level {
|
||||
newPath = append(newPath, h)
|
||||
}
|
||||
}
|
||||
newPath = append(newPath, headerLine)
|
||||
currentHeaderPath = newPath
|
||||
|
||||
// 提取当前标题到下一个标题之间的内容(包含当前标题)
|
||||
content := strings.TrimSpace(text[start:nextStart])
|
||||
|
||||
// 创建块,使用当前标题路径(包含当前标题)
|
||||
sections = append(sections, Section{
|
||||
HeaderPath: append([]string(nil), currentHeaderPath...),
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
// 过滤空块
|
||||
result := make([]Section, 0, len(sections))
|
||||
for _, section := range sections {
|
||||
if strings.TrimSpace(section.Content) != "" {
|
||||
result = append(result, section)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return []Section{{HeaderPath: []string{}, Content: text}}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// splitByParagraphs 按段落分割
|
||||
func (idx *Indexer) splitByParagraphs(text string) []string {
|
||||
paragraphs := strings.Split(text, "\n\n")
|
||||
result := make([]string, 0)
|
||||
for _, p := range paragraphs {
|
||||
if strings.TrimSpace(p) != "" {
|
||||
result = append(result, strings.TrimSpace(p))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// splitBySentences 按句子分割(用于内部,不包含重叠逻辑)
|
||||
func (idx *Indexer) splitBySentences(text string) []string {
|
||||
// 简单的句子分割(按句号、问号、感叹号,支持中英文)
|
||||
// . ! ? = 英文标点
|
||||
// \u3002 = 。(中文句号)
|
||||
// \uFF01 = !(中文叹号)
|
||||
// \uFF1F = ?(中文问号)
|
||||
sentenceRegex := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`)
|
||||
sentences := sentenceRegex.Split(text, -1)
|
||||
result := make([]string, 0)
|
||||
for _, s := range sentences {
|
||||
if strings.TrimSpace(s) != "" {
|
||||
result = append(result, strings.TrimSpace(s))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// splitBySentencesWithOverlap 按句子分割并应用重叠策略
|
||||
func (idx *Indexer) splitBySentencesWithOverlap(text string) []string {
|
||||
if idx.overlap <= 0 {
|
||||
// 如果没有重叠,使用简单分割
|
||||
return idx.splitBySentencesSimple(text)
|
||||
}
|
||||
|
||||
sentences := idx.splitBySentences(text)
|
||||
if len(sentences) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
currentChunk := ""
|
||||
|
||||
for _, sentence := range sentences {
|
||||
testChunk := currentChunk
|
||||
if testChunk != "" {
|
||||
testChunk += "\n"
|
||||
}
|
||||
testChunk += sentence
|
||||
|
||||
testTokens := idx.estimateTokens(testChunk)
|
||||
|
||||
if testTokens > idx.chunkSize && currentChunk != "" {
|
||||
// 当前块已达到大小限制,保存它
|
||||
result = append(result, currentChunk)
|
||||
|
||||
// 从当前块的末尾提取重叠部分
|
||||
overlapText := idx.extractLastTokens(currentChunk, idx.overlap)
|
||||
if overlapText != "" {
|
||||
// 如果有重叠内容,作为下一个块的起始
|
||||
currentChunk = overlapText + "\n" + sentence
|
||||
} else {
|
||||
// 如果无法提取足够的重叠内容,直接使用当前句子
|
||||
currentChunk = sentence
|
||||
}
|
||||
} else {
|
||||
currentChunk = testChunk
|
||||
}
|
||||
}
|
||||
|
||||
// 添加最后一个块
|
||||
if strings.TrimSpace(currentChunk) != "" {
|
||||
result = append(result, currentChunk)
|
||||
}
|
||||
|
||||
// 过滤空块
|
||||
filtered := make([]string, 0)
|
||||
for _, chunk := range result {
|
||||
if strings.TrimSpace(chunk) != "" {
|
||||
filtered = append(filtered, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// splitBySentencesSimple 按句子分割(简单版本,无重叠)
|
||||
func (idx *Indexer) splitBySentencesSimple(text string) []string {
|
||||
sentences := idx.splitBySentences(text)
|
||||
result := make([]string, 0)
|
||||
currentChunk := ""
|
||||
|
||||
for _, sentence := range sentences {
|
||||
testChunk := currentChunk
|
||||
if testChunk != "" {
|
||||
testChunk += "\n"
|
||||
}
|
||||
testChunk += sentence
|
||||
|
||||
if idx.estimateTokens(testChunk) > idx.chunkSize && currentChunk != "" {
|
||||
result = append(result, currentChunk)
|
||||
currentChunk = sentence
|
||||
} else {
|
||||
currentChunk = testChunk
|
||||
}
|
||||
}
|
||||
if currentChunk != "" {
|
||||
result = append(result, currentChunk)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractLastTokens 从文本末尾提取指定 token 数量的内容
|
||||
func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
|
||||
if tokenCount <= 0 || text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 估算字符数(1 token ≈ 4 字符)
|
||||
charCount := tokenCount * 4
|
||||
runes := []rune(text)
|
||||
|
||||
if len(runes) <= charCount {
|
||||
return text
|
||||
}
|
||||
|
||||
// 从末尾提取指定数量的字符
|
||||
startPos := len(runes) - charCount
|
||||
extracted := string(runes[startPos:])
|
||||
|
||||
// 尝试找到第一个句子边界(支持中英文标点)
|
||||
sentenceBoundary := regexp.MustCompile(`[.!?\x{3002}\x{FF01}\x{FF1F}]+`)
|
||||
matches := sentenceBoundary.FindStringIndex(extracted)
|
||||
if len(matches) > 0 && matches[0] > 0 {
|
||||
// 在句子边界处截断,保留完整句子
|
||||
extracted = extracted[matches[0]:]
|
||||
}
|
||||
|
||||
return strings.TrimSpace(extracted)
|
||||
}
|
||||
|
||||
// estimateTokens 估算 token 数(简单估算:1 token ≈ 4 字符)
|
||||
func (idx *Indexer) estimateTokens(text string) int {
|
||||
return len([]rune(text)) / 4
|
||||
}
|
||||
|
||||
// IndexItem 索引知识项(分块并向量化)
|
||||
// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。
|
||||
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
// 获取知识项(包含 category 和 title,用于向量化)
|
||||
var content, category, title string
|
||||
err := idx.db.QueryRow("SELECT content, category, title FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title)
|
||||
if idx.indexChain == nil {
|
||||
return fmt.Errorf("索引链未初始化")
|
||||
}
|
||||
if idx.embedder == nil {
|
||||
return fmt.Errorf("嵌入器未初始化")
|
||||
}
|
||||
|
||||
var content, category, title, filePath string
|
||||
err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取知识项失败:%w", err)
|
||||
}
|
||||
|
||||
// 删除旧的向量(在 RebuildIndex 中已经统一清空,这里保留是为了单独调用 IndexItem 时的兼容性)
|
||||
_, err = idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID)
|
||||
if err != nil {
|
||||
if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil {
|
||||
return fmt.Errorf("删除旧向量失败:%w", err)
|
||||
}
|
||||
|
||||
// 分块
|
||||
chunks := idx.ChunkText(content)
|
||||
|
||||
// 应用最大块数限制
|
||||
if idx.maxChunks > 0 && len(chunks) > idx.maxChunks {
|
||||
idx.logger.Info("知识项块数量超过限制,已截断",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("originalChunks", len(chunks)),
|
||||
zap.Int("maxChunks", idx.maxChunks))
|
||||
chunks = chunks[:idx.maxChunks]
|
||||
}
|
||||
|
||||
idx.logger.Info("知识项分块完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
|
||||
|
||||
// 跟踪该知识项的错误
|
||||
itemErrorCount := 0
|
||||
var firstError error
|
||||
firstErrorChunkIndex := -1
|
||||
|
||||
// 向量化每个块(包含 category 和 title 信息,以便向量检索时能匹配到风险类型)
|
||||
for i, chunk := range chunks {
|
||||
// 将 category 和 title 信息包含到向量化的文本中
|
||||
// 格式:"[风险类型:{category}] [标题:{title}]\n{chunk 内容}"
|
||||
// 这样向量嵌入就会包含风险类型信息,即使 SQL 过滤失败,向量相似度也能帮助匹配
|
||||
textForEmbedding := fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunk)
|
||||
|
||||
embedding, err := idx.embedder.EmbedText(ctx, textForEmbedding)
|
||||
if err != nil {
|
||||
itemErrorCount++
|
||||
if firstError == nil {
|
||||
firstError = err
|
||||
firstErrorChunkIndex = i
|
||||
// 只在第一个块失败时记录详细日志
|
||||
chunkPreview := chunk
|
||||
if len(chunkPreview) > 200 {
|
||||
chunkPreview = chunkPreview[:200] + "..."
|
||||
body := strings.TrimSpace(content)
|
||||
if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil {
|
||||
docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)})
|
||||
if lerr == nil && len(docs) > 0 {
|
||||
var b strings.Builder
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
idx.logger.Warn("向量化失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("chunkIndex", i),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.String("chunkPreview", chunkPreview),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
// 更新全局错误跟踪
|
||||
errorMsg := fmt.Sprintf("向量化失败 (知识项:%s): %v", itemID, err)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
if i > 0 {
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString(d.Content)
|
||||
}
|
||||
|
||||
// 如果连续失败 5 个块,立即停止处理该知识项
|
||||
// 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题
|
||||
// 对于大文档(超过 10 个块),允许失败比例不超过 50%
|
||||
maxConsecutiveFailures := 5
|
||||
if len(chunks) > 10 && itemErrorCount > len(chunks)/2 {
|
||||
idx.logger.Error("知识项向量化失败比例过高,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.Int("failedChunks", itemErrorCount),
|
||||
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
|
||||
zap.Error(firstError),
|
||||
)
|
||||
return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError)
|
||||
if s := strings.TrimSpace(b.String()); s != "" {
|
||||
body = s
|
||||
}
|
||||
if itemErrorCount >= maxConsecutiveFailures {
|
||||
idx.logger.Error("知识项连续向量化失败,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.Int("failedChunks", itemErrorCount),
|
||||
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
|
||||
zap.Error(firstError),
|
||||
)
|
||||
return fmt.Errorf("知识项连续向量化失败 (%d个块失败): %v", itemErrorCount, firstError)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 保存向量
|
||||
chunkID := uuid.New().String()
|
||||
embeddingJSON, _ := json.Marshal(embedding)
|
||||
|
||||
_, err = idx.db.Exec(
|
||||
"INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, created_at) VALUES (?, ?, ?, ?, ?, datetime('now'))",
|
||||
chunkID, itemID, i, chunk, string(embeddingJSON),
|
||||
)
|
||||
if err != nil {
|
||||
idx.logger.Warn("保存向量失败", zap.String("itemId", itemID), zap.Int("chunkIndex", i), zap.Error(err))
|
||||
continue
|
||||
} else if idx.logger != nil {
|
||||
idx.logger.Warn("优先源文件读取失败,使用数据库正文",
|
||||
zap.String("itemId", itemID),
|
||||
zap.String("path", filePath),
|
||||
zap.Error(lerr))
|
||||
}
|
||||
}
|
||||
|
||||
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(chunks)))
|
||||
root := &schema.Document{
|
||||
ID: itemID,
|
||||
Content: body,
|
||||
MetaData: map[string]any{
|
||||
metaKBCategory: category,
|
||||
metaKBTitle: title,
|
||||
metaKBItemID: itemID,
|
||||
},
|
||||
}
|
||||
|
||||
// 更新重建状态中的最近处理信息
|
||||
idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())}
|
||||
if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 {
|
||||
idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes))
|
||||
}
|
||||
|
||||
ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...))
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = msg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
if idx.logger != nil {
|
||||
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids)))
|
||||
}
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildLastItemID = itemID
|
||||
idx.rebuildLastChunks = len(chunks)
|
||||
idx.rebuildLastChunks = len(ids)
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -608,7 +215,6 @@ func (idx *Indexer) HasIndex() (bool, error) {
|
||||
|
||||
// RebuildIndex 重建所有索引
|
||||
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
// 设置重建状态
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = true
|
||||
idx.rebuildTotalItems = 0
|
||||
@@ -619,7 +225,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
idx.rebuildLastChunks = 0
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
// 重置错误跟踪
|
||||
idx.mu.Lock()
|
||||
idx.lastError = ""
|
||||
idx.lastErrorTime = time.Time{}
|
||||
@@ -628,7 +233,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
|
||||
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
|
||||
if err != nil {
|
||||
// 重置重建状态
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
@@ -640,7 +244,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
// 重置重建状态
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
@@ -655,13 +258,9 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
|
||||
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
|
||||
|
||||
// 注意:不再清空所有旧索引,而是按增量方式更新
|
||||
// 每个知识项在 IndexItem 中会先删除自己的旧向量,然后插入新向量
|
||||
// 这样配置更新后只重新索引变化的知识项,保留其他知识项的索引
|
||||
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误)
|
||||
maxConsecutiveFailures := 5
|
||||
firstFailureItemID := ""
|
||||
var firstFailureError error
|
||||
|
||||
@@ -670,7 +269,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
// 只在第一个失败时记录详细日志
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
@@ -681,7 +279,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
)
|
||||
}
|
||||
|
||||
// 如果连续失败过多,可能是配置问题,立即停止索引
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
@@ -699,7 +296,6 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError)
|
||||
}
|
||||
|
||||
// 如果失败的知识项过多,记录警告但继续处理(降低阈值到 30%)
|
||||
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
|
||||
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
@@ -717,26 +313,22 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// 成功时重置连续失败计数和第一个失败信息
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
|
||||
// 更新重建进度
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildCurrent = i + 1
|
||||
idx.rebuildFailed = failedCount
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
// 减少进度日志频率(每 10 个或每 10% 记录一次)
|
||||
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
|
||||
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
|
||||
}
|
||||
}
|
||||
|
||||
// 重置重建状态
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
|
||||
const postRetrieveMaxPrefetchCap = 200
|
||||
|
||||
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
|
||||
type DocumentReranker interface {
|
||||
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
|
||||
}
|
||||
|
||||
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
|
||||
type NopDocumentReranker struct{}
|
||||
|
||||
// Rerank implements [DocumentReranker] as no-op.
|
||||
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var tiktokenEncMu sync.Mutex
|
||||
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
|
||||
|
||||
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
|
||||
m := strings.TrimSpace(model)
|
||||
if m == "" {
|
||||
m = "gpt-4"
|
||||
}
|
||||
tiktokenEncMu.Lock()
|
||||
defer tiktokenEncMu.Unlock()
|
||||
if enc, ok := tiktokenEncCache[m]; ok {
|
||||
return enc, nil
|
||||
}
|
||||
enc, err := tiktoken.EncodingForModel(m)
|
||||
if err != nil {
|
||||
enc, err = tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
tiktokenEncCache[m] = enc
|
||||
return enc, nil
|
||||
}
|
||||
|
||||
func countDocTokens(text, model string) (int, error) {
|
||||
enc, err := encodingForTokenizerModel(model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
toks := enc.Encode(text, nil, nil)
|
||||
return len(toks), nil
|
||||
}
|
||||
|
||||
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
|
||||
func normalizeContentFingerprintKey(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
prevSpace := false
|
||||
for _, r := range s {
|
||||
if unicode.IsSpace(r) {
|
||||
if !prevSpace {
|
||||
b.WriteByte(' ')
|
||||
prevSpace = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
prevSpace = false
|
||||
b.WriteRune(r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func contentNormKey(d *schema.Document) string {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
n := normalizeContentFingerprintKey(d.Content)
|
||||
if n == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(n))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
|
||||
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
|
||||
if len(docs) < 2 {
|
||||
return docs
|
||||
}
|
||||
seen := make(map[string]struct{}, len(docs))
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
k := contentNormKey(d)
|
||||
if k == "" {
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[k]; ok {
|
||||
continue
|
||||
}
|
||||
seen[k] = struct{}{}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
|
||||
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
unlimitedChars := maxRunes <= 0
|
||||
unlimitedTok := maxTokens <= 0
|
||||
if unlimitedChars && unlimitedTok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
remRunes := maxRunes
|
||||
remTok := maxTokens
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
|
||||
for _, d := range docs {
|
||||
if d == nil || strings.TrimSpace(d.Content) == "" {
|
||||
continue
|
||||
}
|
||||
runes := utf8.RuneCountInString(d.Content)
|
||||
if !unlimitedChars && runes > remRunes {
|
||||
break
|
||||
}
|
||||
var tok int
|
||||
var err error
|
||||
if !unlimitedTok {
|
||||
tok, err = countDocTokens(d.Content, tokenModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token count: %w", err)
|
||||
}
|
||||
if tok > remTok {
|
||||
break
|
||||
}
|
||||
}
|
||||
out = append(out, d)
|
||||
if !unlimitedChars {
|
||||
remRunes -= runes
|
||||
}
|
||||
if !unlimitedTok {
|
||||
remTok -= tok
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
|
||||
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
if topK < 1 {
|
||||
topK = 5
|
||||
}
|
||||
fetch := topK
|
||||
if po != nil && po.PrefetchTopK > fetch {
|
||||
fetch = po.PrefetchTopK
|
||||
}
|
||||
if fetch > postRetrieveMaxPrefetchCap {
|
||||
fetch = postRetrieveMaxPrefetchCap
|
||||
}
|
||||
return fetch
|
||||
}
|
||||
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
|
||||
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
|
||||
if finalTopK < 1 {
|
||||
finalTopK = 5
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
maxChars := 0
|
||||
maxTok := 0
|
||||
if po != nil {
|
||||
maxChars = po.MaxContextChars
|
||||
maxTok = po.MaxContextTokens
|
||||
}
|
||||
|
||||
out := dedupeByNormalizedContent(docs)
|
||||
|
||||
var err error
|
||||
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(out) > finalTopK {
|
||||
out = out[:finalTopK]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func doc(id, content string, score float64) *schema.Document {
|
||||
d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}}
|
||||
d.WithScore(score)
|
||||
return d
|
||||
}
|
||||
|
||||
func TestDedupeByNormalizedContent(t *testing.T) {
|
||||
a := doc("1", "hello world", 0.9)
|
||||
b := doc("2", "hello world", 0.8)
|
||||
c := doc("3", "other", 0.7)
|
||||
out := dedupeByNormalizedContent([]*schema.Document{a, b, c})
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("len=%d want 2", len(out))
|
||||
}
|
||||
if out[0].ID != "1" || out[1].ID != "3" {
|
||||
t.Fatalf("order/ids wrong: %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectivePrefetchTopK(t *testing.T) {
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 5 {
|
||||
t.Fatalf("got %d", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
|
||||
t.Fatalf("got %d", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap {
|
||||
t.Fatalf("cap: got %d", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) {
|
||||
d1 := doc("1", "ab", 0.9)
|
||||
d2 := doc("2", "cd", 0.8)
|
||||
d3 := doc("3", "ef", 0.7)
|
||||
po := &config.PostRetrieveConfig{MaxContextChars: 3}
|
||||
out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 1 || out[0].ID != "1" {
|
||||
t.Fatalf("got %#v", out)
|
||||
}
|
||||
|
||||
out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out2) != 2 {
|
||||
t.Fatalf("topk: len=%d", len(out2))
|
||||
}
|
||||
}
|
||||
+174
-545
@@ -8,23 +8,34 @@ import (
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Retriever 检索器
|
||||
// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值),
|
||||
// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。
|
||||
type Retriever struct {
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
config *RetrievalConfig
|
||||
logger *zap.Logger
|
||||
|
||||
rerankMu sync.RWMutex
|
||||
reranker DocumentReranker
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int
|
||||
SimilarityThreshold float64
|
||||
HybridWeight float64
|
||||
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
|
||||
SubIndexFilter string
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
}
|
||||
|
||||
// NewRetriever 创建新的检索器
|
||||
@@ -38,18 +49,41 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge
|
||||
}
|
||||
|
||||
// UpdateConfig 更新检索配置
|
||||
func (r *Retriever) UpdateConfig(config *RetrievalConfig) {
|
||||
if config != nil {
|
||||
r.config = config
|
||||
r.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", config.TopK),
|
||||
zap.Float64("similarity_threshold", config.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", config.HybridWeight),
|
||||
)
|
||||
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
if cfg != nil {
|
||||
r.config = cfg
|
||||
if r.logger != nil {
|
||||
r.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", cfg.TopK),
|
||||
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
|
||||
zap.String("sub_index_filter", cfg.SubIndexFilter),
|
||||
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
|
||||
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
|
||||
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cosineSimilarity 计算余弦相似度
|
||||
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
|
||||
func (r *Retriever) SetDocumentReranker(rr DocumentReranker) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.rerankMu.Lock()
|
||||
defer r.rerankMu.Unlock()
|
||||
r.reranker = rr
|
||||
}
|
||||
|
||||
func (r *Retriever) documentReranker() DocumentReranker {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
r.rerankMu.RLock()
|
||||
defer r.rerankMu.RUnlock()
|
||||
return r.reranker
|
||||
}
|
||||
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) {
|
||||
return 0.0
|
||||
@@ -69,608 +103,203 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// bm25Score 计算 BM25 分数(带缓存的改进版本)
|
||||
// 注意:由于缺少全局文档统计,使用简化 IDF 计算
|
||||
func (r *Retriever) bm25Score(query, text string) float64 {
|
||||
queryTerms := strings.Fields(strings.ToLower(query))
|
||||
if len(queryTerms) == 0 {
|
||||
return 0.0
|
||||
// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。
|
||||
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("请求不能为空")
|
||||
}
|
||||
|
||||
textLower := strings.ToLower(text)
|
||||
textTerms := strings.Fields(textLower)
|
||||
if len(textTerms) == 0 {
|
||||
return 0.0
|
||||
q := strings.TrimSpace(req.Query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
// BM25 参数(标准值)
|
||||
k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0)
|
||||
b := 0.75 // 长度归一化参数(标准值)
|
||||
avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小)
|
||||
docLength := float64(len(textTerms))
|
||||
|
||||
// 计算词频映射
|
||||
textTermFreq := make(map[string]int, len(textTerms))
|
||||
for _, term := range textTerms {
|
||||
textTermFreq[term]++
|
||||
opts := r.einoRetrieverOptions(req)
|
||||
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
score := 0.0
|
||||
matchedQueryTerms := 0
|
||||
|
||||
for _, term := range queryTerms {
|
||||
termFreq, exists := textTermFreq[term]
|
||||
if !exists || termFreq == 0 {
|
||||
continue
|
||||
}
|
||||
matchedQueryTerms++
|
||||
|
||||
// BM25 TF 计算公式
|
||||
tf := float64(termFreq)
|
||||
lengthNorm := 1 - b + b*(docLength/avgDocLength)
|
||||
tfScore := tf / (tf + k1*lengthNorm)
|
||||
|
||||
// 改进的 IDF 计算:使用词长度和出现频率估算
|
||||
// 短词(2-3 字符)通常更重要,长词 IDF 略低
|
||||
idfWeight := 1.0
|
||||
termLen := len(term)
|
||||
if termLen <= 2 {
|
||||
// 极短词(如 go, js)给予更高权重
|
||||
idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0)
|
||||
} else if termLen <= 4 {
|
||||
// 短词(4 字符)标准权重
|
||||
idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0)
|
||||
} else {
|
||||
// 长词稍微降低权重
|
||||
idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0)
|
||||
}
|
||||
|
||||
score += tfScore * idfWeight
|
||||
}
|
||||
|
||||
// 归一化:考虑匹配的查询词比例
|
||||
if len(queryTerms) > 0 {
|
||||
// 使用匹配比例作为额外因子
|
||||
matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms))
|
||||
score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2
|
||||
}
|
||||
|
||||
return math.Min(score, 1.0)
|
||||
return documentsToRetrievalResults(docs)
|
||||
}
|
||||
|
||||
// Search 搜索知识库
|
||||
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option {
|
||||
var opts []retriever.Option
|
||||
if req.TopK > 0 {
|
||||
opts = append(opts, retriever.WithTopK(req.TopK))
|
||||
}
|
||||
dsl := map[string]any{}
|
||||
if strings.TrimSpace(req.RiskType) != "" {
|
||||
dsl[DSLRiskType] = strings.TrimSpace(req.RiskType)
|
||||
}
|
||||
if req.Threshold > 0 {
|
||||
dsl[DSLSimilarityThreshold] = req.Threshold
|
||||
}
|
||||
if strings.TrimSpace(req.SubIndexFilter) != "" {
|
||||
dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter)
|
||||
}
|
||||
if len(dsl) > 0 {
|
||||
opts = append(opts, retriever.WithDSLInfo(dsl))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
|
||||
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
|
||||
}
|
||||
|
||||
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
|
||||
q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE 1=1`
|
||||
var args []interface{}
|
||||
if strings.TrimSpace(riskType) != "" {
|
||||
q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE`
|
||||
args = append(args, riskType)
|
||||
}
|
||||
if tag := strings.TrimSpace(subIndexFilter); tag != "" {
|
||||
tag = strings.ToLower(strings.ReplaceAll(tag, " ", ""))
|
||||
q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)`
|
||||
args = append(args, tag)
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。
|
||||
func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req.Query == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK <= 0 {
|
||||
if topK <= 0 && r.config != nil {
|
||||
topK = r.config.TopK
|
||||
}
|
||||
if topK == 0 {
|
||||
if topK <= 0 {
|
||||
topK = 5
|
||||
}
|
||||
|
||||
threshold := req.Threshold
|
||||
if threshold <= 0 {
|
||||
if threshold <= 0 && r.config != nil {
|
||||
threshold = r.config.SimilarityThreshold
|
||||
}
|
||||
if threshold == 0 {
|
||||
if threshold <= 0 {
|
||||
threshold = 0.7
|
||||
}
|
||||
|
||||
// 向量化查询(如果提供了risk_type,也包含在查询文本中,以便更好地匹配)
|
||||
queryText := req.Query
|
||||
if req.RiskType != "" {
|
||||
// 将risk_type信息包含到查询中,格式与索引时保持一致
|
||||
queryText = fmt.Sprintf("[风险类型: %s] %s", req.RiskType, req.Query)
|
||||
subIdxFilter := strings.TrimSpace(req.SubIndexFilter)
|
||||
if subIdxFilter == "" && r.config != nil {
|
||||
subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter)
|
||||
}
|
||||
|
||||
queryText := FormatQueryEmbeddingText(req.RiskType, req.Query)
|
||||
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量化查询失败: %w", err)
|
||||
}
|
||||
|
||||
// 查询所有向量(或按风险类型过滤)
|
||||
// 使用精确匹配(=)以提高性能和准确性
|
||||
// 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
|
||||
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
|
||||
var rows *sql.Rows
|
||||
if req.RiskType != "" {
|
||||
// 使用精确匹配(=),性能更好且更准确
|
||||
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
|
||||
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
|
||||
// 建议用户先调用相应的内置工具获取准确的category名称
|
||||
rows, err = r.db.Query(`
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE
|
||||
`, req.RiskType)
|
||||
} else {
|
||||
rows, err = r.db.Query(`
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
`)
|
||||
queryDim := len(queryEmbedding)
|
||||
expectedModel := ""
|
||||
if r.embedder != nil {
|
||||
expectedModel = r.embedder.EmbeddingModelName()
|
||||
}
|
||||
|
||||
sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter)
|
||||
rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询向量失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 计算相似度
|
||||
type candidate struct {
|
||||
chunk *KnowledgeChunk
|
||||
item *KnowledgeItem
|
||||
similarity float64
|
||||
bm25Score float64
|
||||
hasStrongKeywordMatch bool
|
||||
hybridScore float64 // 混合分数,用于最终排序
|
||||
chunk *KnowledgeChunk
|
||||
item *KnowledgeItem
|
||||
similarity float64
|
||||
}
|
||||
|
||||
candidates := make([]candidate, 0)
|
||||
|
||||
rowNum := 0
|
||||
for rows.Next() {
|
||||
var chunkID, itemID, chunkText, embeddingJSON, category, title string
|
||||
var chunkIndex int
|
||||
rowNum++
|
||||
if rowNum%48 == 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &category, &title); err != nil {
|
||||
var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string
|
||||
var chunkIndex, rowDim int
|
||||
|
||||
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil {
|
||||
r.logger.Warn("扫描向量失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析向量
|
||||
var embedding []float32
|
||||
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
|
||||
r.logger.Warn("解析向量失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算余弦相似度
|
||||
similarity := cosineSimilarity(queryEmbedding, embedding)
|
||||
|
||||
// 计算BM25分数(考虑chunk文本、category和title)
|
||||
// category和title是结构化字段,完全匹配时应该被优先考虑
|
||||
chunkBM25 := r.bm25Score(req.Query, chunkText)
|
||||
categoryBM25 := r.bm25Score(req.Query, category)
|
||||
titleBM25 := r.bm25Score(req.Query, title)
|
||||
|
||||
// 检查category或title是否有显著匹配(这对于结构化字段很重要)
|
||||
hasStrongKeywordMatch := categoryBM25 > 0.3 || titleBM25 > 0.3
|
||||
|
||||
// 综合BM25分数(用于后续排序)
|
||||
bm25Score := math.Max(math.Max(chunkBM25, categoryBM25), titleBM25)
|
||||
|
||||
// 收集所有候选(先不严格过滤,以便后续智能处理跨语言情况)
|
||||
// 只过滤掉相似度极低的结果(< 0.1),避免噪音
|
||||
if similarity < 0.1 {
|
||||
if rowDim > 0 && len(embedding) != rowDim {
|
||||
r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding)))
|
||||
continue
|
||||
}
|
||||
if queryDim > 0 && len(embedding) != queryDim {
|
||||
r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding)))
|
||||
continue
|
||||
}
|
||||
if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel {
|
||||
r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel))
|
||||
continue
|
||||
}
|
||||
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: chunkID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIndex,
|
||||
ChunkText: chunkText,
|
||||
Embedding: embedding,
|
||||
}
|
||||
|
||||
item := &KnowledgeItem{
|
||||
ID: itemID,
|
||||
Category: category,
|
||||
Title: title,
|
||||
}
|
||||
|
||||
similarity := cosineSimilarity(queryEmbedding, embedding)
|
||||
candidates = append(candidates, candidate{
|
||||
chunk: chunk,
|
||||
item: item,
|
||||
similarity: similarity,
|
||||
bm25Score: bm25Score,
|
||||
hasStrongKeywordMatch: hasStrongKeywordMatch,
|
||||
chunk: &KnowledgeChunk{
|
||||
ID: chunkID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIndex,
|
||||
ChunkText: chunkText,
|
||||
Embedding: embedding,
|
||||
},
|
||||
item: &KnowledgeItem{
|
||||
ID: itemID,
|
||||
Category: category,
|
||||
Title: title,
|
||||
},
|
||||
similarity: similarity,
|
||||
})
|
||||
}
|
||||
|
||||
// 先按相似度排序(使用更高效的排序)
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].similarity > candidates[j].similarity
|
||||
})
|
||||
|
||||
// 智能过滤策略:优先保留关键词匹配的结果,对跨语言查询使用更宽松的阈值
|
||||
filteredCandidates := make([]candidate, 0)
|
||||
|
||||
// 检查是否有任何关键词匹配(用于判断是否是跨语言查询)
|
||||
hasAnyKeywordMatch := false
|
||||
for _, cand := range candidates {
|
||||
if cand.hasStrongKeywordMatch {
|
||||
hasAnyKeywordMatch = true
|
||||
break
|
||||
filtered := make([]candidate, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
if c.similarity >= threshold {
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查最高相似度,用于判断是否确实有相关内容
|
||||
maxSimilarity := 0.0
|
||||
if len(candidates) > 0 {
|
||||
maxSimilarity = candidates[0].similarity
|
||||
if len(filtered) > topK {
|
||||
filtered = filtered[:topK]
|
||||
}
|
||||
|
||||
// 应用智能过滤
|
||||
// 如果用户设置了高阈值(>=0.8),更严格地遵守阈值,减少自动放宽
|
||||
strictMode := threshold >= 0.8
|
||||
|
||||
// 根据是否有关键词匹配,采用不同的阈值策略
|
||||
// 严格模式下,禁用跨语言放宽策略,严格遵守用户设置的阈值
|
||||
effectiveThreshold := threshold
|
||||
if !strictMode && !hasAnyKeywordMatch {
|
||||
// 非严格模式下,没有关键词匹配,可能是跨语言查询,适度放宽阈值
|
||||
// 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性
|
||||
// 跨语言阈值设为0.6,确保返回的结果至少有一定相关性
|
||||
effectiveThreshold = math.Max(threshold*0.85, 0.6)
|
||||
r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值",
|
||||
zap.Float64("originalThreshold", threshold),
|
||||
zap.Float64("effectiveThreshold", effectiveThreshold),
|
||||
)
|
||||
} else if strictMode {
|
||||
// 严格模式下,即使没有关键词匹配,也严格遵守阈值
|
||||
r.logger.Debug("严格模式:严格遵守用户设置的阈值",
|
||||
zap.Float64("threshold", threshold),
|
||||
zap.Bool("hasKeywordMatch", hasAnyKeywordMatch),
|
||||
)
|
||||
}
|
||||
for _, cand := range candidates {
|
||||
if cand.similarity >= effectiveThreshold {
|
||||
// 达到阈值,直接通过
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
} else if !strictMode && cand.hasStrongKeywordMatch {
|
||||
// 非严格模式下,有关键词匹配但相似度略低于阈值,适当放宽
|
||||
// 严格模式下,即使有关键词匹配,也严格遵守阈值
|
||||
relaxedThreshold := math.Max(effectiveThreshold*0.85, 0.55)
|
||||
if cand.similarity >= relaxedThreshold {
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
}
|
||||
}
|
||||
// 如果既没有关键词匹配,相似度又低于阈值,则过滤掉
|
||||
}
|
||||
|
||||
// 智能兜底策略:只有在最高相似度达到合理水平时,才考虑返回结果
|
||||
// 如果最高相似度都很低(<0.55),说明确实没有相关内容,应该返回空
|
||||
// 严格模式下(阈值>=0.8),禁用兜底策略,严格遵守用户设置的阈值
|
||||
if len(filteredCandidates) == 0 && len(candidates) > 0 && !strictMode {
|
||||
// 即使没有通过阈值过滤,如果最高相似度还可以(>=0.55),可以考虑返回Top-K
|
||||
// 但这是最后的兜底,只在确实有一定相关性时才使用
|
||||
// 严格模式下不使用兜底策略
|
||||
minAcceptableSimilarity := 0.55
|
||||
if maxSimilarity >= minAcceptableSimilarity {
|
||||
r.logger.Debug("过滤后无结果,但最高相似度可接受,返回Top-K结果",
|
||||
zap.Int("totalCandidates", len(candidates)),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
zap.Float64("effectiveThreshold", effectiveThreshold),
|
||||
)
|
||||
maxResults := topK
|
||||
if len(candidates) < maxResults {
|
||||
maxResults = len(candidates)
|
||||
}
|
||||
// 只返回相似度 >= 0.55 的结果
|
||||
for _, cand := range candidates {
|
||||
if cand.similarity >= minAcceptableSimilarity && len(filteredCandidates) < maxResults {
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
r.logger.Debug("过滤后无结果,且最高相似度过低,返回空结果",
|
||||
zap.Int("totalCandidates", len(candidates)),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
zap.Float64("minAcceptableSimilarity", minAcceptableSimilarity),
|
||||
)
|
||||
}
|
||||
} else if len(filteredCandidates) == 0 && strictMode {
|
||||
// 严格模式下,如果过滤后无结果,直接返回空,不使用兜底策略
|
||||
r.logger.Debug("严格模式:过滤后无结果,严格遵守阈值,返回空结果",
|
||||
zap.Float64("threshold", threshold),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
)
|
||||
}
|
||||
|
||||
// 统一在最终返回前严格限制 Top-K 数量
|
||||
if len(filteredCandidates) > topK {
|
||||
// 如果过滤后结果太多,只取Top-K
|
||||
filteredCandidates = filteredCandidates[:topK]
|
||||
}
|
||||
|
||||
candidates = filteredCandidates
|
||||
|
||||
// 混合排序(向量相似度 + BM25)
|
||||
// 注意:hybridWeight可以是0.0(纯关键词检索),所以不设置默认值
|
||||
// 如果配置文件中未设置,应该在配置加载时使用默认值
|
||||
hybridWeight := r.config.HybridWeight
|
||||
// 如果未设置,使用默认值0.7(偏重向量检索)
|
||||
if hybridWeight < 0 || hybridWeight > 1 {
|
||||
r.logger.Warn("混合权重超出范围,使用默认值0.7",
|
||||
zap.Float64("provided", hybridWeight))
|
||||
hybridWeight = 0.7
|
||||
}
|
||||
|
||||
// 先计算混合分数并存储在candidate中,用于排序
|
||||
for i := range candidates {
|
||||
normalizedBM25 := math.Min(candidates[i].bm25Score, 1.0)
|
||||
candidates[i].hybridScore = hybridWeight*candidates[i].similarity + (1-hybridWeight)*normalizedBM25
|
||||
|
||||
// 调试日志:记录前几个候选的分数计算(仅在debug级别)
|
||||
if i < 3 {
|
||||
r.logger.Debug("混合分数计算",
|
||||
zap.Int("index", i),
|
||||
zap.Float64("similarity", candidates[i].similarity),
|
||||
zap.Float64("bm25Score", candidates[i].bm25Score),
|
||||
zap.Float64("normalizedBM25", normalizedBM25),
|
||||
zap.Float64("hybridWeight", hybridWeight),
|
||||
zap.Float64("hybridScore", candidates[i].hybridScore))
|
||||
}
|
||||
}
|
||||
|
||||
// 根据混合分数重新排序(这才是真正的混合检索)
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].hybridScore > candidates[j].hybridScore
|
||||
})
|
||||
|
||||
// 转换为结果
|
||||
results := make([]*RetrievalResult, len(candidates))
|
||||
for i, cand := range candidates {
|
||||
results := make([]*RetrievalResult, len(filtered))
|
||||
for i, c := range filtered {
|
||||
results[i] = &RetrievalResult{
|
||||
Chunk: cand.chunk,
|
||||
Item: cand.item,
|
||||
Similarity: cand.similarity,
|
||||
Score: cand.hybridScore,
|
||||
Chunk: c.chunk,
|
||||
Item: c.item,
|
||||
Similarity: c.similarity,
|
||||
Score: c.similarity,
|
||||
}
|
||||
}
|
||||
|
||||
// 上下文扩展:为每个匹配的chunk添加同一文档中的相关chunk
|
||||
// 这可以防止文本描述和payload被分开切分时,只返回描述而丢失payload的问题
|
||||
results = r.expandContext(ctx, results)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// expandContext 扩展检索结果的上下文
|
||||
// 对于每个匹配的chunk,自动包含同一文档中的相关chunk(特别是包含代码块、payload的chunk)
|
||||
func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResult) []*RetrievalResult {
|
||||
if len(results) == 0 {
|
||||
return results
|
||||
}
|
||||
|
||||
// 收集所有匹配到的文档ID
|
||||
itemIDs := make(map[string]bool)
|
||||
for _, result := range results {
|
||||
itemIDs[result.Item.ID] = true
|
||||
}
|
||||
|
||||
// 为每个文档加载所有chunk
|
||||
itemChunksMap := make(map[string][]*KnowledgeChunk)
|
||||
for itemID := range itemIDs {
|
||||
chunks, err := r.loadAllChunksForItem(itemID)
|
||||
if err != nil {
|
||||
r.logger.Warn("加载文档chunk失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
itemChunksMap[itemID] = chunks
|
||||
}
|
||||
|
||||
// 按文档分组结果,每个文档只扩展一次
|
||||
resultsByItem := make(map[string][]*RetrievalResult)
|
||||
for _, result := range results {
|
||||
itemID := result.Item.ID
|
||||
resultsByItem[itemID] = append(resultsByItem[itemID], result)
|
||||
}
|
||||
|
||||
// 扩展每个文档的结果
|
||||
expandedResults := make([]*RetrievalResult, 0, len(results))
|
||||
processedChunkIDs := make(map[string]bool) // 避免重复添加
|
||||
|
||||
for itemID, itemResults := range resultsByItem {
|
||||
// 获取该文档的所有chunk
|
||||
allChunks, exists := itemChunksMap[itemID]
|
||||
if !exists {
|
||||
// 如果无法加载chunk,直接添加原始结果
|
||||
for _, result := range itemResults {
|
||||
if !processedChunkIDs[result.Chunk.ID] {
|
||||
expandedResults = append(expandedResults, result)
|
||||
processedChunkIDs[result.Chunk.ID] = true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加原始结果
|
||||
for _, result := range itemResults {
|
||||
if !processedChunkIDs[result.Chunk.ID] {
|
||||
expandedResults = append(expandedResults, result)
|
||||
processedChunkIDs[result.Chunk.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 为该文档的匹配chunk收集需要扩展的相邻chunk
|
||||
// 策略:只对混合分数最高的前3个匹配chunk进行扩展,避免扩展过多
|
||||
// 先按混合分数排序,只扩展前3个(使用混合分数而不是相似度)
|
||||
sortedItemResults := make([]*RetrievalResult, len(itemResults))
|
||||
copy(sortedItemResults, itemResults)
|
||||
sort.Slice(sortedItemResults, func(i, j int) bool {
|
||||
return sortedItemResults[i].Score > sortedItemResults[j].Score
|
||||
})
|
||||
|
||||
// 只扩展前3个(或所有,如果少于3个)
|
||||
maxExpandFrom := 3
|
||||
if len(sortedItemResults) < maxExpandFrom {
|
||||
maxExpandFrom = len(sortedItemResults)
|
||||
}
|
||||
|
||||
// 使用map去重,避免同一个chunk被多次添加
|
||||
relatedChunksMap := make(map[string]*KnowledgeChunk)
|
||||
|
||||
for i := 0; i < maxExpandFrom; i++ {
|
||||
result := sortedItemResults[i]
|
||||
// 查找相关chunk(上下各2个,排除已处理的chunk)
|
||||
relatedChunks := r.findRelatedChunks(result.Chunk, allChunks, processedChunkIDs)
|
||||
for _, relatedChunk := range relatedChunks {
|
||||
// 使用chunk ID作为key去重
|
||||
if !processedChunkIDs[relatedChunk.ID] {
|
||||
relatedChunksMap[relatedChunk.ID] = relatedChunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 限制每个文档最多扩展的chunk数量(避免扩展过多)
|
||||
// 策略:最多扩展8个chunk,无论匹配了多少个chunk
|
||||
// 这样可以避免当多个匹配chunk分散在文档不同位置时,扩展出过多chunk
|
||||
maxExpandPerItem := 8
|
||||
|
||||
// 将相关chunk转换为切片并按索引排序,优先选择距离匹配chunk最近的
|
||||
relatedChunksList := make([]*KnowledgeChunk, 0, len(relatedChunksMap))
|
||||
for _, chunk := range relatedChunksMap {
|
||||
relatedChunksList = append(relatedChunksList, chunk)
|
||||
}
|
||||
|
||||
// 计算每个相关chunk到最近匹配chunk的距离,按距离排序
|
||||
sort.Slice(relatedChunksList, func(i, j int) bool {
|
||||
// 计算到最近匹配chunk的距离
|
||||
minDistI := len(allChunks)
|
||||
minDistJ := len(allChunks)
|
||||
for _, result := range itemResults {
|
||||
distI := abs(relatedChunksList[i].ChunkIndex - result.Chunk.ChunkIndex)
|
||||
distJ := abs(relatedChunksList[j].ChunkIndex - result.Chunk.ChunkIndex)
|
||||
if distI < minDistI {
|
||||
minDistI = distI
|
||||
}
|
||||
if distJ < minDistJ {
|
||||
minDistJ = distJ
|
||||
}
|
||||
}
|
||||
return minDistI < minDistJ
|
||||
})
|
||||
|
||||
// 限制数量
|
||||
if len(relatedChunksList) > maxExpandPerItem {
|
||||
relatedChunksList = relatedChunksList[:maxExpandPerItem]
|
||||
}
|
||||
|
||||
// 添加去重后的相关chunk
|
||||
// 使用该文档中混合分数最高的结果作为参考
|
||||
maxScore := 0.0
|
||||
maxSimilarity := 0.0
|
||||
for _, result := range itemResults {
|
||||
if result.Score > maxScore {
|
||||
maxScore = result.Score
|
||||
}
|
||||
if result.Similarity > maxSimilarity {
|
||||
maxSimilarity = result.Similarity
|
||||
}
|
||||
}
|
||||
|
||||
// 计算扩展chunk的混合分数(使用相同的混合权重)
|
||||
hybridWeight := r.config.HybridWeight
|
||||
expandedSimilarity := maxSimilarity * 0.8 // 相关chunk的相似度略低
|
||||
// 对于扩展的chunk,BM25分数设为0(因为它们是上下文扩展,不是直接匹配)
|
||||
expandedBM25 := 0.0
|
||||
expandedScore := hybridWeight*expandedSimilarity + (1-hybridWeight)*expandedBM25
|
||||
|
||||
for _, relatedChunk := range relatedChunksList {
|
||||
expandedResult := &RetrievalResult{
|
||||
Chunk: relatedChunk,
|
||||
Item: itemResults[0].Item, // 使用第一个结果的Item信息
|
||||
Similarity: expandedSimilarity,
|
||||
Score: expandedScore, // 使用正确的混合分数
|
||||
}
|
||||
expandedResults = append(expandedResults, expandedResult)
|
||||
processedChunkIDs[relatedChunk.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
return expandedResults
|
||||
}
|
||||
|
||||
// loadAllChunksForItem 加载文档的所有chunk
|
||||
func (r *Retriever) loadAllChunksForItem(itemID string) ([]*KnowledgeChunk, error) {
|
||||
rows, err := r.db.Query(`
|
||||
SELECT id, item_id, chunk_index, chunk_text, embedding
|
||||
FROM knowledge_embeddings
|
||||
WHERE item_id = ?
|
||||
ORDER BY chunk_index
|
||||
`, itemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询chunk失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var chunks []*KnowledgeChunk
|
||||
for rows.Next() {
|
||||
var chunkID, itemID, chunkText, embeddingJSON string
|
||||
var chunkIndex int
|
||||
|
||||
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON); err != nil {
|
||||
r.logger.Warn("扫描chunk失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析向量(可选,这里不需要)
|
||||
var embedding []float32
|
||||
if embeddingJSON != "" {
|
||||
json.Unmarshal([]byte(embeddingJSON), &embedding)
|
||||
}
|
||||
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: chunkID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIndex,
|
||||
ChunkText: chunkText,
|
||||
Embedding: embedding,
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// findRelatedChunks 查找与给定chunk相关的其他chunk
|
||||
// 策略:只返回上下各2个相邻的chunk(共最多4个)
|
||||
// 排除已处理的chunk,避免重复添加
|
||||
func (r *Retriever) findRelatedChunks(targetChunk *KnowledgeChunk, allChunks []*KnowledgeChunk, processedChunkIDs map[string]bool) []*KnowledgeChunk {
|
||||
related := make([]*KnowledgeChunk, 0)
|
||||
|
||||
// 查找上下各2个相邻chunk
|
||||
for _, chunk := range allChunks {
|
||||
if chunk.ID == targetChunk.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否已经被处理过(可能已经在检索结果中)
|
||||
if processedChunkIDs[chunk.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是相邻chunk(索引相差不超过2,且不为0)
|
||||
indexDiff := chunk.ChunkIndex - targetChunk.ChunkIndex
|
||||
if indexDiff >= -2 && indexDiff <= 2 && indexDiff != 0 {
|
||||
related = append(related, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// 按索引距离排序,优先选择最近的
|
||||
sort.Slice(related, func(i, j int) bool {
|
||||
diffI := abs(related[i].ChunkIndex - targetChunk.ChunkIndex)
|
||||
diffJ := abs(related[j].ChunkIndex - targetChunk.ChunkIndex)
|
||||
return diffI < diffJ
|
||||
})
|
||||
|
||||
// 限制最多返回4个(上下各2个)
|
||||
if len(related) > 4 {
|
||||
related = related[:4]
|
||||
}
|
||||
|
||||
return related
|
||||
}
|
||||
|
||||
// abs 返回整数的绝对值
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return NewVectorEinoRetriever(r)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata.
|
||||
func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
var n int
|
||||
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error {
|
||||
var colCount int
|
||||
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
|
||||
if err := db.QueryRow(q, column).Scan(&colCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if colCount > 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec(alterSQL)
|
||||
return err
|
||||
}
|
||||
|
||||
// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。
|
||||
func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error {
|
||||
return EnsureKnowledgeEmbeddingsSchema(db)
|
||||
}
|
||||
@@ -81,8 +81,8 @@ func RegisterKnowledgeTool(
|
||||
// 注册第二个工具:搜索知识库(保持原有功能)
|
||||
searchTool := mcp.Tool{
|
||||
Name: builtin.ToolSearchKnowledgeBase,
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(向量语义检索)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
@@ -123,7 +123,7 @@ func RegisterKnowledgeTool(
|
||||
zap.String("riskType", riskType),
|
||||
)
|
||||
|
||||
// 执行检索
|
||||
// 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。
|
||||
searchReq := &SearchRequest{
|
||||
Query: query,
|
||||
RiskType: riskType,
|
||||
@@ -158,17 +158,16 @@ func RegisterKnowledgeTool(
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
|
||||
// 先按混合分数排序,确保文档顺序是按混合分数的(混合检索的核心)
|
||||
// 按余弦相似度(Score)降序
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].Score > results[j].Score
|
||||
})
|
||||
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
// 使用有序的slice来保持文档顺序(按最高混合分数)
|
||||
type itemGroup struct {
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档的最高混合分数
|
||||
maxScore float64 // 该文档块的最高相似度
|
||||
}
|
||||
itemGroups := make([]*itemGroup, 0)
|
||||
itemMap := make(map[string]*itemGroup)
|
||||
@@ -191,7 +190,7 @@ func RegisterKnowledgeTool(
|
||||
}
|
||||
}
|
||||
|
||||
// 按最高混合分数排序文档组
|
||||
// 按文档内最高相似度排序
|
||||
sort.Slice(itemGroups, func(i, j int) bool {
|
||||
return itemGroups[i].maxScore > itemGroups[j].maxScore
|
||||
})
|
||||
@@ -199,12 +198,11 @@ func RegisterKnowledgeTool(
|
||||
// 收集检索到的知识项ID(用于日志)
|
||||
retrievedItemIDs := make([]string, 0, len(itemGroups))
|
||||
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results)))
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results)))
|
||||
|
||||
resultIndex := 1
|
||||
for _, group := range itemGroups {
|
||||
itemResults := group.results
|
||||
// 找到混合分数最高的作为主结果(使用混合分数,而不是相似度)
|
||||
mainResult := itemResults[0]
|
||||
maxScore := mainResult.Score
|
||||
for _, result := range itemResults {
|
||||
@@ -219,9 +217,8 @@ func RegisterKnowledgeTool(
|
||||
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
|
||||
})
|
||||
|
||||
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
|
||||
|
||||
@@ -80,7 +80,7 @@ type RetrievalResult struct {
|
||||
Chunk *KnowledgeChunk `json:"chunk"`
|
||||
Item *KnowledgeItem `json:"item"`
|
||||
Similarity float64 `json:"similarity"` // 相似度分数
|
||||
Score float64 `json:"score"` // 综合分数(混合检索)
|
||||
Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度
|
||||
}
|
||||
|
||||
// RetrievalLog 检索日志
|
||||
@@ -115,8 +115,9 @@ type CategoryWithItems struct {
|
||||
|
||||
// SearchRequest 搜索请求
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
|
||||
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
|
||||
SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据)
|
||||
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user