mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
Add files via upload
This commit is contained in:
@@ -242,7 +242,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
// 如果知识库已启用,设置知识库工具注册器,以便在ApplyConfig时重新注册知识库工具
|
||||
// 如果知识库已启用,设置知识库工具注册器和检索器更新器
|
||||
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
|
||||
// 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用
|
||||
registrar := func() error {
|
||||
@@ -250,6 +250,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
return nil
|
||||
}
|
||||
configHandler.SetKnowledgeToolRegistrar(registrar)
|
||||
// 设置检索器更新器,以便在ApplyConfig时更新检索器配置
|
||||
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
||||
}
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -23,6 +24,11 @@ import (
|
||||
// KnowledgeToolRegistrar 知识库工具注册器接口
|
||||
type KnowledgeToolRegistrar func() error
|
||||
|
||||
// RetrieverUpdater 检索器更新接口
|
||||
type RetrieverUpdater interface {
|
||||
UpdateConfig(config *knowledge.RetrievalConfig)
|
||||
}
|
||||
|
||||
// ConfigHandler 配置处理器
|
||||
type ConfigHandler struct {
|
||||
configPath string
|
||||
@@ -33,6 +39,7 @@ type ConfigHandler struct {
|
||||
attackChainHandler AttackChainUpdater // 攻击链处理器接口,用于更新配置
|
||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||
knowledgeToolRegistrar KnowledgeToolRegistrar // 知识库工具注册器(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -69,6 +76,13 @@ func (h *ConfigHandler) SetKnowledgeToolRegistrar(registrar KnowledgeToolRegistr
|
||||
h.knowledgeToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetRetrieverUpdater 设置检索器更新器
|
||||
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.retrieverUpdater = updater
|
||||
}
|
||||
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
@@ -639,6 +653,21 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("AttackChainHandler配置已更新")
|
||||
}
|
||||
|
||||
// 更新检索器配置(如果知识库启用)
|
||||
if h.config.Knowledge.Enabled && h.retrieverUpdater != nil {
|
||||
retrievalConfig := &knowledge.RetrievalConfig{
|
||||
TopK: h.config.Knowledge.Retrieval.TopK,
|
||||
SimilarityThreshold: h.config.Knowledge.Retrieval.SimilarityThreshold,
|
||||
HybridWeight: h.config.Knowledge.Retrieval.HybridWeight,
|
||||
}
|
||||
h.retrieverUpdater.UpdateConfig(retrievalConfig)
|
||||
h.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", retrievalConfig.TopK),
|
||||
zap.Float64("similarity_threshold", retrievalConfig.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", retrievalConfig.HybridWeight),
|
||||
)
|
||||
}
|
||||
|
||||
h.logger.Info("配置已应用",
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
@@ -952,7 +981,13 @@ func setFloatInMap(mapNode *yaml.Node, key string, value float64) {
|
||||
valueNode.Kind = yaml.ScalarNode
|
||||
valueNode.Tag = "!!float"
|
||||
valueNode.Style = 0
|
||||
valueNode.Value = fmt.Sprintf("%g", value)
|
||||
// 对于0.0到1.0之间的值(如hybrid_weight),使用%.1f确保0.0被明确序列化为"0.0"
|
||||
// 对于其他值,使用%g自动选择最合适的格式
|
||||
if value >= 0.0 && value <= 1.0 {
|
||||
valueNode.Value = fmt.Sprintf("%.1f", value)
|
||||
} else {
|
||||
valueNode.Value = fmt.Sprintf("%g", value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -37,6 +37,18 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新检索配置
|
||||
func (r *Retriever) UpdateConfig(config *RetrievalConfig) {
|
||||
if config != nil {
|
||||
r.config = config
|
||||
r.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", config.TopK),
|
||||
zap.Float64("similarity_threshold", config.SimilarityThreshold),
|
||||
zap.Float64("hybrid_weight", config.HybridWeight),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// cosineSimilarity 计算余弦相似度
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) {
|
||||
@@ -57,27 +69,61 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// bm25Score 计算BM25分数(简化版)
|
||||
// bm25Score 计算BM25分数(改进版,更接近标准BM25)
|
||||
// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确
|
||||
func (r *Retriever) bm25Score(query, text string) float64 {
|
||||
queryTerms := strings.Fields(strings.ToLower(query))
|
||||
if len(queryTerms) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
textLower := strings.ToLower(text)
|
||||
textTerms := strings.Fields(textLower)
|
||||
if len(textTerms) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// BM25参数
|
||||
k1 := 1.5 // 词频饱和度参数
|
||||
b := 0.75 // 长度归一化参数
|
||||
avgDocLength := 100.0 // 估算的平均文档长度(用于归一化)
|
||||
docLength := float64(len(textTerms))
|
||||
|
||||
score := 0.0
|
||||
for _, term := range queryTerms {
|
||||
// 计算词频(TF)
|
||||
termFreq := 0
|
||||
for _, textTerm := range textTerms {
|
||||
if textTerm == term {
|
||||
termFreq++
|
||||
}
|
||||
}
|
||||
|
||||
if termFreq > 0 {
|
||||
// 简化的BM25公式
|
||||
score += float64(termFreq) / float64(len(textTerms))
|
||||
// BM25公式的核心部分
|
||||
// TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength)))
|
||||
tf := float64(termFreq)
|
||||
lengthNorm := 1 - b + b*(docLength/avgDocLength)
|
||||
tfScore := tf / (tf + k1*lengthNorm)
|
||||
|
||||
// 简化IDF:使用词长度作为权重(短词通常更重要)
|
||||
// 实际BM25需要全局文档统计,这里用简化版本
|
||||
idfWeight := 1.0
|
||||
if len(term) > 2 {
|
||||
// 长词稍微降低权重(但实际BM25中,罕见词IDF更高)
|
||||
idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0)
|
||||
}
|
||||
|
||||
score += tfScore * idfWeight
|
||||
}
|
||||
}
|
||||
|
||||
return score / float64(len(queryTerms))
|
||||
// 归一化到0-1范围
|
||||
if len(queryTerms) > 0 {
|
||||
score = score / float64(len(queryTerms))
|
||||
}
|
||||
|
||||
return math.Min(score, 1.0)
|
||||
}
|
||||
|
||||
// Search 搜索知识库
|
||||
@@ -148,6 +194,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
similarity float64
|
||||
bm25Score float64
|
||||
hasStrongKeywordMatch bool
|
||||
hybridScore float64 // 混合分数,用于最终排序
|
||||
}
|
||||
|
||||
candidates := make([]candidate, 0)
|
||||
@@ -229,19 +276,6 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
}
|
||||
}
|
||||
|
||||
// 根据是否有关键词匹配,采用不同的阈值策略
|
||||
effectiveThreshold := threshold
|
||||
if !hasAnyKeywordMatch {
|
||||
// 没有关键词匹配,可能是跨语言查询,适度放宽阈值
|
||||
// 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性
|
||||
// 跨语言阈值设为0.6,确保返回的结果至少有一定相关性
|
||||
effectiveThreshold = math.Max(threshold*0.85, 0.6)
|
||||
r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值",
|
||||
zap.Float64("originalThreshold", threshold),
|
||||
zap.Float64("effectiveThreshold", effectiveThreshold),
|
||||
)
|
||||
}
|
||||
|
||||
// 检查最高相似度,用于判断是否确实有相关内容
|
||||
maxSimilarity := 0.0
|
||||
if len(candidates) > 0 {
|
||||
@@ -249,12 +283,35 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
}
|
||||
|
||||
// 应用智能过滤
|
||||
// 如果用户设置了高阈值(>=0.8),更严格地遵守阈值,减少自动放宽
|
||||
strictMode := threshold >= 0.8
|
||||
|
||||
// 根据是否有关键词匹配,采用不同的阈值策略
|
||||
// 严格模式下,禁用跨语言放宽策略,严格遵守用户设置的阈值
|
||||
effectiveThreshold := threshold
|
||||
if !strictMode && !hasAnyKeywordMatch {
|
||||
// 非严格模式下,没有关键词匹配,可能是跨语言查询,适度放宽阈值
|
||||
// 但即使跨语言,也不能无脑降低阈值,需要保证最低相关性
|
||||
// 跨语言阈值设为0.6,确保返回的结果至少有一定相关性
|
||||
effectiveThreshold = math.Max(threshold*0.85, 0.6)
|
||||
r.logger.Debug("检测到可能的跨语言查询,使用放宽的阈值",
|
||||
zap.Float64("originalThreshold", threshold),
|
||||
zap.Float64("effectiveThreshold", effectiveThreshold),
|
||||
)
|
||||
} else if strictMode {
|
||||
// 严格模式下,即使没有关键词匹配,也严格遵守阈值
|
||||
r.logger.Debug("严格模式:严格遵守用户设置的阈值",
|
||||
zap.Float64("threshold", threshold),
|
||||
zap.Bool("hasKeywordMatch", hasAnyKeywordMatch),
|
||||
)
|
||||
}
|
||||
for _, cand := range candidates {
|
||||
if cand.similarity >= effectiveThreshold {
|
||||
// 达到阈值,直接通过
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
} else if cand.hasStrongKeywordMatch {
|
||||
// 有关键词匹配但相似度略低于阈值,适当放宽
|
||||
} else if !strictMode && cand.hasStrongKeywordMatch {
|
||||
// 非严格模式下,有关键词匹配但相似度略低于阈值,适当放宽
|
||||
// 严格模式下,即使有关键词匹配,也严格遵守阈值
|
||||
relaxedThreshold := math.Max(effectiveThreshold*0.85, 0.55)
|
||||
if cand.similarity >= relaxedThreshold {
|
||||
filteredCandidates = append(filteredCandidates, cand)
|
||||
@@ -265,9 +322,11 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
|
||||
// 智能兜底策略:只有在最高相似度达到合理水平时,才考虑返回结果
|
||||
// 如果最高相似度都很低(<0.55),说明确实没有相关内容,应该返回空
|
||||
if len(filteredCandidates) == 0 && len(candidates) > 0 {
|
||||
// 严格模式下(阈值>=0.8),禁用兜底策略,严格遵守用户设置的阈值
|
||||
if len(filteredCandidates) == 0 && len(candidates) > 0 && !strictMode {
|
||||
// 即使没有通过阈值过滤,如果最高相似度还可以(>=0.55),可以考虑返回Top-K
|
||||
// 但这是最后的兜底,只在确实有一定相关性时才使用
|
||||
// 严格模式下不使用兜底策略
|
||||
minAcceptableSimilarity := 0.55
|
||||
if maxSimilarity >= minAcceptableSimilarity {
|
||||
r.logger.Debug("过滤后无结果,但最高相似度可接受,返回Top-K结果",
|
||||
@@ -292,6 +351,12 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
zap.Float64("minAcceptableSimilarity", minAcceptableSimilarity),
|
||||
)
|
||||
}
|
||||
} else if len(filteredCandidates) == 0 && strictMode {
|
||||
// 严格模式下,如果过滤后无结果,直接返回空,不使用兜底策略
|
||||
r.logger.Debug("严格模式:过滤后无结果,严格遵守阈值,返回空结果",
|
||||
zap.Float64("threshold", threshold),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
)
|
||||
} else if len(filteredCandidates) > topK {
|
||||
// 如果过滤后结果太多,只取Top-K
|
||||
filteredCandidates = filteredCandidates[:topK]
|
||||
@@ -300,23 +365,29 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
candidates = filteredCandidates
|
||||
|
||||
// 混合排序(向量相似度 + BM25)
|
||||
// 注意:hybridWeight可以是0.0(纯关键词检索),所以不设置默认值
|
||||
// 如果配置文件中未设置,应该在配置加载时使用默认值
|
||||
hybridWeight := r.config.HybridWeight
|
||||
if hybridWeight == 0 {
|
||||
hybridWeight = 0.7
|
||||
|
||||
// 先计算混合分数并存储在candidate中,用于排序
|
||||
for i := range candidates {
|
||||
normalizedBM25 := math.Min(candidates[i].bm25Score, 1.0)
|
||||
candidates[i].hybridScore = hybridWeight*candidates[i].similarity + (1-hybridWeight)*normalizedBM25
|
||||
}
|
||||
|
||||
// 根据混合分数重新排序(这才是真正的混合检索)
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].hybridScore > candidates[j].hybridScore
|
||||
})
|
||||
|
||||
// 转换为结果
|
||||
results := make([]*RetrievalResult, len(candidates))
|
||||
for i, cand := range candidates {
|
||||
// 计算混合分数
|
||||
normalizedBM25 := math.Min(cand.bm25Score, 1.0)
|
||||
hybridScore := hybridWeight*cand.similarity + (1-hybridWeight)*normalizedBM25
|
||||
|
||||
results[i] = &RetrievalResult{
|
||||
Chunk: cand.chunk,
|
||||
Item: cand.item,
|
||||
Similarity: cand.similarity,
|
||||
Score: hybridScore,
|
||||
Score: cand.hybridScore,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,12 +456,12 @@ func (r *Retriever) expandContext(ctx context.Context, results []*RetrievalResul
|
||||
}
|
||||
|
||||
// 为该文档的匹配chunk收集需要扩展的相邻chunk
|
||||
// 策略:只对相似度最高的前3个匹配chunk进行扩展,避免扩展过多
|
||||
// 先按相似度排序,只扩展前3个
|
||||
// 策略:只对混合分数最高的前3个匹配chunk进行扩展,避免扩展过多
|
||||
// 先按混合分数排序,只扩展前3个(使用混合分数而不是相似度)
|
||||
sortedItemResults := make([]*RetrievalResult, len(itemResults))
|
||||
copy(sortedItemResults, itemResults)
|
||||
sort.Slice(sortedItemResults, func(i, j int) bool {
|
||||
return sortedItemResults[i].Similarity > sortedItemResults[j].Similarity
|
||||
return sortedItemResults[i].Score > sortedItemResults[j].Score
|
||||
})
|
||||
|
||||
// 只扩展前3个(或所有,如果少于3个)
|
||||
|
||||
@@ -157,26 +157,58 @@ func RegisterKnowledgeTool(
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
|
||||
// 先按混合分数排序,确保文档顺序是按混合分数的(混合检索的核心)
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].Score > results[j].Score
|
||||
})
|
||||
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
resultsByItem := make(map[string][]*RetrievalResult)
|
||||
// 使用有序的slice来保持文档顺序(按最高混合分数)
|
||||
type itemGroup struct {
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档的最高混合分数
|
||||
}
|
||||
itemGroups := make([]*itemGroup, 0)
|
||||
itemMap := make(map[string]*itemGroup)
|
||||
|
||||
for _, result := range results {
|
||||
itemID := result.Item.ID
|
||||
resultsByItem[itemID] = append(resultsByItem[itemID], result)
|
||||
group, exists := itemMap[itemID]
|
||||
if !exists {
|
||||
group = &itemGroup{
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
maxScore: result.Score,
|
||||
}
|
||||
itemMap[itemID] = group
|
||||
itemGroups = append(itemGroups, group)
|
||||
}
|
||||
group.results = append(group.results, result)
|
||||
if result.Score > group.maxScore {
|
||||
group.maxScore = result.Score
|
||||
}
|
||||
}
|
||||
|
||||
// 按最高混合分数排序文档组
|
||||
sort.Slice(itemGroups, func(i, j int) bool {
|
||||
return itemGroups[i].maxScore > itemGroups[j].maxScore
|
||||
})
|
||||
|
||||
// 收集检索到的知识项ID(用于日志)
|
||||
retrievedItemIDs := make([]string, 0, len(resultsByItem))
|
||||
retrievedItemIDs := make([]string, 0, len(itemGroups))
|
||||
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识(包含上下文扩展):\n\n", len(results)))
|
||||
|
||||
resultIndex := 1
|
||||
for itemID, itemResults := range resultsByItem {
|
||||
// 找到相似度最高的作为主结果
|
||||
for _, group := range itemGroups {
|
||||
itemResults := group.results
|
||||
// 找到混合分数最高的作为主结果(使用混合分数,而不是相似度)
|
||||
mainResult := itemResults[0]
|
||||
maxSimilarity := mainResult.Similarity
|
||||
maxScore := mainResult.Score
|
||||
for _, result := range itemResults {
|
||||
if result.Similarity > maxSimilarity {
|
||||
maxSimilarity = result.Similarity
|
||||
if result.Score > maxScore {
|
||||
maxScore = result.Score
|
||||
mainResult = result
|
||||
}
|
||||
}
|
||||
@@ -186,8 +218,9 @@ func RegisterKnowledgeTool(
|
||||
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
|
||||
})
|
||||
|
||||
// 显示主结果(相似度最高的)
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", resultIndex, mainResult.Similarity*100))
|
||||
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
|
||||
@@ -208,8 +241,8 @@ func RegisterKnowledgeTool(
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
|
||||
if !contains(retrievedItemIDs, itemID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, itemID)
|
||||
if !contains(retrievedItemIDs, group.itemID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
|
||||
}
|
||||
resultIndex++
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user