Add files via upload

This commit is contained in:
公明
2025-12-27 03:57:01 +08:00
committed by GitHub
parent 3e0867d459
commit 604e31d247
5 changed files with 196 additions and 47 deletions

View File

@@ -242,7 +242,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
// 如果知识库已启用,设置知识库工具注册器以便在ApplyConfig时重新注册知识库工具
// 如果知识库已启用,设置知识库工具注册器和检索器更新器
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
@@ -250,6 +250,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(knowledgeRetriever)
}
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)

View File

@@ -13,6 +13,7 @@ import (
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
@@ -23,6 +24,11 @@ import (
// KnowledgeToolRegistrar 知识库工具注册器接口
type KnowledgeToolRegistrar func() error
// RetrieverUpdater 检索器更新接口
type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig)
}
// ConfigHandler 配置处理器
type ConfigHandler struct {
configPath string
@@ -33,6 +39,7 @@ type ConfigHandler struct {
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
logger *zap.Logger
mu sync.RWMutex
}
@@ -69,6 +76,13 @@ func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistr
h.knowledgeToolRegistrar = registrar
}
// SetRetrieverUpdater 设置检索器更新器
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.mu.Lock()
defer h.mu.Unlock()
h.retrieverUpdater = updater
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
@@ -639,6 +653,21 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("AttackChainHandler配置已更新")
}
// 更新检索器配置(如果知识库启用)
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
retrievalConfig := &knowledge.RetrievalConfig{
TopK: h.config.Knowledge.Retrieval.TopK,
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
}
h.retrieverUpdater.UpdateConfig(retrievalConfig)
h.logger.Info("检索器配置已更新",
zap.Int("top_k", retrievalConfig.TopK),
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
)
}
h.logger.Info("配置已应用",
zap.Int("tools_count", len(h.config.Security.Tools)),
)
@@ -952,7 +981,13 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
valueNode.Kind = yaml.ScalarNode
valueNode.Tag = "!!float"
valueNode.Style = 0
valueNode.Value = fmt.Sprintf("%g", value)
// 对于0.0到1.0之间的值如hybrid_weight使用%.1f确保0.0被明确序列化为"0.0"
// 对于其他值,使用%g自动选择最合适的格式
if value >= 0.0 && value <= 1.0 {
valueNode.Value = fmt.Sprintf("%.1f", value)
} else {
valueNode.Value = fmt.Sprintf("%g", value)
}
}

View File

