diff --git a/internal/app/app.go b/internal/app/app.go index e2dd405f..c6568db6 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -242,7 +242,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) - // 如果知识库已启用,设置知识库工具注册器,以便在ApplyConfig时重新注册知识库工具 + // 如果知识库已启用,设置知识库工具注册器和检索器更新器 if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil { // 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用 registrar := func() error { @@ -250,6 +250,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { return nil } configHandler.SetKnowledgeToolRegistrar(registrar) + // 设置检索器更新器,以便在ApplyConfig时更新检索器配置 + configHandler.SetRetrieverUpdater(knowledgeRetriever) } externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) diff --git a/internal/handler/config.go b/internal/handler/config.go index 0d3de996..74b64f14 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -13,6 +13,7 @@ import ( "time" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/security" "github.com/gin-gonic/gin" @@ -23,6 +24,11 @@ import ( // KnowledgeToolRegistrar 知识库工具注册器接口 type KnowledgeToolRegistrar func() error +// RetrieverUpdater 检索器更新接口 +type RetrieverUpdater interface { + UpdateConfig(config *knowledge.RetrievalConfig) +} + // ConfigHandler 配置处理器 type ConfigHandler struct { configPath string @@ -33,6 +39,7 @@ type ConfigHandler struct { attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置 externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选) + retrieverUpdater RetrieverUpdater // 检索器更新器(可选) logger *zap.Logger mu sync.RWMutex } @@ -69,6 +76,13 @@ func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistr h.knowledgeToolRegistrar = registrar } +// SetRetrieverUpdater 设置检索器更新器 +func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { + h.mu.Lock() + defer h.mu.Unlock() + h.retrieverUpdater = updater +} + // GetConfigResponse 获取配置响应 type GetConfigResponse struct { OpenAI config.OpenAIConfig `json:"openai"` @@ -639,6 +653,21 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { h.logger.Info("AttackChainHandler配置已更新") } + // 更新检索器配置(如果知识库启用) + if h.config.Knowledge.Enabled && h.retrieverUpdater != nil { + retrievalConfig := &knowledge.RetrievalConfig{ + TopK: h.config.Knowledge.Retrieval.TopK, + SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold, + HybridWeight: h.config.Knowledge.Retrieval.HybridWeight, + } + h.retrieverUpdater.UpdateConfig(retrievalConfig) + h.logger.Info("检索器配置已更新", + zap.Int("top_k", retrievalConfig.TopK), + zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold), + zap.Float64("hybrid_weight", retrievalConfig.HybridWeight), + ) + } + h.logger.Info("配置已应用", zap.Int("tools_count", len(h.config.Security.Tools)), ) @@ -952,7 +981,13 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) { valueNode.Kind = yaml.ScalarNode valueNode.Tag = "!!float" valueNode.Style = 0 - valueNode.Value = fmt.Sprintf("%g", value) + // 对于0.0到1.0之间的值(如hybrid_weight),使用%.1f确保0.0被明确序列化为"0.0" + // 对于其他值,使用%g自动选择最合适的格式 + if value >= 0.0 && value <= 1.0 { + valueNode.Value = fmt.Sprintf("%.1f", value) + } else { + valueNode.Value = fmt.Sprintf("%g", value) + } } diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go index d859fcaf..2a87c450 100644 --- a/internal/knowledge/retriever.go +++ b/internal/knowledge/retriever.go @@ -37,6 +37,18 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge } } +// UpdateConfig 更新检索配置 +func (r *Retriever) UpdateConfig(config *RetrievalConfig) { + if config != nil { + r.config = config + r.logger.Info("检索器配置已更新", + zap.Int("top_k", config.TopK), + zap.Float64("similarity_threshold", config.SimilarityThreshold), + zap.Float64("hybrid_weight", config.HybridWeight), + ) + } +} + // cosineSimilarity 计算余弦相似度 func cosineSimilarity(a, b []float32) float64 { if len(a) != len(b) { @@ -57,27 +69,61 @@ func cosineSimilarity(a, b []float32) float64 { return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) } -// bm25Score 计算BM25分数(简化版) +// bm25Score 计算BM25分数(改进版,更接近标准BM25) +// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确 func (r *Retriever) bm25Score(query, text string) float64 { queryTerms := strings.Fields(strings.ToLower(query)) + if len(queryTerms) == 0 { + return 0.0 + } + textLower := strings.ToLower(text) textTerms := strings.Fields(textLower) + if len(textTerms) == 0 { + return 0.0 + } + + // BM25参数 + k1 := 1.5 // 词频饱和度参数 + b := 0.75 // 长度归一化参数 + avgDocLength := 100.0 // 估算的平均文档长度(用于归一化) + docLength := float64(len(textTerms)) score := 0.0 for _, term := range queryTerms { + // 计算词频(TF) termFreq := 0 for _, textTerm := range textTerms { if textTerm == term { termFreq++ } } + if termFreq > 0 { - // 简化的BM25公式 - score += float64(termFreq) / float64(len(textTerms)) + // BM25公式的核心部分 + // TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength))) + tf := float64(termFreq) + lengthNorm := 1 - b + b*(docLength/avgDocLength) + tfScore := tf / (tf + k1*lengthNorm) + + // 简化IDF:使用词长度作为权重(短词通常更重要) + // 实际BM25需要全局文档统计,这里用简化版本 + idfWeight := 1.0 + if len(term) > 2 { + // 长词稍微降低权重(但实际BM25中,罕见词IDF更高) + idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0) + } + + score += tfScore * idfWeight } } - return score / float64(len(queryTerms)) + // 归一化到0-1范围 + if len(queryTerms) > 0 { + score = score / float64(len(queryTerms)) + } + + return math.Min(score, 1.0) } // Search 搜索知识库 @@ -148,6 +194,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva similarity float64 bm25Score float64 hasStrongKeywordMatch bool + hybridScore float64 // 混合分数,用于最终排序 } candidates := make([]candidate, 0) @@ -229,19 +276,6 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva } } - // 根据是否有关键词匹配,采用不同的阈值策略 - effectiveThreshold := threshold - if !hasAnyKeywordMatch { - // 没有关键词匹配,可能是跨语言查询,适度放宽阈值 - // 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性 - // 跨语言阈值设为0.6,确保返回的结果至少有一定相关性 - effectiveThreshold = math.Max(threshold*0.85, 0.6) - r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值", - zap.Float64("originalThreshold", threshold), - zap.Float64("effectiveThreshold", effectiveThreshold), - ) - } - // 检查最高相似度,用于判断是否确实有相关内容 maxSimilarity := 0.0 if len(candidates) > 0 { @@ -249,12 +283,35 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva } // 应用智能过滤 + // 如果用户设置了高阈值(>=0.8),更严格地遵守阈值,减少自动放宽 + strictMode := threshold >= 0.8 + + // 根据是否有关键词匹配,采用不同的阈值策略 + // 严格模式下,禁用跨语言放宽策略,严格遵守用户设置的阈值 + effectiveThreshold := threshold + if !strictMode && !hasAnyKeywordMatch { + // 非严格模式下,没有关键词匹配,可能是跨语言查询,适度放宽阈值 + // 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性 + // 跨语言阈值设为0.6,确保返回的结果至少有一定相关性 + effectiveThreshold = math.Max(threshold*0.85, 0.6) + r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值", + zap.Float64("originalThreshold", threshold), + zap.Float64("effectiveThreshold", effectiveThreshold), + ) + } else if strictMode { + // 严格模式下,即使没有关键词匹配,也严格遵守阈值 + r.logger.Debug("严格模式:严格遵守用户设置的阈值", + zap.Float64("threshold", threshold), + zap.Bool("hasKeywordMatch", hasAnyKeywordMatch), + ) + } for _, cand := range candidates { if cand.similarity >= effectiveThreshold { // 达到阈值,直接通过 filteredCandidates = append(filteredCandidates, cand) - } else if cand.hasStrongKeywordMatch { - // 有关键词匹配但相似度略低于阈值,适当放宽 + } else if !strictMode && cand.hasStrongKeywordMatch { + // 非严格模式下,有关键词匹配但相似度略低于阈值,适当放宽 + // 严格模式下,即使有关键词匹配,也严格遵守阈值 relaxedThreshold := math.Max(effectiveThreshold*0.85, 0.55) if cand.similarity >= relaxedThreshold { filteredCandidates = append(filteredCandidates, cand) @@ -265,9 +322,11 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva // 智能兜底策略:只有在最高相似度达到合理水平时,才考虑返回结果 // 如果最高相似度都很低(<0.55),说明确实没有相关内容,应该返回空 - if len(filteredCandidates) == 0 && len(candidates) > 0 { + // 严格模式下(阈值>=0.8),禁用兜底策略,严格遵守用户设置的阈值 + if len(filteredCandidates) == 0 && len(candidates) > 0 && !strictMode { // 即使没有通过阈值过滤,如果最高相似度还可以(>=0.55),可以考虑返回Top-K // 但这是最后的兜底,只在确实有一定相关性时才使用 + // 严格模式下不使用兜底策略 minAcceptableSimilarity := 0.55 if maxSimilarity >= minAcceptableSimilarity { r.logger.Debug("过滤后无结果,但最高相似度可接受,返回Top-K结果", @@ -292,6 +351,12 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva zap.Float64("minAcceptableSimilarity", minAcceptableSimilarity), ) } + } else if len(filteredCandidates) == 0 && strictMode { + // 严格模式下,如果过滤后无结果,直接返回空,不使用兜底策略 + r.logger.Debug("严格模式:过滤后无结果,严格遵守阈值,返回空结果", + zap.Float64("threshold", threshold), + zap.Float64("maxSimilarity", maxSimilarity), + ) } else if len(filteredCandidates) > topK { // 如果过滤后结果太多,只取Top-K filteredCandidates = filteredCandidates[:topK] @@ -300,23 +365,29 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva candidates = filteredCandidates // 混合排序(向量相似度 + BM25) + // 注意:hybridWeight可以是0.0(纯关键词检索),所以不设置默认值 + // 如果配置文件中未设置,应该在配置加载时使用默认值 hybridWeight := r.config.HybridWeight - if hybridWeight == 0 { - hybridWeight = 0.7 + + // 先计算混合分数并存储在candidate中,用于排序 + for i := range candidates { + normalizedBM25 := math.Min(candidates[i].bm25Score, 1.0) + candidates[i].hybridScore = hybridWeight*candidates[i].similarity + (1-hybridWeight)*normalizedBM25 } + // 根据混合分数重新排序(这才是真正的混合检索) + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].hybridScore > candidates[j].hybridScore + }) + // 转换为结果 results := make([]*RetrievalResult, len(candidates)) for i, cand := range candidates { - // 计算混合分数 - normalizedBM25 := math.Min(cand.bm25Score, 1.0) - hybridScore := hybridWeight*cand.similarity + (1-hybridWeight)*normalizedBM25 - results[i] = &RetrievalResult{ Chunk: cand.chunk, Item: cand.item, Similarity: cand.similarity, - Score: hybridScore, + Score: cand.hybridScore, } } @@ -385,12 +456,12 @@ func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResul } // 为该文档的匹配chunk收集需要扩展的相邻chunk - // 策略:只对相似度最高的前3个匹配chunk进行扩展,避免扩展过多 - // 先按相似度排序,只扩展前3个 + // 策略:只对混合分数最高的前3个匹配chunk进行扩展,避免扩展过多 + // 先按混合分数排序,只扩展前3个(使用混合分数而不是相似度) sortedItemResults := make([]*RetrievalResult, len(itemResults)) copy(sortedItemResults, itemResults) sort.Slice(sortedItemResults, func(i, j int) bool { - return sortedItemResults[i].Similarity > sortedItemResults[j].Similarity + return sortedItemResults[i].Score > sortedItemResults[j].Score }) // 只扩展前3个(或所有,如果少于3个) diff --git a/internal/knowledge/tool.go b/internal/knowledge/tool.go index b6eb2c52..a1001bf6 100644 --- a/internal/knowledge/tool.go +++ b/internal/knowledge/tool.go @@ -157,26 +157,58 @@ func RegisterKnowledgeTool( // 格式化结果 var resultText strings.Builder + // 先按混合分数排序,确保文档顺序是按混合分数的(混合检索的核心) + sort.Slice(results, func(i, j int) bool { + return results[i].Score > results[j].Score + }) + // 按文档分组结果,以便更好地展示上下文 - resultsByItem := make(map[string][]*RetrievalResult) + // 使用有序的slice来保持文档顺序(按最高混合分数) + type itemGroup struct { + itemID string + results []*RetrievalResult + maxScore float64 // 该文档的最高混合分数 + } + itemGroups := make([]*itemGroup, 0) + itemMap := make(map[string]*itemGroup) + for _, result := range results { itemID := result.Item.ID - resultsByItem[itemID] = append(resultsByItem[itemID], result) + group, exists := itemMap[itemID] + if !exists { + group = &itemGroup{ + itemID: itemID, + results: make([]*RetrievalResult, 0), + maxScore: result.Score, + } + itemMap[itemID] = group + itemGroups = append(itemGroups, group) + } + group.results = append(group.results, result) + if result.Score > group.maxScore { + group.maxScore = result.Score + } } + // 按最高混合分数排序文档组 + sort.Slice(itemGroups, func(i, j int) bool { + return itemGroups[i].maxScore > itemGroups[j].maxScore + }) + // 收集检索到的知识项ID(用于日志) - retrievedItemIDs := make([]string, 0, len(resultsByItem)) + retrievedItemIDs := make([]string, 0, len(itemGroups)) resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results))) resultIndex := 1 - for itemID, itemResults := range resultsByItem { - // 找到相似度最高的作为主结果 + for _, group := range itemGroups { + itemResults := group.results + // 找到混合分数最高的作为主结果(使用混合分数,而不是相似度) mainResult := itemResults[0] - maxSimilarity := mainResult.Similarity + maxScore := mainResult.Score for _, result := range itemResults { - if result.Similarity > maxSimilarity { - maxSimilarity = result.Similarity + if result.Score > maxScore { + maxScore = result.Score mainResult = result } } @@ -186,8 +218,9 @@ func RegisterKnowledgeTool( return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex }) - // 显示主结果(相似度最高的) - resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", resultIndex, mainResult.Similarity*100)) + // 显示主结果(混合分数最高的,同时显示相似度和混合分数) + resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n", + resultIndex, mainResult.Similarity*100, mainResult.Score*100)) resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) // 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk) @@ -208,8 +241,8 @@ func RegisterKnowledgeTool( } resultText.WriteString("\n") - if !contains(retrievedItemIDs, itemID) { - retrievedItemIDs = append(retrievedItemIDs, itemID) + if !contains(retrievedItemIDs, group.itemID) { + retrievedItemIDs = append(retrievedItemIDs, group.itemID) } resultIndex++ } diff --git a/web/static/js/settings.js b/web/static/js/settings.js index 62fc17e5..f063729f 100644 --- a/web/static/js/settings.js +++ b/web/static/js/settings.js @@ -146,7 +146,9 @@ async function loadConfig(loadTools = true) { const retrievalWeightInput = document.getElementById('knowledge-retrieval-hybrid-weight'); if (retrievalWeightInput) { - retrievalWeightInput.value = knowledge.retrieval?.hybrid_weight || 0.7; + const hybridWeight = knowledge.retrieval?.hybrid_weight; + // 允许0.0值,只有undefined/null时才使用默认值 + retrievalWeightInput.value = (hybridWeight !== undefined && hybridWeight !== null) ? hybridWeight : 0.7; } } @@ -613,8 +615,14 @@ async function applySettings() { }, retrieval: { top_k: parseInt(document.getElementById('knowledge-retrieval-top-k')?.value) || 5, - similarity_threshold: parseFloat(document.getElementById('knowledge-retrieval-similarity-threshold')?.value) || 0.7, - hybrid_weight: parseFloat(document.getElementById('knowledge-retrieval-hybrid-weight')?.value) || 0.7 + similarity_threshold: (() => { + const val = parseFloat(document.getElementById('knowledge-retrieval-similarity-threshold')?.value); + return isNaN(val) ? 0.7 : val; + })(), + hybrid_weight: (() => { + const val = parseFloat(document.getElementById('knowledge-retrieval-hybrid-weight')?.value); + return isNaN(val) ? 0.7 : val; // 允许0.0值,只有NaN时才使用默认值 + })() } };