From ea3dc216c1ee6949314c874f8a343b720abca120 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sun, 21 Dec 2025 15:21:44 +0800 Subject: [PATCH] Add files via upload --- internal/knowledge/retriever.go | 238 ++++++++++++++++++++++++++++++++ internal/knowledge/tool.go | 143 +++++++------------ 2 files changed, 290 insertions(+), 91 deletions(-) diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go index 090313cf..ad7245e9 100644 --- a/internal/knowledge/retriever.go +++ b/internal/knowledge/retriever.go @@ -308,5 +308,243 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva } } + // 上下文扩展:为每个匹配的chunk添加同一文档中的相关chunk + // 这可以防止文本描述和payload被分开切分时,只返回描述而丢失payload的问题 + results = r.expandContext(ctx, results) + return results, nil } + +// expandContext 扩展检索结果的上下文 +// 对于每个匹配的chunk,自动包含同一文档中的相关chunk(特别是包含代码块、payload的chunk) +func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResult) []*RetrievalResult { + if len(results) == 0 { + return results + } + + // 收集所有匹配到的文档ID + itemIDs := make(map[string]bool) + for _, result := range results { + itemIDs[result.Item.ID] = true + } + + // 为每个文档加载所有chunk + itemChunksMap := make(map[string][]*KnowledgeChunk) + for itemID := range itemIDs { + chunks, err := r.loadAllChunksForItem(itemID) + if err != nil { + r.logger.Warn("加载文档chunk失败", zap.String("itemId", itemID), zap.Error(err)) + continue + } + itemChunksMap[itemID] = chunks + } + + // 按文档分组结果,每个文档只扩展一次 + resultsByItem := make(map[string][]*RetrievalResult) + for _, result := range results { + itemID := result.Item.ID + resultsByItem[itemID] = append(resultsByItem[itemID], result) + } + + // 扩展每个文档的结果 + expandedResults := make([]*RetrievalResult, 0, len(results)) + processedChunkIDs := make(map[string]bool) // 避免重复添加 + + for itemID, itemResults := range resultsByItem { + // 获取该文档的所有chunk + allChunks, exists := itemChunksMap[itemID] + if !exists { + // 如果无法加载chunk,直接添加原始结果 + for _, result := range itemResults { + if !processedChunkIDs[result.Chunk.ID] { + expandedResults = append(expandedResults, result) + processedChunkIDs[result.Chunk.ID] = true + } + } + continue + } + + // 添加原始结果 + for _, result := range itemResults { + if !processedChunkIDs[result.Chunk.ID] { + expandedResults = append(expandedResults, result) + processedChunkIDs[result.Chunk.ID] = true + } + } + + // 为该文档的匹配chunk收集需要扩展的相邻chunk + // 策略:只对相似度最高的前3个匹配chunk进行扩展,避免扩展过多 + // 先按相似度排序,只扩展前3个 + sortedItemResults := make([]*RetrievalResult, len(itemResults)) + copy(sortedItemResults, itemResults) + sort.Slice(sortedItemResults, func(i, j int) bool { + return sortedItemResults[i].Similarity > sortedItemResults[j].Similarity + }) + + // 只扩展前3个(或所有,如果少于3个) + maxExpandFrom := 3 + if len(sortedItemResults) < maxExpandFrom { + maxExpandFrom = len(sortedItemResults) + } + + // 使用map去重,避免同一个chunk被多次添加 + relatedChunksMap := make(map[string]*KnowledgeChunk) + + for i := 0; i < maxExpandFrom; i++ { + result := sortedItemResults[i] + // 查找相关chunk(上下各2个,排除已处理的chunk) + relatedChunks := r.findRelatedChunks(result.Chunk, allChunks, processedChunkIDs) + for _, relatedChunk := range relatedChunks { + // 使用chunk ID作为key去重 + if !processedChunkIDs[relatedChunk.ID] { + relatedChunksMap[relatedChunk.ID] = relatedChunk + } + } + } + + // 限制每个文档最多扩展的chunk数量(避免扩展过多) + // 策略:最多扩展8个chunk,无论匹配了多少个chunk + // 这样可以避免当多个匹配chunk分散在文档不同位置时,扩展出过多chunk + maxExpandPerItem := 8 + + // 将相关chunk转换为切片并按索引排序,优先选择距离匹配chunk最近的 + relatedChunksList := make([]*KnowledgeChunk, 0, len(relatedChunksMap)) + for _, chunk := range relatedChunksMap { + relatedChunksList = append(relatedChunksList, chunk) + } + + // 计算每个相关chunk到最近匹配chunk的距离,按距离排序 + sort.Slice(relatedChunksList, func(i, j int) bool { + // 计算到最近匹配chunk的距离 + minDistI := len(allChunks) + minDistJ := len(allChunks) + for _, result := range itemResults { + distI := abs(relatedChunksList[i].ChunkIndex - result.Chunk.ChunkIndex) + distJ := abs(relatedChunksList[j].ChunkIndex - result.Chunk.ChunkIndex) + if distI < minDistI { + minDistI = distI + } + if distJ < minDistJ { + minDistJ = distJ + } + } + return minDistI < minDistJ + }) + + // 限制数量 + if len(relatedChunksList) > maxExpandPerItem { + relatedChunksList = relatedChunksList[:maxExpandPerItem] + } + + // 添加去重后的相关chunk + // 使用该文档中相似度最高的结果作为参考 + maxSimilarity := 0.0 + for _, result := range itemResults { + if result.Similarity > maxSimilarity { + maxSimilarity = result.Similarity + } + } + + for _, relatedChunk := range relatedChunksList { + expandedResult := &RetrievalResult{ + Chunk: relatedChunk, + Item: itemResults[0].Item, // 使用第一个结果的Item信息 + Similarity: maxSimilarity * 0.8, // 相关chunk的相似度略低 + Score: maxSimilarity * 0.8, + } + expandedResults = append(expandedResults, expandedResult) + processedChunkIDs[relatedChunk.ID] = true + } + } + + return expandedResults +} + +// loadAllChunksForItem 加载文档的所有chunk +func (r *Retriever) loadAllChunksForItem(itemID string) ([]*KnowledgeChunk, error) { + rows, err := r.db.Query(` + SELECT id, item_id, chunk_index, chunk_text, embedding + FROM knowledge_embeddings + WHERE item_id = ? + ORDER BY chunk_index + `, itemID) + if err != nil { + return nil, fmt.Errorf("查询chunk失败: %w", err) + } + defer rows.Close() + + var chunks []*KnowledgeChunk + for rows.Next() { + var chunkID, itemID, chunkText, embeddingJSON string + var chunkIndex int + + if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON); err != nil { + r.logger.Warn("扫描chunk失败", zap.Error(err)) + continue + } + + // 解析向量(可选,这里不需要) + var embedding []float32 + if embeddingJSON != "" { + json.Unmarshal([]byte(embeddingJSON), &embedding) + } + + chunk := &KnowledgeChunk{ + ID: chunkID, + ItemID: itemID, + ChunkIndex: chunkIndex, + ChunkText: chunkText, + Embedding: embedding, + } + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// findRelatedChunks 查找与给定chunk相关的其他chunk +// 策略:只返回上下各2个相邻的chunk(共最多4个) +// 排除已处理的chunk,避免重复添加 +func (r *Retriever) findRelatedChunks(targetChunk *KnowledgeChunk, allChunks []*KnowledgeChunk, processedChunkIDs map[string]bool) []*KnowledgeChunk { + related := make([]*KnowledgeChunk, 0) + + // 查找上下各2个相邻chunk + for _, chunk := range allChunks { + if chunk.ID == targetChunk.ID { + continue + } + + // 检查是否已经被处理过(可能已经在检索结果中) + if processedChunkIDs[chunk.ID] { + continue + } + + // 检查是否是相邻chunk(索引相差不超过2,且不为0) + indexDiff := chunk.ChunkIndex - targetChunk.ChunkIndex + if indexDiff >= -2 && indexDiff <= 2 && indexDiff != 0 { + related = append(related, chunk) + } + } + + // 按索引距离排序,优先选择最近的 + sort.Slice(related, func(i, j int) bool { + diffI := abs(related[i].ChunkIndex - targetChunk.ChunkIndex) + diffJ := abs(related[j].ChunkIndex - targetChunk.ChunkIndex) + return diffI < diffJ + }) + + // 限制最多返回4个(上下各2个) + if len(related) > 4 { + related = related[:4] + } + + return related +} + +// abs 返回整数的绝对值 +func abs(x int) int { + if x < 0 { + return -x + } + return x +} diff --git a/internal/knowledge/tool.go b/internal/knowledge/tool.go index 8c5774ef..075c35cb 100644 --- a/internal/knowledge/tool.go +++ b/internal/knowledge/tool.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sort" "strings" "cyberstrike-ai/internal/mcp" @@ -98,19 +99,62 @@ func RegisterKnowledgeTool( // 格式化结果 var resultText strings.Builder - resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识:\n\n", len(results))) + + // 按文档分组结果,以便更好地展示上下文 + resultsByItem := make(map[string][]*RetrievalResult) + for _, result := range results { + itemID := result.Item.ID + resultsByItem[itemID] = append(resultsByItem[itemID], result) + } // 收集检索到的知识项ID(用于日志) - retrievedItemIDs := make([]string, 0, len(results)) + retrievedItemIDs := make([]string, 0, len(resultsByItem)) - for i, result := range results { - resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", i+1, result.Similarity*100)) - resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", result.Item.Category, result.Item.Title, result.Item.ID)) - resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n\n", result.Chunk.ChunkText)) + resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results))) - if !contains(retrievedItemIDs, result.Item.ID) { - retrievedItemIDs = append(retrievedItemIDs, result.Item.ID) + resultIndex := 1 + for itemID, itemResults := range resultsByItem { + // 找到相似度最高的作为主结果 + mainResult := itemResults[0] + maxSimilarity := mainResult.Similarity + for _, result := range itemResults { + if result.Similarity > maxSimilarity { + maxSimilarity = result.Similarity + mainResult = result + } } + + // 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序) + sort.Slice(itemResults, func(i, j int) bool { + return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex + }) + + // 显示主结果(相似度最高的) + resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", resultIndex, mainResult.Similarity*100)) + resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) + + // 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk) + if len(itemResults) == 1 { + // 只有一个chunk,直接显示 + resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText)) + } else { + // 多个chunk,按逻辑顺序显示 + resultText.WriteString("内容片段(按文档顺序):\n") + for i, result := range itemResults { + // 标记主结果 + marker := "" + if result.Chunk.ID == mainResult.Chunk.ID { + marker = " [主匹配]" + } + resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText)) + } + } + resultText.WriteString("\n") + + if !contains(retrievedItemIDs, itemID) { + retrievedItemIDs = append(retrievedItemIDs, itemID) + } + resultIndex++ } // 在结果末尾添加元数据(JSON格式,用于提取知识项ID) @@ -140,89 +184,6 @@ func RegisterKnowledgeTool( mcpServer.RegisterTool(tool, handler) logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name)) - - // 注册读取完整知识项的工具 - RegisterReadKnowledgeItemTool(mcpServer, manager, logger) -} - -// RegisterReadKnowledgeItemTool 注册读取完整知识项工具到MCP服务器 -func RegisterReadKnowledgeItemTool( - mcpServer *mcp.Server, - manager *Manager, - logger *zap.Logger, -) { - tool := mcp.Tool{ - Name: "read_knowledge_item", - Description: "根据知识项ID读取完整的知识文档内容。**重要:此工具应谨慎使用,只在检索到的片段信息明显不足时才调用。** 使用场景:1) 检索片段缺少关键上下文导致无法理解;2) 需要查看文档的完整结构或流程;3) 片段信息不完整,必须查看完整文档才能回答用户问题。**不要**仅为了获取更多信息而盲目读取完整文档,因为检索工具已经返回了最相关的片段。传入知识项ID(从search_knowledge_base的检索结果中获取)即可获取该知识项的完整内容(包括标题、分类、完整文档内容等)。", - ShortDescription: "读取完整知识项文档(仅在片段信息不足时使用)", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "item_id": map[string]interface{}{ - "type": "string", - "description": "知识项ID(可以从 search_knowledge_base 的检索结果中获取)", - }, - }, - "required": []string{"item_id"}, - }, - } - - handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - itemID, ok := args["item_id"].(string) - if !ok || itemID == "" { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: "错误: item_id 参数不能为空", - }, - }, - IsError: true, - }, nil - } - - logger.Info("读取知识项", zap.String("itemId", itemID)) - - // 获取完整知识项 - item, err := manager.GetItem(itemID) - if err != nil { - logger.Error("读取知识项失败", zap.String("itemId", itemID), zap.Error(err)) - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("读取知识项失败: %v", err), - }, - }, - IsError: true, - }, nil - } - - // 格式化结果 - var resultText strings.Builder - resultText.WriteString("=== 完整知识项内容 ===\n\n") - resultText.WriteString(fmt.Sprintf("ID: %s\n", item.ID)) - resultText.WriteString(fmt.Sprintf("分类: %s\n", item.Category)) - resultText.WriteString(fmt.Sprintf("标题: %s\n", item.Title)) - if item.FilePath != "" { - resultText.WriteString(fmt.Sprintf("文件路径: %s\n", item.FilePath)) - } - resultText.WriteString("\n--- 完整内容 ---\n\n") - resultText.WriteString(item.Content) - resultText.WriteString("\n\n") - - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: resultText.String(), - }, - }, - }, nil - } - - mcpServer.RegisterTool(tool, handler) - logger.Info("读取知识项工具已注册", zap.String("toolName", tool.Name)) } // contains 检查切片是否包含元素