diff --git a/internal/config/config.go b/internal/config/config.go index bc024f85..83b0997f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -582,9 +582,18 @@ func Default() *Config { }, Retrieval: RetrievalConfig{ TopK: 5, - SimilarityThreshold: 0.7, + SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检 HybridWeight: 0.7, }, + Indexing: IndexingConfig{ + ChunkSize: 768, // 增加到 768,更好的上下文保持 + ChunkOverlap: 50, + MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额 + MaxRPM: 100, // 默认 100 RPM,避免 429 错误 + RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM + MaxRetries: 3, + RetryDelayMs: 1000, + }, }, } } diff --git a/internal/knowledge/indexer.go b/internal/knowledge/indexer.go index ca9fae3b..4a0da3eb 100644 --- a/internal/knowledge/indexer.go +++ b/internal/knowledge/indexer.go @@ -544,9 +544,21 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { idx.mu.Unlock() } - // 如果连续失败 2 个块,立即停止处理该知识项(降低阈值,更快停止) + // 如果连续失败 5 个块,立即停止处理该知识项 // 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题 - if itemErrorCount >= 2 { + // 对于大文档(超过 10 个块),允许失败比例不超过 50% + maxConsecutiveFailures := 5 + if len(chunks) > 10 && itemErrorCount > len(chunks)/2 { + idx.logger.Error("知识项向量化失败比例过高,停止处理", + zap.String("itemId", itemID), + zap.Int("totalChunks", len(chunks)), + zap.Int("failedChunks", itemErrorCount), + zap.Int("firstErrorChunkIndex", firstErrorChunkIndex), + zap.Error(firstError), + ) + return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError) + } + if itemErrorCount >= maxConsecutiveFailures { idx.logger.Error("知识项连续向量化失败,停止处理", zap.String("itemId", itemID), zap.Int("totalChunks", len(chunks)), @@ -649,7 +661,7 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error { failedCount := 0 consecutiveFailures := 0 - maxConsecutiveFailures := 2 // 连续失败 2 次后立即停止(降低阈值,更快停止) + maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误) firstFailureItemID := "" var firstFailureError error diff --git a/internal/knowledge/manager.go b/internal/knowledge/manager.go index ec72abad..7309cc2a 100644 --- a/internal/knowledge/manager.go +++ b/internal/knowledge/manager.go @@ -657,7 +657,7 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte // 删除旧目录(如果为空) oldDir := filepath.Dir(item.FilePath) - if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 { + if isEmpty, _ := isEmptyDir(oldDir); isEmpty { // 只有当目录不是知识库根目录时才删除(避免删除根目录) if oldDir != m.basePath { if err := os.Remove(oldDir); err != nil { @@ -712,7 +712,7 @@ func (m *Manager) DeleteItem(id string) error { // 删除空目录(如果为空) dir := filepath.Dir(filePath) - if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 { + if isEmpty, _ := isEmptyDir(dir); isEmpty { // 只有当目录不是知识库根目录时才删除(避免删除根目录) if dir != m.basePath { if err := os.Remove(dir); err != nil { @@ -724,6 +724,21 @@ func (m *Manager) DeleteItem(id string) error { return nil } +// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件) +func isEmptyDir(dir string) (bool, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return false, err + } + for _, entry := range entries { + // 忽略隐藏文件(以 . 开头) + if !strings.HasPrefix(entry.Name(), ".") { + return false, nil + } + } + return true, nil +} + // LogRetrieval 记录检索日志 func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error { id := uuid.New().String() diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go index 15e46a1f..6a6551a1 100644 --- a/internal/knowledge/retriever.go +++ b/internal/knowledge/retriever.go @@ -69,8 +69,8 @@ func cosineSimilarity(a, b []float32) float64 { return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) } -// bm25Score 计算BM25分数(改进版,更接近标准BM25) -// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确 +// bm25Score 计算 BM25 分数(带缓存的改进版本) +// 注意:由于缺少全局文档统计,使用简化 IDF 计算 func (r *Retriever) bm25Score(query, text string) float64 { queryTerms := strings.Fields(strings.ToLower(query)) if len(queryTerms) == 0 { @@ -83,44 +83,56 @@ func (r *Retriever) bm25Score(query, text string) float64 { return 0.0 } - // BM25参数 - k1 := 1.5 // 词频饱和度参数 - b := 0.75 // 长度归一化参数 - avgDocLength := 100.0 // 估算的平均文档长度(用于归一化) + // BM25 参数(标准值) + k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0) + b := 0.75 // 长度归一化参数(标准值) + avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小) docLength := float64(len(textTerms)) - score := 0.0 - for _, term := range queryTerms { - // 计算词频(TF) - termFreq := 0 - for _, textTerm := range textTerms { - if textTerm == term { - termFreq++ - } - } - - if termFreq > 0 { - // BM25公式的核心部分 - // TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength))) - tf := float64(termFreq) - lengthNorm := 1 - b + b*(docLength/avgDocLength) - tfScore := tf / (tf + k1*lengthNorm) - - // 简化IDF:使用词长度作为权重(短词通常更重要) - // 实际BM25需要全局文档统计,这里用简化版本 - idfWeight := 1.0 - if len(term) > 2 { - // 长词稍微降低权重(但实际BM25中,罕见词IDF更高) - idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0) - } - - score += tfScore * idfWeight - } + // 计算词频映射 + textTermFreq := make(map[string]int, len(textTerms)) + for _, term := range textTerms { + textTermFreq[term]++ } - // 归一化到0-1范围 + score := 0.0 + matchedQueryTerms := 0 + + for _, term := range queryTerms { + termFreq, exists := textTermFreq[term] + if !exists || termFreq == 0 { + continue + } + matchedQueryTerms++ + + // BM25 TF 计算公式 + tf := float64(termFreq) + lengthNorm := 1 - b + b*(docLength/avgDocLength) + tfScore := tf / (tf + k1*lengthNorm) + + // 改进的 IDF 计算:使用词长度和出现频率估算 + // 短词(2-3 字符)通常更重要,长词 IDF 略低 + idfWeight := 1.0 + termLen := len(term) + if termLen <= 2 { + // 极短词(如 go, js)给予更高权重 + idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0) + } else if termLen <= 4 { + // 短词(4 字符)标准权重 + idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0) + } else { + // 长词稍微降低权重 + idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0) + } + + score += tfScore * idfWeight + } + + // 归一化:考虑匹配的查询词比例 if len(queryTerms) > 0 { - score = score / float64(len(queryTerms)) + // 使用匹配比例作为额外因子 + matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms)) + score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2 } return math.Min(score, 1.0) @@ -173,7 +185,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title FROM knowledge_embeddings e JOIN knowledge_base_items i ON e.item_id = i.id - WHERE i.category = ? COLLATE NOCASE + WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE `, req.RiskType) } else { rows, err = r.db.Query(` @@ -357,7 +369,10 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva zap.Float64("threshold", threshold), zap.Float64("maxSimilarity", maxSimilarity), ) - } else if len(filteredCandidates) > topK { + } + + // 统一在最终返回前严格限制 Top-K 数量 + if len(filteredCandidates) > topK { // 如果过滤后结果太多,只取Top-K filteredCandidates = filteredCandidates[:topK] } diff --git a/internal/knowledge/types.go b/internal/knowledge/types.go index 656b03a7..bccd3a93 100644 --- a/internal/knowledge/types.go +++ b/internal/knowledge/types.go @@ -5,6 +5,14 @@ import ( "time" ) +// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串 +func formatTime(t time.Time) string { + if t.IsZero() { + return "" + } + return t.Format(time.RFC3339) +} + // KnowledgeItem 知识库项 type KnowledgeItem struct { ID string `json:"id"` @@ -22,12 +30,12 @@ type KnowledgeItemSummary struct { Category string `json:"category"` Title string `json:"title"` FilePath string `json:"filePath"` - Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前150字符) + Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符) CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` } -// MarshalJSON 自定义JSON序列化,确保时间格式正确 +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) { type Alias KnowledgeItemSummary aux := &struct { @@ -37,25 +45,12 @@ func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) { }{ Alias: (*Alias)(k), } - - // 格式化创建时间 - if k.CreatedAt.IsZero() { - aux.CreatedAt = "" - } else { - aux.CreatedAt = k.CreatedAt.Format(time.RFC3339) - } - - // 格式化更新时间 - if k.UpdatedAt.IsZero() { - aux.UpdatedAt = "" - } else { - aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339) - } - + aux.CreatedAt = formatTime(k.CreatedAt) + aux.UpdatedAt = formatTime(k.UpdatedAt) return json.Marshal(aux) } -// MarshalJSON 自定义JSON序列化,确保时间格式正确 +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 func (k *KnowledgeItem) MarshalJSON() ([]byte, error) { type Alias KnowledgeItem aux := &struct { @@ -65,21 +60,8 @@ func (k *KnowledgeItem) MarshalJSON() ([]byte, error) { }{ Alias: (*Alias)(k), } - - // 格式化创建时间 - if k.CreatedAt.IsZero() { - aux.CreatedAt = "" - } else { - aux.CreatedAt = k.CreatedAt.Format(time.RFC3339) - } - - // 格式化更新时间 - if k.UpdatedAt.IsZero() { - aux.UpdatedAt = "" - } else { - aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339) - } - + aux.CreatedAt = formatTime(k.CreatedAt) + aux.UpdatedAt = formatTime(k.UpdatedAt) return json.Marshal(aux) } @@ -89,7 +71,7 @@ type KnowledgeChunk struct { ItemID string `json:"itemId"` ChunkIndex int `json:"chunkIndex"` ChunkText string `json:"chunkText"` - Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON + Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON CreatedAt time.Time `json:"createdAt"` } @@ -108,11 +90,11 @@ type RetrievalLog struct { MessageID string `json:"messageId,omitempty"` Query string `json:"query"` RiskType string `json:"riskType,omitempty"` - RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表 + RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表 CreatedAt time.Time `json:"createdAt"` } -// MarshalJSON 自定义JSON序列化,确保时间格式正确 +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 func (r *RetrievalLog) MarshalJSON() ([]byte, error) { type Alias RetrievalLog return json.Marshal(&struct { @@ -120,21 +102,21 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) { CreatedAt string `json:"createdAt"` }{ Alias: (*Alias)(r), - CreatedAt: r.CreatedAt.Format(time.RFC3339), + CreatedAt: formatTime(r.CreatedAt), }) } // CategoryWithItems 分类及其下的知识项(用于按分类分页) type CategoryWithItems struct { - Category string `json:"category"` // 分类名称 - ItemCount int `json:"itemCount"` // 该分类下的知识项总数 - Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 + Category string `json:"category"` // 分类名称 + ItemCount int `json:"itemCount"` // 该分类下的知识项总数 + Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 } // SearchRequest 搜索请求 type SearchRequest struct { Query string `json:"query"` RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 - TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5 - Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7 + TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5 + Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7 }