mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -308,5 +308,243 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
}
|
||||
}
|
||||
|
||||
// 上下文扩展:为每个匹配的chunk添加同一文档中的相关chunk
|
||||
// 这可以防止文本描述和payload被分开切分时,只返回描述而丢失payload的问题
|
||||
results = r.expandContext(ctx, results)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// expandContext 扩展检索结果的上下文
|
||||
// 对于每个匹配的chunk,自动包含同一文档中的相关chunk(特别是包含代码块、payload的chunk)
|
||||
func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResult) []*RetrievalResult {
|
||||
if len(results) == 0 {
|
||||
return results
|
||||
}
|
||||
|
||||
// 收集所有匹配到的文档ID
|
||||
itemIDs := make(map[string]bool)
|
||||
for _, result := range results {
|
||||
itemIDs[result.Item.ID] = true
|
||||
}
|
||||
|
||||
// 为每个文档加载所有chunk
|
||||
itemChunksMap := make(map[string][]*KnowledgeChunk)
|
||||
for itemID := range itemIDs {
|
||||
chunks, err := r.loadAllChunksForItem(itemID)
|
||||
if err != nil {
|
||||
r.logger.Warn("加载文档chunk失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
itemChunksMap[itemID] = chunks
|
||||
}
|
||||
|
||||
// 按文档分组结果,每个文档只扩展一次
|
||||
resultsByItem := make(map[string][]*RetrievalResult)
|
||||
for _, result := range results {
|
||||
itemID := result.Item.ID
|
||||
resultsByItem[itemID] = append(resultsByItem[itemID], result)
|
||||
}
|
||||
|
||||
// 扩展每个文档的结果
|
||||
expandedResults := make([]*RetrievalResult, 0, len(results))
|
||||
processedChunkIDs := make(map[string]bool) // 避免重复添加
|
||||
|
||||
for itemID, itemResults := range resultsByItem {
|
||||
// 获取该文档的所有chunk
|
||||
allChunks, exists := itemChunksMap[itemID]
|
||||
if !exists {
|
||||
// 如果无法加载chunk,直接添加原始结果
|
||||
for _, result := range itemResults {
|
||||
if !processedChunkIDs[result.Chunk.ID] {
|
||||
expandedResults = append(expandedResults, result)
|
||||
processedChunkIDs[result.Chunk.ID] = true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加原始结果
|
||||
for _, result := range itemResults {
|
||||
if !processedChunkIDs[result.Chunk.ID] {
|
||||
expandedResults = append(expandedResults, result)
|
||||
processedChunkIDs[result.Chunk.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 为该文档的匹配chunk收集需要扩展的相邻chunk
|
||||
// 策略:只对相似度最高的前3个匹配chunk进行扩展,避免扩展过多
|
||||
// 先按相似度排序,只扩展前3个
|
||||
sortedItemResults := make([]*RetrievalResult, len(itemResults))
|
||||
copy(sortedItemResults, itemResults)
|
||||
sort.Slice(sortedItemResults, func(i, j int) bool {
|
||||
return sortedItemResults[i].Similarity > sortedItemResults[j].Similarity
|
||||
})
|
||||
|
||||
// 只扩展前3个(或所有,如果少于3个)
|
||||
maxExpandFrom := 3
|
||||
if len(sortedItemResults) < maxExpandFrom {
|
||||
maxExpandFrom = len(sortedItemResults)
|
||||
}
|
||||
|
||||
// 使用map去重,避免同一个chunk被多次添加
|
||||
relatedChunksMap := make(map[string]*KnowledgeChunk)
|
||||
|
||||
for i := 0; i < maxExpandFrom; i++ {
|
||||
result := sortedItemResults[i]
|
||||
// 查找相关chunk(上下各2个,排除已处理的chunk)
|
||||
relatedChunks := r.findRelatedChunks(result.Chunk, allChunks, processedChunkIDs)
|
||||
for _, relatedChunk := range relatedChunks {
|
||||
// 使用chunk ID作为key去重
|
||||
if !processedChunkIDs[relatedChunk.ID] {
|
||||
relatedChunksMap[relatedChunk.ID] = relatedChunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 限制每个文档最多扩展的chunk数量(避免扩展过多)
|
||||
// 策略:最多扩展8个chunk,无论匹配了多少个chunk
|
||||
// 这样可以避免当多个匹配chunk分散在文档不同位置时,扩展出过多chunk
|
||||
maxExpandPerItem := 8
|
||||
|
||||
// 将相关chunk转换为切片并按索引排序,优先选择距离匹配chunk最近的
|
||||
relatedChunksList := make([]*KnowledgeChunk, 0, len(relatedChunksMap))
|
||||
for _, chunk := range relatedChunksMap {
|
||||
relatedChunksList = append(relatedChunksList, chunk)
|
||||
}
|
||||
|
||||
// 计算每个相关chunk到最近匹配chunk的距离,按距离排序
|
||||
sort.Slice(relatedChunksList, func(i, j int) bool {
|
||||
// 计算到最近匹配chunk的距离
|
||||
minDistI := len(allChunks)
|
||||
minDistJ := len(allChunks)
|
||||
for _, result := range itemResults {
|
||||
distI := abs(relatedChunksList[i].ChunkIndex - result.Chunk.ChunkIndex)
|
||||
distJ := abs(relatedChunksList[j].ChunkIndex - result.Chunk.ChunkIndex)
|
||||
if distI < minDistI {
|
||||
minDistI = distI
|
||||
}
|
||||
if distJ < minDistJ {
|
||||
minDistJ = distJ
|
||||
}
|
||||
}
|
||||
return minDistI < minDistJ
|
||||
})
|
||||
|
||||
// 限制数量
|
||||
if len(relatedChunksList) > maxExpandPerItem {
|
||||
relatedChunksList = relatedChunksList[:maxExpandPerItem]
|
||||
}
|
||||
|
||||
// 添加去重后的相关chunk
|
||||
// 使用该文档中相似度最高的结果作为参考
|
||||
maxSimilarity := 0.0
|
||||
for _, result := range itemResults {
|
||||
if result.Similarity > maxSimilarity {
|
||||
maxSimilarity = result.Similarity
|
||||
}
|
||||
}
|
||||
|
||||
for _, relatedChunk := range relatedChunksList {
|
||||
expandedResult := &RetrievalResult{
|
||||
Chunk: relatedChunk,
|
||||
Item: itemResults[0].Item, // 使用第一个结果的Item信息
|
||||
Similarity: maxSimilarity * 0.8, // 相关chunk的相似度略低
|
||||
Score: maxSimilarity * 0.8,
|
||||
}
|
||||
expandedResults = append(expandedResults, expandedResult)
|
||||
processedChunkIDs[relatedChunk.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
return expandedResults
|
||||
}
|
||||
|
||||
// loadAllChunksForItem 加载文档的所有chunk
|
||||
func (r *Retriever) loadAllChunksForItem(itemID string) ([]*KnowledgeChunk, error) {
|
||||
rows, err := r.db.Query(`
|
||||
SELECT id, item_id, chunk_index, chunk_text, embedding
|
||||
FROM knowledge_embeddings
|
||||
WHERE item_id = ?
|
||||
ORDER BY chunk_index
|
||||
`, itemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询chunk失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var chunks []*KnowledgeChunk
|
||||
for rows.Next() {
|
||||
var chunkID, itemID, chunkText, embeddingJSON string
|
||||
var chunkIndex int
|
||||
|
||||
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON); err != nil {
|
||||
r.logger.Warn("扫描chunk失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析向量(可选,这里不需要)
|
||||
var embedding []float32
|
||||
if embeddingJSON != "" {
|
||||
json.Unmarshal([]byte(embeddingJSON), &embedding)
|
||||
}
|
||||
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: chunkID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIndex,
|
||||
ChunkText: chunkText,
|
||||
Embedding: embedding,
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// findRelatedChunks 查找与给定chunk相关的其他chunk
|
||||
// 策略:只返回上下各2个相邻的chunk(共最多4个)
|
||||
// 排除已处理的chunk,避免重复添加
|
||||
func (r *Retriever) findRelatedChunks(targetChunk *KnowledgeChunk, allChunks []*KnowledgeChunk, processedChunkIDs map[string]bool) []*KnowledgeChunk {
|
||||
related := make([]*KnowledgeChunk, 0)
|
||||
|
||||
// 查找上下各2个相邻chunk
|
||||
for _, chunk := range allChunks {
|
||||
if chunk.ID == targetChunk.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否已经被处理过(可能已经在检索结果中)
|
||||
if processedChunkIDs[chunk.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否是相邻chunk(索引相差不超过2,且不为0)
|
||||
indexDiff := chunk.ChunkIndex - targetChunk.ChunkIndex
|
||||
if indexDiff >= -2 && indexDiff <= 2 && indexDiff != 0 {
|
||||
related = append(related, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// 按索引距离排序,优先选择最近的
|
||||
sort.Slice(related, func(i, j int) bool {
|
||||
diffI := abs(related[i].ChunkIndex - targetChunk.ChunkIndex)
|
||||
diffJ := abs(related[j].ChunkIndex - targetChunk.ChunkIndex)
|
||||
return diffI < diffJ
|
||||
})
|
||||
|
||||
// 限制最多返回4个(上下各2个)
|
||||
if len(related) > 4 {
|
||||
related = related[:4]
|
||||
}
|
||||
|
||||
return related
|
||||
}
|
||||
|
||||
// abs 返回整数的绝对值
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -98,19 +99,62 @@ func RegisterKnowledgeTool(
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识:\n\n", len(results)))
|
||||
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
resultsByItem := make(map[string][]*RetrievalResult)
|
||||
for _, result := range results {
|
||||
itemID := result.Item.ID
|
||||
resultsByItem[itemID] = append(resultsByItem[itemID], result)
|
||||
}
|
||||
|
||||
// 收集检索到的知识项ID(用于日志)
|
||||
retrievedItemIDs := make([]string, 0, len(results))
|
||||
retrievedItemIDs := make([]string, 0, len(resultsByItem))
|
||||
|
||||
for i, result := range results {
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", i+1, result.Similarity*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", result.Item.Category, result.Item.Title, result.Item.ID))
|
||||
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n\n", result.Chunk.ChunkText))
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results)))
|
||||
|
||||
if !contains(retrievedItemIDs, result.Item.ID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, result.Item.ID)
|
||||
resultIndex := 1
|
||||
for itemID, itemResults := range resultsByItem {
|
||||
// 找到相似度最高的作为主结果
|
||||
mainResult := itemResults[0]
|
||||
maxSimilarity := mainResult.Similarity
|
||||
for _, result := range itemResults {
|
||||
if result.Similarity > maxSimilarity {
|
||||
maxSimilarity = result.Similarity
|
||||
mainResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
|
||||
sort.Slice(itemResults, func(i, j int) bool {
|
||||
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
|
||||
})
|
||||
|
||||
// 显示主结果(相似度最高的)
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", resultIndex, mainResult.Similarity*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
|
||||
if len(itemResults) == 1 {
|
||||
// 只有一个chunk,直接显示
|
||||
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
|
||||
} else {
|
||||
// 多个chunk,按逻辑顺序显示
|
||||
resultText.WriteString("内容片段(按文档顺序):\n")
|
||||
for i, result := range itemResults {
|
||||
// 标记主结果
|
||||
marker := ""
|
||||
if result.Chunk.ID == mainResult.Chunk.ID {
|
||||
marker = " [主匹配]"
|
||||
}
|
||||
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
|
||||
}
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
|
||||
if !contains(retrievedItemIDs, itemID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, itemID)
|
||||
}
|
||||
resultIndex++
|
||||
}
|
||||
|
||||
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
|
||||
@@ -140,89 +184,6 @@ func RegisterKnowledgeTool(
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
|
||||
|
||||
// 注册读取完整知识项的工具
|
||||
RegisterReadKnowledgeItemTool(mcpServer, manager, logger)
|
||||
}
|
||||
|
||||
// RegisterReadKnowledgeItemTool 注册读取完整知识项工具到MCP服务器
|
||||
func RegisterReadKnowledgeItemTool(
|
||||
mcpServer *mcp.Server,
|
||||
manager *Manager,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
tool := mcp.Tool{
|
||||
Name: "read_knowledge_item",
|
||||
Description: "根据知识项ID读取完整的知识文档内容。**重要:此工具应谨慎使用,只在检索到的片段信息明显不足时才调用。** 使用场景:1) 检索片段缺少关键上下文导致无法理解;2) 需要查看文档的完整结构或流程;3) 片段信息不完整,必须查看完整文档才能回答用户问题。**不要**仅为了获取更多信息而盲目读取完整文档,因为检索工具已经返回了最相关的片段。传入知识项ID(从search_knowledge_base的检索结果中获取)即可获取该知识项的完整内容(包括标题、分类、完整文档内容等)。",
|
||||
ShortDescription: "读取完整知识项文档(仅在片段信息不足时使用)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"item_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "知识项ID(可以从 search_knowledge_base 的检索结果中获取)",
|
||||
},
|
||||
},
|
||||
"required": []string{"item_id"},
|
||||
},
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
itemID, ok := args["item_id"].(string)
|
||||
if !ok || itemID == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: item_id 参数不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.Info("读取知识项", zap.String("itemId", itemID))
|
||||
|
||||
// 获取完整知识项
|
||||
item, err := manager.GetItem(itemID)
|
||||
if err != nil {
|
||||
logger.Error("读取知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("读取知识项失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
resultText.WriteString("=== 完整知识项内容 ===\n\n")
|
||||
resultText.WriteString(fmt.Sprintf("ID: %s\n", item.ID))
|
||||
resultText.WriteString(fmt.Sprintf("分类: %s\n", item.Category))
|
||||
resultText.WriteString(fmt.Sprintf("标题: %s\n", item.Title))
|
||||
if item.FilePath != "" {
|
||||
resultText.WriteString(fmt.Sprintf("文件路径: %s\n", item.FilePath))
|
||||
}
|
||||
resultText.WriteString("\n--- 完整内容 ---\n\n")
|
||||
resultText.WriteString(item.Content)
|
||||
resultText.WriteString("\n\n")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("读取知识项工具已注册", zap.String("toolName", tool.Name))
|
||||
}
|
||||
|
||||
// contains 检查切片是否包含元素
|
||||
|
||||
Reference in New Issue
Block a user