mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
214 lines
5.1 KiB
Go
214 lines
5.1 KiB
Go
package knowledge
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
|
|
"cyberstrike-ai/internal/config"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/pkoukk/tiktoken-go"
|
|
)
|
|
|
|
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
|
|
const postRetrieveMaxPrefetchCap = 200
|
|
|
|
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
|
|
type DocumentReranker interface {
|
|
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
|
|
}
|
|
|
|
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
|
|
type NopDocumentReranker struct{}
|
|
|
|
// Rerank implements [DocumentReranker] as no-op.
|
|
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
|
|
return docs, nil
|
|
}
|
|
|
|
var tiktokenEncMu sync.Mutex
|
|
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
|
|
|
|
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
|
|
m := strings.TrimSpace(model)
|
|
if m == "" {
|
|
m = "gpt-4"
|
|
}
|
|
tiktokenEncMu.Lock()
|
|
defer tiktokenEncMu.Unlock()
|
|
if enc, ok := tiktokenEncCache[m]; ok {
|
|
return enc, nil
|
|
}
|
|
enc, err := tiktoken.EncodingForModel(m)
|
|
if err != nil {
|
|
enc, err = tiktoken.GetEncoding("cl100k_base")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
tiktokenEncCache[m] = enc
|
|
return enc, nil
|
|
}
|
|
|
|
func countDocTokens(text, model string) (int, error) {
|
|
enc, err := encodingForTokenizerModel(model)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
toks := enc.Encode(text, nil, nil)
|
|
return len(toks), nil
|
|
}
|
|
|
|
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
|
|
func normalizeContentFingerprintKey(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
var b strings.Builder
|
|
b.Grow(len(s))
|
|
prevSpace := false
|
|
for _, r := range s {
|
|
if unicode.IsSpace(r) {
|
|
if !prevSpace {
|
|
b.WriteByte(' ')
|
|
prevSpace = true
|
|
}
|
|
continue
|
|
}
|
|
prevSpace = false
|
|
b.WriteRune(r)
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func contentNormKey(d *schema.Document) string {
|
|
if d == nil {
|
|
return ""
|
|
}
|
|
n := normalizeContentFingerprintKey(d.Content)
|
|
if n == "" {
|
|
return ""
|
|
}
|
|
sum := sha256.Sum256([]byte(n))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
|
|
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
|
|
if len(docs) < 2 {
|
|
return docs
|
|
}
|
|
seen := make(map[string]struct{}, len(docs))
|
|
out := make([]*schema.Document, 0, len(docs))
|
|
for _, d := range docs {
|
|
if d == nil {
|
|
continue
|
|
}
|
|
k := contentNormKey(d)
|
|
if k == "" {
|
|
out = append(out, d)
|
|
continue
|
|
}
|
|
if _, ok := seen[k]; ok {
|
|
continue
|
|
}
|
|
seen[k] = struct{}{}
|
|
out = append(out, d)
|
|
}
|
|
return out
|
|
}
|
|
|
|
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
|
|
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
|
|
if len(docs) == 0 {
|
|
return docs, nil
|
|
}
|
|
unlimitedChars := maxRunes <= 0
|
|
unlimitedTok := maxTokens <= 0
|
|
if unlimitedChars && unlimitedTok {
|
|
return docs, nil
|
|
}
|
|
|
|
remRunes := maxRunes
|
|
remTok := maxTokens
|
|
out := make([]*schema.Document, 0, len(docs))
|
|
|
|
for _, d := range docs {
|
|
if d == nil || strings.TrimSpace(d.Content) == "" {
|
|
continue
|
|
}
|
|
runes := utf8.RuneCountInString(d.Content)
|
|
if !unlimitedChars && runes > remRunes {
|
|
break
|
|
}
|
|
var tok int
|
|
var err error
|
|
if !unlimitedTok {
|
|
tok, err = countDocTokens(d.Content, tokenModel)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("token count: %w", err)
|
|
}
|
|
if tok > remTok {
|
|
break
|
|
}
|
|
}
|
|
out = append(out, d)
|
|
if !unlimitedChars {
|
|
remRunes -= runes
|
|
}
|
|
if !unlimitedTok {
|
|
remTok -= tok
|
|
}
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
|
|
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
|
if topK < 1 {
|
|
topK = 5
|
|
}
|
|
fetch := topK
|
|
if po != nil && po.PrefetchTopK > fetch {
|
|
fetch = po.PrefetchTopK
|
|
}
|
|
if fetch > postRetrieveMaxPrefetchCap {
|
|
fetch = postRetrieveMaxPrefetchCap
|
|
}
|
|
return fetch
|
|
}
|
|
|
|
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
|
|
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
|
|
if finalTopK < 1 {
|
|
finalTopK = 5
|
|
}
|
|
if len(docs) == 0 {
|
|
return docs, nil
|
|
}
|
|
|
|
maxChars := 0
|
|
maxTok := 0
|
|
if po != nil {
|
|
maxChars = po.MaxContextChars
|
|
maxTok = po.MaxContextTokens
|
|
}
|
|
|
|
out := dedupeByNormalizedContent(docs)
|
|
|
|
var err error
|
|
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(out) > finalTopK {
|
|
out = out[:finalTopK]
|
|
}
|
|
return out, nil
|
|
}
|