@@ -37,6 +37,18 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge
}
}
// UpdateConfig 更新检索配置
func (r *Retriever) UpdateConfig(config *RetrievalConfig) {
if config != nil {
r.config = config
r.logger.Info("检索器配置已更新",
zap.Int("top_k", config.TopK),
zap.Float64("similarity_threshold", config.SimilarityThreshold),
zap.Float64("hybrid_weight", config.HybridWeight),
)
}
}
// cosineSimilarity 计算余弦相似度
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
@@ -57,27 +69,61 @@ func cosineSimilarity(a, b []float32) float64 {
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// bm25Score 计算BM25分数简化版
// bm25Score 计算BM25分数改进版更接近标准BM25
// 注意这是单文档版本的BM25缺少全局IDF但比之前的简化版本更准确
func (r *Retriever) bm25Score(query, text string) float64 {
queryTerms := strings.Fields(strings.ToLower(query))
if len(queryTerms) == 0 {
return 0.0
}
textLower := strings.ToLower(text)
textTerms := strings.Fields(textLower)
if len(textTerms) == 0 {
return 0.0
}
// BM25参数
k1 := 1.5 // 词频饱和度参数
b := 0.75 // 长度归一化参数
avgDocLength := 100.0 // 估算的平均文档长度(用于归一化)
docLength := float64(len(textTerms))
score := 0.0
for _, term := range queryTerms {
// 计算词频TF
termFreq := 0
for _, textTerm := range textTerms {
if textTerm == term {
termFreq++
}
}
if termFreq > 0 {
// 简化的BM25公式
score += float64(termFreq) / float64(len(textTerms))
// BM25公式的核心部分
// TF部分termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength)))
tf := float64(termFreq)
lengthNorm := 1 - b + b*(docLength/avgDocLength)
tfScore := tf / (tf + k1*lengthNorm)
// 简化IDF使用词长度作为权重短词通常更重要
// 实际BM25需要全局文档统计这里用简化版本
idfWeight := 1.0
if len(term) > 2 {
// 长词稍微降低权重但实际BM25中罕见词IDF更高
idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0)
}
score += tfScore * idfWeight
}
}
return score / float64(len(queryTerms))
// 归一化到0-1范围
if len(queryTerms) > 0 {
score = score / float64(len(queryTerms))
}
return math.Min(score, 1.0)
}
// Search 搜索知识库
@@ -148,6 +194,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
similarity float64
bm25Score float64
hasStrongKeywordMatch bool
hybridScore float64 // 混合分数,用于最终排序
}
candidates := make([]candidate, 0)
@@ -229,19 +276,6 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
}
}
// 根据是否有关键词匹配,采用不同的阈值策略
effectiveThreshold := threshold
if !hasAnyKeywordMatch {
// 没有关键词匹配,可能是跨语言查询,适度放宽阈值
// 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性
// 跨语言阈值设为0.6,确保返回的结果至少有一定相关性
effectiveThreshold = math.Max(threshold*0.85, 0.6)
r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值",
zap.Float64("originalThreshold", threshold),
zap.Float64("effectiveThreshold", effectiveThreshold),
)
}
// 检查最高相似度,用于判断是否确实有相关内容
maxSimilarity := 0.0
if len(candidates) > 0 {
@@ -249,12 +283,35 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
}
// 应用智能过滤
// 如果用户设置了高阈值(>=0.8),更严格地遵守阈值,减少自动放宽
strictMode := threshold >= 0.8
// 根据是否有关键词匹配,采用不同的阈值策略
// 严格模式下,禁用跨语言放宽策略,严格遵守用户设置的阈值
effectiveThreshold := threshold
if !strictMode && !hasAnyKeywordMatch {
// 非严格模式下,没有关键词匹配,可能是跨语言查询,适度放宽阈值
// 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性
// 跨语言阈值设为0.6,确保返回的结果至少有一定相关性
effectiveThreshold = math.Max(threshold*0.85, 0.6)
r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值",
zap.Float64("originalThreshold", threshold),
zap.Float64("effectiveThreshold", effectiveThreshold),
)
} else if strictMode {
// 严格模式下,即使没有关键词匹配,也严格遵守阈值
r.logger.Debug("严格模式:严格遵守用户设置的阈值",
zap.Float64("threshold", threshold),
zap.Bool("hasKeywordMatch", hasAnyKeywordMatch),
)
}
for _, cand := range candidates {
if cand.similarity >= effectiveThreshold {
// 达到阈值,直接通过
filteredCandidates = append(filteredCandidates, cand)
} else if cand.hasStrongKeywordMatch {
// 有关键词匹配但相似度略低于阈值,适当放宽
} else if !strictMode && cand.hasStrongKeywordMatch {
// 非严格模式下,有关键词匹配但相似度略低于阈值,适当放宽
// 严格模式下,即使有关键词匹配,也严格遵守阈值
relaxedThreshold := math.Max(effectiveThreshold*0.85, 0.55)
if cand.similarity >= relaxedThreshold {
filteredCandidates = append(filteredCandidates, cand)
@@ -265,9 +322,11 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
// 智能兜底策略:只有在最高相似度达到合理水平时,才考虑返回结果
// 如果最高相似度都很低(<0.55),说明确实没有相关内容,应该返回空
if len(filteredCandidates) == 0 && len(candidates) > 0 {
// 严格模式下(阈值>=0.8),禁用兜底策略,严格遵守用户设置的阈值
if len(filteredCandidates) == 0 && len(candidates) > 0 && !strictMode {
// 即使没有通过阈值过滤,如果最高相似度还可以(>=0.55可以考虑返回Top-K
// 但这是最后的兜底,只在确实有一定相关性时才使用
// 严格模式下不使用兜底策略
minAcceptableSimilarity := 0.55
if maxSimilarity >= minAcceptableSimilarity {
r.logger.Debug("过滤后无结果但最高相似度可接受返回Top-K结果",
@@ -292,6 +351,12 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
zap.Float64("minAcceptableSimilarity", minAcceptableSimilarity),
)
}
} else if len(filteredCandidates) == 0 && strictMode {
// 严格模式下,如果过滤后无结果,直接返回空,不使用兜底策略
r.logger.Debug("严格模式:过滤后无结果,严格遵守阈值,返回空结果",
zap.Float64("threshold", threshold),
zap.Float64("maxSimilarity", maxSimilarity),
)
} else if len(filteredCandidates) > topK {
// 如果过滤后结果太多只取Top-K
filteredCandidates = filteredCandidates[:topK]
@@ -300,23 +365,29 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
candidates = filteredCandidates
// 混合排序(向量相似度 + BM25
// 注意hybridWeight可以是0.0(纯关键词检索),所以不设置默认值
// 如果配置文件中未设置,应该在配置加载时使用默认值
hybridWeight := r.config.HybridWeight
if hybridWeight == 0 {
hybridWeight = 0.7
// 先计算混合分数并存储在candidate中用于排序
for i := range candidates {
normalizedBM25 := math.Min(candidates[i].bm25Score, 1.0)
candidates[i].hybridScore = hybridWeight*candidates[i].similarity + (1-hybridWeight)*normalizedBM25
}
// 根据混合分数重新排序(这才是真正的混合检索)
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].hybridScore > candidates[j].hybridScore
})
// 转换为结果
results := make([]*RetrievalResult, len(candidates))
for i, cand := range candidates {
// 计算混合分数
normalizedBM25 := math.Min(cand.bm25Score, 1.0)
hybridScore := hybridWeight*cand.similarity + (1-hybridWeight)*normalizedBM25
results[i] = &RetrievalResult{
Chunk: cand.chunk,
Item: cand.item,
Similarity: cand.similarity,
Score: hybridScore,
Score: cand.hybridScore,
}
}
@@ -385,12 +456,12 @@ func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResul
}
// 为该文档的匹配chunk收集需要扩展的相邻chunk
// 策略:只对相似度最高的前3个匹配chunk进行扩展避免扩展过多
// 先按相似度排序只扩展前3个
// 策略:只对混合分数最高的前3个匹配chunk进行扩展避免扩展过多
// 先按混合分数排序只扩展前3个(使用混合分数而不是相似度)
sortedItemResults := make([]*RetrievalResult, len(itemResults))
copy(sortedItemResults, itemResults)
sort.Slice(sortedItemResults, func(i, j int) bool {
return sortedItemResults[i].Similarity > sortedItemResults[j].Similarity
return sortedItemResults[i].Score > sortedItemResults[j].Score
})
// 只扩展前3个或所有如果少于3个

View File

@@ -157,26 +157,58 @@ func RegisterKnowledgeTool(
// 格式化结果
var resultText strings.Builder
// 先按混合分数排序,确保文档顺序是按混合分数的(混合检索的核心)
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// 按文档分组结果,以便更好地展示上下文
resultsByItem := make(map[string][]*RetrievalResult)
// 使用有序的slice来保持文档顺序按最高混合分数
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
for _, result := range results {
itemID := result.Item.ID
resultsByItem[itemID] = append(resultsByItem[itemID], result)
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
itemGroups = append(itemGroups, group)
}
group.results = append(group.results, result)
if result.Score > group.maxScore {
group.maxScore = result.Score
}
}
// 按最高混合分数排序文档组
sort.Slice(itemGroups, func(i, j int) bool {
return itemGroups[i].maxScore > itemGroups[j].maxScore
})
// 收集检索到的知识项ID用于日志
retrievedItemIDs := make([]string, 0, len(resultsByItem))
retrievedItemIDs := make([]string, 0, len(itemGroups))
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results)))
resultIndex := 1
for itemID, itemResults := range resultsByItem {
// 找到相似度最高的作为主结果
for _, group := range itemGroups {
itemResults := group.results
// 找到混合分数最高的作为主结果(使用混合分数,而不是相似度)
mainResult := itemResults[0]
maxSimilarity := mainResult.Similarity
maxScore := mainResult.Score
for _, result := range itemResults {
if result.Similarity > maxSimilarity {
maxSimilarity = result.Similarity
if result.Score > maxScore {
maxScore = result.Score
mainResult = result
}
}
@@ -186,8 +218,9 @@ func RegisterKnowledgeTool(
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
})
// 显示主结果(相似度最高的
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", resultIndex, mainResult.Similarity*100))
// 显示主结果(混合分数最高的,同时显示相似度和混合分数
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
// 按逻辑顺序显示所有chunk包括主结果和扩展的chunk
@@ -208,8 +241,8 @@ func RegisterKnowledgeTool(
}
resultText.WriteString("\n")
if !contains(retrievedItemIDs, itemID) {
retrievedItemIDs = append(retrievedItemIDs, itemID)
if !contains(retrievedItemIDs, group.itemID) {
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
}
resultIndex++
}