mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
Add files via upload
This commit is contained in:
@@ -582,9 +582,18 @@ func Default() *Config {
|
||||
},
|
||||
Retrieval: RetrievalConfig{
|
||||
TopK: 5,
|
||||
SimilarityThreshold: 0.7,
|
||||
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
|
||||
HybridWeight: 0.7,
|
||||
},
|
||||
Indexing: IndexingConfig{
|
||||
ChunkSize: 768, // 增加到 768,更好的上下文保持
|
||||
ChunkOverlap: 50,
|
||||
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
|
||||
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
|
||||
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
|
||||
MaxRetries: 3,
|
||||
RetryDelayMs: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,9 +544,21 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
idx.mu.Unlock()
|
||||
}
|
||||
|
||||
// 如果连续失败 2 个块,立即停止处理该知识项(降低阈值,更快停止)
|
||||
// 如果连续失败 5 个块,立即停止处理该知识项
|
||||
// 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题
|
||||
if itemErrorCount >= 2 {
|
||||
// 对于大文档(超过 10 个块),允许失败比例不超过 50%
|
||||
maxConsecutiveFailures := 5
|
||||
if len(chunks) > 10 && itemErrorCount > len(chunks)/2 {
|
||||
idx.logger.Error("知识项向量化失败比例过高,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.Int("failedChunks", itemErrorCount),
|
||||
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
|
||||
zap.Error(firstError),
|
||||
)
|
||||
return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError)
|
||||
}
|
||||
if itemErrorCount >= maxConsecutiveFailures {
|
||||
idx.logger.Error("知识项连续向量化失败,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
@@ -649,7 +661,7 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
maxConsecutiveFailures := 2 // 连续失败 2 次后立即停止(降低阈值,更快停止)
|
||||
maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误)
|
||||
firstFailureItemID := ""
|
||||
var firstFailureError error
|
||||
|
||||
|
||||
@@ -657,7 +657,7 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte
|
||||
|
||||
// 删除旧目录(如果为空)
|
||||
oldDir := filepath.Dir(item.FilePath)
|
||||
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
|
||||
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if oldDir != m.basePath {
|
||||
if err := os.Remove(oldDir); err != nil {
|
||||
@@ -712,7 +712,7 @@ func (m *Manager) DeleteItem(id string) error {
|
||||
|
||||
// 删除空目录(如果为空)
|
||||
dir := filepath.Dir(filePath)
|
||||
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
|
||||
if isEmpty, _ := isEmptyDir(dir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if dir != m.basePath {
|
||||
if err := os.Remove(dir); err != nil {
|
||||
@@ -724,6 +724,21 @@ func (m *Manager) DeleteItem(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
|
||||
func isEmptyDir(dir string) (bool, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
// 忽略隐藏文件(以 . 开头)
|
||||
if !strings.HasPrefix(entry.Name(), ".") {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// LogRetrieval 记录检索日志
|
||||
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -69,8 +69,8 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// bm25Score 计算BM25分数(改进版,更接近标准BM25)
|
||||
// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确
|
||||
// bm25Score 计算 BM25 分数(带缓存的改进版本)
|
||||
// 注意:由于缺少全局文档统计,使用简化 IDF 计算
|
||||
func (r *Retriever) bm25Score(query, text string) float64 {
|
||||
queryTerms := strings.Fields(strings.ToLower(query))
|
||||
if len(queryTerms) == 0 {
|
||||
@@ -83,44 +83,56 @@ func (r *Retriever) bm25Score(query, text string) float64 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// BM25参数
|
||||
k1 := 1.5 // 词频饱和度参数
|
||||
b := 0.75 // 长度归一化参数
|
||||
avgDocLength := 100.0 // 估算的平均文档长度(用于归一化)
|
||||
// BM25 参数(标准值)
|
||||
k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0)
|
||||
b := 0.75 // 长度归一化参数(标准值)
|
||||
avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小)
|
||||
docLength := float64(len(textTerms))
|
||||
|
||||
score := 0.0
|
||||
for _, term := range queryTerms {
|
||||
// 计算词频(TF)
|
||||
termFreq := 0
|
||||
for _, textTerm := range textTerms {
|
||||
if textTerm == term {
|
||||
termFreq++
|
||||
}
|
||||
}
|
||||
|
||||
if termFreq > 0 {
|
||||
// BM25公式的核心部分
|
||||
// TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength)))
|
||||
tf := float64(termFreq)
|
||||
lengthNorm := 1 - b + b*(docLength/avgDocLength)
|
||||
tfScore := tf / (tf + k1*lengthNorm)
|
||||
|
||||
// 简化IDF:使用词长度作为权重(短词通常更重要)
|
||||
// 实际BM25需要全局文档统计,这里用简化版本
|
||||
idfWeight := 1.0
|
||||
if len(term) > 2 {
|
||||
// 长词稍微降低权重(但实际BM25中,罕见词IDF更高)
|
||||
idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0)
|
||||
}
|
||||
|
||||
score += tfScore * idfWeight
|
||||
}
|
||||
// 计算词频映射
|
||||
textTermFreq := make(map[string]int, len(textTerms))
|
||||
for _, term := range textTerms {
|
||||
textTermFreq[term]++
|
||||
}
|
||||
|
||||
// 归一化到0-1范围
|
||||
score := 0.0
|
||||
matchedQueryTerms := 0
|
||||
|
||||
for _, term := range queryTerms {
|
||||
termFreq, exists := textTermFreq[term]
|
||||
if !exists || termFreq == 0 {
|
||||
continue
|
||||
}
|
||||
matchedQueryTerms++
|
||||
|
||||
// BM25 TF 计算公式
|
||||
tf := float64(termFreq)
|
||||
lengthNorm := 1 - b + b*(docLength/avgDocLength)
|
||||
tfScore := tf / (tf + k1*lengthNorm)
|
||||
|
||||
// 改进的 IDF 计算:使用词长度和出现频率估算
|
||||
// 短词(2-3 字符)通常更重要,长词 IDF 略低
|
||||
idfWeight := 1.0
|
||||
termLen := len(term)
|
||||
if termLen <= 2 {
|
||||
// 极短词(如 go, js)给予更高权重
|
||||
idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0)
|
||||
} else if termLen <= 4 {
|
||||
// 短词(4 字符)标准权重
|
||||
idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0)
|
||||
} else {
|
||||
// 长词稍微降低权重
|
||||
idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0)
|
||||
}
|
||||
|
||||
score += tfScore * idfWeight
|
||||
}
|
||||
|
||||
// 归一化:考虑匹配的查询词比例
|
||||
if len(queryTerms) > 0 {
|
||||
score = score / float64(len(queryTerms))
|
||||
// 使用匹配比例作为额外因子
|
||||
matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms))
|
||||
score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2
|
||||
}
|
||||
|
||||
return math.Min(score, 1.0)
|
||||
@@ -173,7 +185,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE i.category = ? COLLATE NOCASE
|
||||
WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE
|
||||
`, req.RiskType)
|
||||
} else {
|
||||
rows, err = r.db.Query(`
|
||||
@@ -357,7 +369,10 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
zap.Float64("threshold", threshold),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
)
|
||||
} else if len(filteredCandidates) > topK {
|
||||
}
|
||||
|
||||
// 统一在最终返回前严格限制 Top-K 数量
|
||||
if len(filteredCandidates) > topK {
|
||||
// 如果过滤后结果太多,只取Top-K
|
||||
filteredCandidates = filteredCandidates[:topK]
|
||||
}
|
||||
|
||||
@@ -5,6 +5,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
|
||||
func formatTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// KnowledgeItem 知识库项
|
||||
type KnowledgeItem struct {
|
||||
ID string `json:"id"`
|
||||
@@ -22,12 +30,12 @@ type KnowledgeItemSummary struct {
|
||||
Category string `json:"category"`
|
||||
Title string `json:"title"`
|
||||
FilePath string `json:"filePath"`
|
||||
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前150字符)
|
||||
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,确保时间格式正确
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItemSummary
|
||||
aux := &struct {
|
||||
@@ -37,25 +45,12 @@ func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
|
||||
// 格式化创建时间
|
||||
if k.CreatedAt.IsZero() {
|
||||
aux.CreatedAt = ""
|
||||
} else {
|
||||
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 格式化更新时间
|
||||
if k.UpdatedAt.IsZero() {
|
||||
aux.UpdatedAt = ""
|
||||
} else {
|
||||
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,确保时间格式正确
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItem
|
||||
aux := &struct {
|
||||
@@ -65,21 +60,8 @@ func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
|
||||
// 格式化创建时间
|
||||
if k.CreatedAt.IsZero() {
|
||||
aux.CreatedAt = ""
|
||||
} else {
|
||||
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 格式化更新时间
|
||||
if k.UpdatedAt.IsZero() {
|
||||
aux.UpdatedAt = ""
|
||||
} else {
|
||||
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
@@ -89,7 +71,7 @@ type KnowledgeChunk struct {
|
||||
ItemID string `json:"itemId"`
|
||||
ChunkIndex int `json:"chunkIndex"`
|
||||
ChunkText string `json:"chunkText"`
|
||||
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON
|
||||
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
@@ -108,11 +90,11 @@ type RetrievalLog struct {
|
||||
MessageID string `json:"messageId,omitempty"`
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"`
|
||||
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表
|
||||
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,确保时间格式正确
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
type Alias RetrievalLog
|
||||
return json.Marshal(&struct {
|
||||
@@ -120,21 +102,21 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}{
|
||||
Alias: (*Alias)(r),
|
||||
CreatedAt: r.CreatedAt.Format(time.RFC3339),
|
||||
CreatedAt: formatTime(r.CreatedAt),
|
||||
})
|
||||
}
|
||||
|
||||
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
||||
type CategoryWithItems struct {
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
}
|
||||
|
||||
// SearchRequest 搜索请求
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
|
||||
TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7
|
||||
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user