mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -32,7 +32,7 @@ func NewIndexer(db *sql.DB, embedder *Embedder, logger *zap.Logger) *Indexer {
|
||||
}
|
||||
}
|
||||
|
||||
// ChunkText 将文本分块
|
||||
// ChunkText 将文本分块(支持重叠)
|
||||
func (idx *Indexer) ChunkText(text string) []string {
|
||||
// 按Markdown标题分割
|
||||
chunks := idx.splitByMarkdownHeaders(text)
|
||||
@@ -49,26 +49,9 @@ func (idx *Indexer) ChunkText(text string) []string {
|
||||
if idx.estimateTokens(subChunk) <= idx.chunkSize {
|
||||
result = append(result, subChunk)
|
||||
} else {
|
||||
// 按句子分割
|
||||
sentences := idx.splitBySentences(subChunk)
|
||||
currentChunk := ""
|
||||
for _, sentence := range sentences {
|
||||
testChunk := currentChunk
|
||||
if testChunk != "" {
|
||||
testChunk += "\n"
|
||||
}
|
||||
testChunk += sentence
|
||||
|
||||
if idx.estimateTokens(testChunk) > idx.chunkSize && currentChunk != "" {
|
||||
result = append(result, currentChunk)
|
||||
currentChunk = sentence
|
||||
} else {
|
||||
currentChunk = testChunk
|
||||
}
|
||||
}
|
||||
if currentChunk != "" {
|
||||
result = append(result, currentChunk)
|
||||
}
|
||||
// 按句子分割(支持重叠)
|
||||
chunksWithOverlap := idx.splitBySentencesWithOverlap(subChunk)
|
||||
result = append(result, chunksWithOverlap...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -131,7 +114,7 @@ func (idx *Indexer) splitByParagraphs(text string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// splitBySentences 按句子分割
|
||||
// splitBySentences 按句子分割(用于内部,不包含重叠逻辑)
|
||||
func (idx *Indexer) splitBySentences(text string) []string {
|
||||
// 简单的句子分割(按句号、问号、感叹号)
|
||||
sentenceRegex := regexp.MustCompile(`[.!?]+\s+`)
|
||||
@@ -145,6 +128,121 @@ func (idx *Indexer) splitBySentences(text string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// splitBySentencesWithOverlap 按句子分割并应用重叠策略
|
||||
func (idx *Indexer) splitBySentencesWithOverlap(text string) []string {
|
||||
if idx.overlap <= 0 {
|
||||
// 如果没有重叠,使用简单分割
|
||||
return idx.splitBySentencesSimple(text)
|
||||
}
|
||||
|
||||
sentences := idx.splitBySentences(text)
|
||||
if len(sentences) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
currentChunk := ""
|
||||
|
||||
for _, sentence := range sentences {
|
||||
testChunk := currentChunk
|
||||
if testChunk != "" {
|
||||
testChunk += "\n"
|
||||
}
|
||||
testChunk += sentence
|
||||
|
||||
testTokens := idx.estimateTokens(testChunk)
|
||||
|
||||
if testTokens > idx.chunkSize && currentChunk != "" {
|
||||
// 当前块已达到大小限制,保存它
|
||||
result = append(result, currentChunk)
|
||||
|
||||
// 从当前块的末尾提取重叠部分
|
||||
overlapText := idx.extractLastTokens(currentChunk, idx.overlap)
|
||||
if overlapText != "" {
|
||||
// 如果有重叠内容,作为下一个块的起始
|
||||
currentChunk = overlapText + "\n" + sentence
|
||||
} else {
|
||||
// 如果无法提取足够的重叠内容,直接使用当前句子
|
||||
currentChunk = sentence
|
||||
}
|
||||
} else {
|
||||
currentChunk = testChunk
|
||||
}
|
||||
}
|
||||
|
||||
// 添加最后一个块
|
||||
if strings.TrimSpace(currentChunk) != "" {
|
||||
result = append(result, currentChunk)
|
||||
}
|
||||
|
||||
// 过滤空块
|
||||
filtered := make([]string, 0)
|
||||
for _, chunk := range result {
|
||||
if strings.TrimSpace(chunk) != "" {
|
||||
filtered = append(filtered, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// splitBySentencesSimple 按句子分割(简单版本,无重叠)
|
||||
func (idx *Indexer) splitBySentencesSimple(text string) []string {
|
||||
sentences := idx.splitBySentences(text)
|
||||
result := make([]string, 0)
|
||||
currentChunk := ""
|
||||
|
||||
for _, sentence := range sentences {
|
||||
testChunk := currentChunk
|
||||
if testChunk != "" {
|
||||
testChunk += "\n"
|
||||
}
|
||||
testChunk += sentence
|
||||
|
||||
if idx.estimateTokens(testChunk) > idx.chunkSize && currentChunk != "" {
|
||||
result = append(result, currentChunk)
|
||||
currentChunk = sentence
|
||||
} else {
|
||||
currentChunk = testChunk
|
||||
}
|
||||
}
|
||||
if currentChunk != "" {
|
||||
result = append(result, currentChunk)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractLastTokens 从文本末尾提取指定token数量的内容
|
||||
func (idx *Indexer) extractLastTokens(text string, tokenCount int) string {
|
||||
if tokenCount <= 0 || text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 估算字符数(1 token ≈ 4字符)
|
||||
charCount := tokenCount * 4
|
||||
runes := []rune(text)
|
||||
|
||||
if len(runes) <= charCount {
|
||||
return text
|
||||
}
|
||||
|
||||
// 从末尾提取指定数量的字符
|
||||
// 尝试在句子边界处截断,避免截断句子中间
|
||||
startPos := len(runes) - charCount
|
||||
extracted := string(runes[startPos:])
|
||||
|
||||
// 尝试找到第一个句子边界(句号、问号、感叹号后的空格)
|
||||
sentenceBoundary := regexp.MustCompile(`[.!?]+\s+`)
|
||||
matches := sentenceBoundary.FindStringIndex(extracted)
|
||||
if len(matches) > 0 && matches[0] > 0 {
|
||||
// 在句子边界处截断,保留完整句子
|
||||
extracted = extracted[matches[0]:]
|
||||
}
|
||||
|
||||
return strings.TrimSpace(extracted)
|
||||
}
|
||||
|
||||
// estimateTokens 估算token数(简单估算:1 token ≈ 4字符)
|
||||
func (idx *Indexer) estimateTokens(text string) int {
|
||||
return len([]rune(text)) / 4
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -130,10 +131,11 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
|
||||
// 计算相似度
|
||||
type candidate struct {
|
||||
chunk *KnowledgeChunk
|
||||
item *KnowledgeItem
|
||||
similarity float64
|
||||
bm25Score float64
|
||||
chunk *KnowledgeChunk
|
||||
item *KnowledgeItem
|
||||
similarity float64
|
||||
bm25Score float64
|
||||
hasStrongKeywordMatch bool
|
||||
}
|
||||
|
||||
candidates := make([]candidate, 0)
|
||||
@@ -169,22 +171,10 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
// 综合BM25分数(用于后续排序)
|
||||
bm25Score := math.Max(math.Max(chunkBM25, categoryBM25), titleBM25)
|
||||
|
||||
// 过滤策略:
|
||||
// 1. 如果向量相似度达到阈值,通过
|
||||
// 2. 如果category/title有显著匹配,适当放宽相似度要求(因为它们更可靠)
|
||||
// 这样既保持了原有的过滤严格性,又能处理结构化字段匹配的情况
|
||||
if similarity < threshold {
|
||||
// 只有当category或title有明显匹配时,才适当放宽阈值
|
||||
if hasStrongKeywordMatch {
|
||||
// 放宽到原阈值的75%,但至少要有0.35的相似度
|
||||
// 这确保了即使关键词匹配,向量相似度也不能太低
|
||||
relaxedThreshold := math.Max(threshold*0.75, 0.35)
|
||||
if similarity < relaxedThreshold {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
// 收集所有候选(先不严格过滤,以便后续智能处理跨语言情况)
|
||||
// 只过滤掉相似度极低的结果(< 0.1),避免噪音
|
||||
if similarity < 0.1 {
|
||||
continue
|
||||
}
|
||||
|
||||
chunk := &KnowledgeChunk{
|
||||
@@ -202,36 +192,107 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
}
|
||||
|
||||
candidates = append(candidates, candidate{
|
||||
chunk: chunk,
|
||||
item: item,
|
||||
similarity: similarity,
|
||||
bm25Score: bm25Score,
|
||||
chunk: chunk,
|
||||
item: item,
|
||||
similarity: similarity,
|
||||
bm25Score: bm25Score,
|
||||
hasStrongKeywordMatch: hasStrongKeywordMatch,
|
||||
})
|
||||
}
|
||||
|
||||
// 先按相似度排序(使用更高效的排序)
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].similarity > candidates[j].similarity
|
||||
})
|
||||
|
||||
// 智能过滤策略:优先保留关键词匹配的结果,对跨语言查询使用更宽松的阈值
|
||||
filteredCandidates := make([]candidate, 0)
|
||||
|
||||
// 检查是否有任何关键词匹配(用于判断是否是跨语言查询)
|
||||
hasAnyKeywordMatch := false
|
||||
for _, cand := range candidates {
|
||||
if cand.hasStrongKeywordMatch {
|
||||
hasAnyKeywordMatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 根据是否有关键词匹配,采用不同的阈值策略
|
||||
effectiveThreshold := threshold
|
||||
if !hasAnyKeywordMatch {
|
||||
// 没有关键词匹配,可能是跨语言查询,适度放宽阈值
|
||||
// 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性
|
||||
// 跨语言阈值设为0.6,确保返回的结果至少有一定相关性
|
||||
effectiveThreshold = math.Max(threshold*0.85, 0.6)
|
||||
r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值",
|
||||
zap.Float64("originalThreshold", threshold),
|
||||
zap.Float64("effectiveThreshold", effectiveThreshold),
|
||||
)
|
||||
}
|
||||
|
||||
// 检查最高相似度,用于判断是否确实有相关内容
|
||||
maxSimilarity := 0.0
|
||||
if len(candidates) > 0 {
|
||||
maxSimilarity = candidates[0].similarity
|
||||
}
|
||||
|
||||
// 应用智能过滤
|
||||
for _, cand := range candidates {
|
||||
if cand.similarity >= effectiveThreshold {
|
||||
// 达到阈值,直接通过
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
} else if cand.hasStrongKeywordMatch {
|
||||
// 有关键词匹配但相似度略低于阈值,适当放宽
|
||||
relaxedThreshold := math.Max(effectiveThreshold*0.85, 0.55)
|
||||
if cand.similarity >= relaxedThreshold {
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
}
|
||||
}
|
||||
// 如果既没有关键词匹配,相似度又低于阈值,则过滤掉
|
||||
}
|
||||
|
||||
// 智能兜底策略:只有在最高相似度达到合理水平时,才考虑返回结果
|
||||
// 如果最高相似度都很低(<0.55),说明确实没有相关内容,应该返回空
|
||||
if len(filteredCandidates) == 0 && len(candidates) > 0 {
|
||||
// 即使没有通过阈值过滤,如果最高相似度还可以(>=0.55),可以考虑返回Top-K
|
||||
// 但这是最后的兜底,只在确实有一定相关性时才使用
|
||||
minAcceptableSimilarity := 0.55
|
||||
if maxSimilarity >= minAcceptableSimilarity {
|
||||
r.logger.Debug("过滤后无结果,但最高相似度可接受,返回Top-K结果",
|
||||
zap.Int("totalCandidates", len(candidates)),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
zap.Float64("effectiveThreshold", effectiveThreshold),
|
||||
)
|
||||
maxResults := topK
|
||||
if len(candidates) < maxResults {
|
||||
maxResults = len(candidates)
|
||||
}
|
||||
// 只返回相似度 >= 0.55 的结果
|
||||
for _, cand := range candidates {
|
||||
if cand.similarity >= minAcceptableSimilarity && len(filteredCandidates) < maxResults {
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
r.logger.Debug("过滤后无结果,且最高相似度过低,返回空结果",
|
||||
zap.Int("totalCandidates", len(candidates)),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
zap.Float64("minAcceptableSimilarity", minAcceptableSimilarity),
|
||||
)
|
||||
}
|
||||
} else if len(filteredCandidates) > topK {
|
||||
// 如果过滤后结果太多,只取Top-K
|
||||
filteredCandidates = filteredCandidates[:topK]
|
||||
}
|
||||
|
||||
candidates = filteredCandidates
|
||||
|
||||
// 混合排序(向量相似度 + BM25)
|
||||
hybridWeight := r.config.HybridWeight
|
||||
if hybridWeight == 0 {
|
||||
hybridWeight = 0.7
|
||||
}
|
||||
|
||||
// 按混合分数排序(简化:主要按相似度,BM25作为次要因素)
|
||||
// 这里我们主要使用相似度,因为BM25分数可能不稳定
|
||||
// 实际可以使用更复杂的混合策略
|
||||
|
||||
// 选择Top-K
|
||||
if len(candidates) > topK {
|
||||
// 简单排序(按相似度)
|
||||
for i := 0; i < len(candidates)-1; i++ {
|
||||
for j := i + 1; j < len(candidates); j++ {
|
||||
if candidates[i].similarity < candidates[j].similarity {
|
||||
candidates[i], candidates[j] = candidates[j], candidates[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
candidates = candidates[:topK]
|
||||
}
|
||||
|
||||
// 转换为结果
|
||||
results := make([]*RetrievalResult, len(candidates))
|
||||
for i, cand := range candidates {
|
||||
|
||||
@@ -105,8 +105,8 @@ func RegisterKnowledgeTool(
|
||||
|
||||
for i, result := range results {
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", i+1, result.Similarity*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s\n", result.Item.Category, result.Item.Title))
|
||||
resultText.WriteString(fmt.Sprintf("内容:\n%s\n\n", result.Chunk.ChunkText))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", result.Item.Category, result.Item.Title, result.Item.ID))
|
||||
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n\n", result.Chunk.ChunkText))
|
||||
|
||||
if !contains(retrievedItemIDs, result.Item.ID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, result.Item.ID)
|
||||
@@ -140,6 +140,89 @@ func RegisterKnowledgeTool(
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("知识检索工具已注册", zap.String("toolName", tool.Name))
|
||||
|
||||
// 注册读取完整知识项的工具
|
||||
RegisterReadKnowledgeItemTool(mcpServer, manager, logger)
|
||||
}
|
||||
|
||||
// RegisterReadKnowledgeItemTool 注册读取完整知识项工具到MCP服务器
|
||||
func RegisterReadKnowledgeItemTool(
|
||||
mcpServer *mcp.Server,
|
||||
manager *Manager,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
tool := mcp.Tool{
|
||||
Name: "read_knowledge_item",
|
||||
Description: "根据知识项ID读取完整的知识文档内容。**重要:此工具应谨慎使用,只在检索到的片段信息明显不足时才调用。** 使用场景:1) 检索片段缺少关键上下文导致无法理解;2) 需要查看文档的完整结构或流程;3) 片段信息不完整,必须查看完整文档才能回答用户问题。**不要**仅为了获取更多信息而盲目读取完整文档,因为检索工具已经返回了最相关的片段。传入知识项ID(从search_knowledge_base的检索结果中获取)即可获取该知识项的完整内容(包括标题、分类、完整文档内容等)。",
|
||||
ShortDescription: "读取完整知识项文档(仅在片段信息不足时使用)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"item_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "知识项ID(可以从 search_knowledge_base 的检索结果中获取)",
|
||||
},
|
||||
},
|
||||
"required": []string{"item_id"},
|
||||
},
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
itemID, ok := args["item_id"].(string)
|
||||
if !ok || itemID == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: item_id 参数不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
logger.Info("读取知识项", zap.String("itemId", itemID))
|
||||
|
||||
// 获取完整知识项
|
||||
item, err := manager.GetItem(itemID)
|
||||
if err != nil {
|
||||
logger.Error("读取知识项失败", zap.String("itemId", itemID), zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("读取知识项失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
resultText.WriteString("=== 完整知识项内容 ===\n\n")
|
||||
resultText.WriteString(fmt.Sprintf("ID: %s\n", item.ID))
|
||||
resultText.WriteString(fmt.Sprintf("分类: %s\n", item.Category))
|
||||
resultText.WriteString(fmt.Sprintf("标题: %s\n", item.Title))
|
||||
if item.FilePath != "" {
|
||||
resultText.WriteString(fmt.Sprintf("文件路径: %s\n", item.FilePath))
|
||||
}
|
||||
resultText.WriteString("\n--- 完整内容 ---\n\n")
|
||||
resultText.WriteString(item.Content)
|
||||
resultText.WriteString("\n\n")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
logger.Info("读取知识项工具已注册", zap.String("toolName", tool.Name))
|
||||
}
|
||||
|
||||
// contains 检查切片是否包含元素
|
||||
|
||||
Reference in New Issue
Block a user