mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
func tokenizerLenFunc(embeddingModel string) func(string) int {
|
||||
fallback := func(s string) int {
|
||||
r := []rune(s)
|
||||
if len(r) == 0 {
|
||||
return 0
|
||||
}
|
||||
return (len(r) + 3) / 4
|
||||
}
|
||||
m := strings.TrimSpace(embeddingModel)
|
||||
if m == "" {
|
||||
return fallback
|
||||
}
|
||||
tok, err := tiktoken.EncodingForModel(m)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return func(s string) int {
|
||||
return len(tok.Encode(s, nil, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for
|
||||
// embeddingModel when available, else rune/4 approximation.
|
||||
func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) {
|
||||
if chunkSize <= 0 {
|
||||
return nil, fmt.Errorf("chunk size must be positive")
|
||||
}
|
||||
if overlap < 0 {
|
||||
overlap = 0
|
||||
}
|
||||
return recursive.NewSplitter(context.Background(), &recursive.Config{
|
||||
ChunkSize: chunkSize,
|
||||
OverlapSize: overlap,
|
||||
LenFunc: tokenizerLenFunc(embeddingModel),
|
||||
Separators: []string{
|
||||
"\n\n", "\n## ", "\n### ", "\n#### ", "\n",
|
||||
"。", "!", "?", ". ", "? ", "! ",
|
||||
" ",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。
|
||||
func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) {
|
||||
return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
|
||||
Headers: map[string]string{
|
||||
"#": "h1",
|
||||
"##": "h2",
|
||||
"###": "h3",
|
||||
"####": "h4",
|
||||
},
|
||||
TrimHeaders: false,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Document metadata keys for Eino schema.Document flowing through the RAG pipeline.
|
||||
const (
|
||||
metaKBCategory = "kb_category"
|
||||
metaKBTitle = "kb_title"
|
||||
metaKBItemID = "kb_item_id"
|
||||
metaKBChunkIndex = "kb_chunk_index"
|
||||
metaSimilarity = "similarity"
|
||||
)
|
||||
|
||||
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
|
||||
const (
|
||||
DSLRiskType = "risk_type"
|
||||
DSLSimilarityThreshold = "similarity_threshold"
|
||||
DSLSubIndexFilter = "sub_index_filter"
|
||||
)
|
||||
|
||||
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
|
||||
// stay comparable if users skip reindex; new indexes use the same string shape.
|
||||
func FormatEmbeddingInput(category, title, chunkText string) string {
|
||||
return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText)
|
||||
}
|
||||
|
||||
// FormatQueryEmbeddingText builds the string embedded at query time so it matches
|
||||
// [FormatEmbeddingInput] for the same risk category (title left empty for queries).
|
||||
func FormatQueryEmbeddingText(riskType, query string) string {
|
||||
q := strings.TrimSpace(query)
|
||||
rt := strings.TrimSpace(riskType)
|
||||
if rt != "" {
|
||||
return FormatEmbeddingInput(rt, "", q)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// MetaLookupString returns metadata string value or "" if absent.
|
||||
func MetaLookupString(md map[string]any, key string) string {
|
||||
if md == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return t
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(t))
|
||||
}
|
||||
}
|
||||
|
||||
// MetaStringOK returns trimmed non-empty string and true if present and non-empty.
|
||||
func MetaStringOK(md map[string]any, key string) (string, bool) {
|
||||
s := strings.TrimSpace(MetaLookupString(md, key))
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// RequireMetaString requires a non-empty string metadata field.
|
||||
func RequireMetaString(md map[string]any, key string) (string, error) {
|
||||
s, ok := MetaStringOK(md, key)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing or empty metadata %q", key)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RequireMetaInt requires an integer metadata field.
|
||||
func RequireMetaInt(md map[string]any, key string) (int, error) {
|
||||
if md == nil {
|
||||
return 0, fmt.Errorf("missing metadata key %q", key)
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("missing metadata key %q", key)
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
return t, nil
|
||||
case int32:
|
||||
return int(t), nil
|
||||
case int64:
|
||||
return int(t), nil
|
||||
case float64:
|
||||
return int(t), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v)
|
||||
}
|
||||
}
|
||||
|
||||
// DSLNumeric coerces DSL map values (e.g. from JSON) to float64.
|
||||
func DSLNumeric(v any) (float64, bool) {
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
return t, true
|
||||
case float32:
|
||||
return float64(t), true
|
||||
case int:
|
||||
return float64(t), true
|
||||
case int64:
|
||||
return float64(t), true
|
||||
case uint32:
|
||||
return float64(t), true
|
||||
case uint64:
|
||||
return float64(t), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// MetaFloat64OK reads a float metadata value.
|
||||
func MetaFloat64OK(md map[string]any, key string) (float64, bool) {
|
||||
if md == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return DSLNumeric(v)
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package knowledge
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) {
|
||||
q := FormatQueryEmbeddingText("XSS", "payload")
|
||||
want := FormatEmbeddingInput("XSS", "", "payload")
|
||||
if q != want {
|
||||
t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want)
|
||||
}
|
||||
if FormatQueryEmbeddingText("", "hello") != "hello" {
|
||||
t.Fatalf("expected bare query without risk type")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
|
||||
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
|
||||
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("retriever is nil")
|
||||
}
|
||||
ch := compose.NewChain[string, []*schema.Document]()
|
||||
ch.AppendRetriever(r.AsEinoRetriever())
|
||||
return ch.Compile(ctx)
|
||||
}
|
||||
|
||||
// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。
|
||||
func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) {
|
||||
return BuildKnowledgeRetrieveChain(ctx, r)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) {
|
||||
r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop())
|
||||
_, err := BuildKnowledgeRetrieveChain(context.Background(), r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) {
|
||||
_, err := BuildKnowledgeRetrieveChain(context.Background(), nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil retriever")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
|
||||
//
|
||||
// Options:
|
||||
// - [retriever.WithTopK]
|
||||
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string)
|
||||
//
|
||||
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
|
||||
//
|
||||
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
|
||||
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
|
||||
type VectorEinoRetriever struct {
|
||||
inner *Retriever
|
||||
}
|
||||
|
||||
// NewVectorEinoRetriever wraps r for Eino compose / tooling.
|
||||
func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &VectorEinoRetriever{inner: r}
|
||||
}
|
||||
|
||||
// GetType identifies this retriever for Eino callbacks.
|
||||
func (h *VectorEinoRetriever) GetType() string {
|
||||
return "SQLiteVectorKnowledgeRetriever"
|
||||
}
|
||||
|
||||
// Retrieve runs vector search and returns [schema.Document] rows.
|
||||
func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
|
||||
if h == nil || h.inner == nil {
|
||||
return nil, fmt.Errorf("VectorEinoRetriever: nil retriever")
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
ro := retriever.GetCommonOptions(nil, opts...)
|
||||
cfg := h.inner.config
|
||||
|
||||
req := &SearchRequest{Query: q}
|
||||
|
||||
if ro.TopK != nil && *ro.TopK > 0 {
|
||||
req.TopK = *ro.TopK
|
||||
} else if cfg != nil && cfg.TopK > 0 {
|
||||
req.TopK = cfg.TopK
|
||||
} else {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
req.Threshold = 0
|
||||
if ro.DSLInfo != nil {
|
||||
if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok {
|
||||
req.RiskType = strings.TrimSpace(rt)
|
||||
}
|
||||
if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok {
|
||||
if f, ok2 := DSLNumeric(v); ok2 && f > 0 {
|
||||
req.Threshold = f
|
||||
}
|
||||
}
|
||||
if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok {
|
||||
req.SubIndexFilter = strings.TrimSpace(sf)
|
||||
}
|
||||
}
|
||||
if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" {
|
||||
req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter)
|
||||
}
|
||||
if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 {
|
||||
req.Threshold = cfg.SimilarityThreshold
|
||||
}
|
||||
if req.Threshold <= 0 {
|
||||
req.Threshold = 0.7
|
||||
}
|
||||
|
||||
finalTopK := req.TopK
|
||||
var postPO *config.PostRetrieveConfig
|
||||
if cfg != nil {
|
||||
postPO = &cfg.PostRetrieve
|
||||
}
|
||||
fetchK := EffectivePrefetchTopK(finalTopK, postPO)
|
||||
searchReq := *req
|
||||
searchReq.TopK = fetchK
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever)
|
||||
th := req.Threshold
|
||||
st := &th
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: q,
|
||||
TopK: finalTopK,
|
||||
ScoreThreshold: st,
|
||||
Extra: ro.DSLInfo,
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
|
||||
}()
|
||||
|
||||
results, err := h.inner.vectorSearch(ctx, &searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = retrievalResultsToDocuments(results)
|
||||
|
||||
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if h.inner.logger != nil {
|
||||
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if h.inner.embedder != nil {
|
||||
tokenModel = h.inner.embedder.EmbeddingModelName()
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document {
|
||||
out := make([]*schema.Document, 0, len(results))
|
||||
for _, res := range results {
|
||||
if res == nil || res.Chunk == nil || res.Item == nil {
|
||||
continue
|
||||
}
|
||||
d := &schema.Document{
|
||||
ID: res.Chunk.ID,
|
||||
Content: res.Chunk.ChunkText,
|
||||
MetaData: map[string]any{
|
||||
metaKBItemID: res.Item.ID,
|
||||
metaKBCategory: res.Item.Category,
|
||||
metaKBTitle: res.Item.Title,
|
||||
metaKBChunkIndex: res.Chunk.ChunkIndex,
|
||||
metaSimilarity: res.Similarity,
|
||||
},
|
||||
}
|
||||
d.WithScore(res.Score)
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) {
|
||||
out := make([]*RetrievalResult, 0, len(docs))
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
itemID, err := RequireMetaString(d.MetaData, metaKBItemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("document %d: %w", i, err)
|
||||
}
|
||||
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
||||
title := MetaLookupString(d.MetaData, metaKBTitle)
|
||||
chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("document %d: %w", i, err)
|
||||
}
|
||||
sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity)
|
||||
item := &KnowledgeItem{ID: itemID, Category: cat, Title: title}
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: d.ID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIdx,
|
||||
ChunkText: d.Content,
|
||||
}
|
||||
out = append(out, &RetrievalResult{
|
||||
Chunk: chunk,
|
||||
Item: item,
|
||||
Similarity: sim,
|
||||
Score: d.Score(),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ retriever.Retriever = (*VectorEinoRetriever)(nil)
|
||||
@@ -0,0 +1,142 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
|
||||
type SQLiteIndexer struct {
|
||||
db *sql.DB
|
||||
batchSize int
|
||||
embeddingModel string
|
||||
}
|
||||
|
||||
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
|
||||
// batchSize is the embedding batch size; if <= 0, default 64 is used.
|
||||
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
|
||||
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
|
||||
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
|
||||
}
|
||||
|
||||
// GetType implements eino callback run info.
|
||||
func (s *SQLiteIndexer) GetType() string {
|
||||
return "SQLiteKnowledgeIndexer"
|
||||
}
|
||||
|
||||
// Store embeds documents and inserts rows. Each doc must carry MetaData:
|
||||
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
|
||||
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
options := indexer.GetCommonOptions(nil, opts...)
|
||||
if options.Embedding == nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: embedding is required")
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
|
||||
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
|
||||
}()
|
||||
|
||||
subIdxStr := strings.Join(options.SubIndexes, ",")
|
||||
|
||||
texts := make([]string, len(docs))
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
|
||||
}
|
||||
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
||||
title := MetaLookupString(d.MetaData, metaKBTitle)
|
||||
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
|
||||
}
|
||||
|
||||
bs := s.batchSize
|
||||
if bs <= 0 {
|
||||
bs = 64
|
||||
}
|
||||
|
||||
var allVecs [][]float64
|
||||
for start := 0; start < len(texts); start += bs {
|
||||
end := start + bs
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
batch := texts[start:end]
|
||||
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
|
||||
if embedErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
|
||||
}
|
||||
if len(vecs) != len(batch) {
|
||||
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
|
||||
}
|
||||
allVecs = append(allVecs, vecs...)
|
||||
}
|
||||
|
||||
embedDim := 0
|
||||
if len(allVecs) > 0 {
|
||||
embedDim = len(allVecs[0])
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
ids = make([]string, 0, len(docs))
|
||||
for i, d := range docs {
|
||||
chunkID := uuid.New().String()
|
||||
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
|
||||
if metaErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
||||
}
|
||||
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
||||
if metaErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
||||
}
|
||||
vec := allVecs[i]
|
||||
if embedDim > 0 && len(vec) != embedDim {
|
||||
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
|
||||
}
|
||||
vec32 := make([]float32, len(vec))
|
||||
for j, v := range vec {
|
||||
vec32[j] = float32(v)
|
||||
}
|
||||
embeddingJSON, jsonErr := json.Marshal(vec32)
|
||||
if jsonErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx,
|
||||
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
|
||||
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
|
||||
}
|
||||
ids = append(ids, chunkID)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
|
||||
@@ -0,0 +1,251 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。
|
||||
type Embedder struct {
|
||||
eino embedding.Embedder
|
||||
config *config.KnowledgeConfig
|
||||
logger *zap.Logger
|
||||
|
||||
rateLimiter *rate.Limiter
|
||||
rateLimitDelay time.Duration
|
||||
maxRetries int
|
||||
retryDelay time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。
|
||||
func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("knowledge config is nil")
|
||||
}
|
||||
|
||||
var rateLimiter *rate.Limiter
|
||||
var rateLimitDelay time.Duration
|
||||
if cfg.Indexing.MaxRPM > 0 {
|
||||
rpm := cfg.Indexing.MaxRPM
|
||||
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
|
||||
if logger != nil {
|
||||
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
|
||||
}
|
||||
} else if cfg.Indexing.RateLimitDelayMs > 0 {
|
||||
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
|
||||
if logger != nil {
|
||||
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
|
||||
}
|
||||
}
|
||||
|
||||
maxRetries := 3
|
||||
retryDelay := 1000 * time.Millisecond
|
||||
if cfg.Indexing.MaxRetries > 0 {
|
||||
maxRetries = cfg.Indexing.MaxRetries
|
||||
}
|
||||
if cfg.Indexing.RetryDelayMs > 0 {
|
||||
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(cfg.Embedding.Model)
|
||||
if model == "" {
|
||||
model = "text-embedding-3-small"
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSpace(cfg.Embedding.BaseURL)
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
apiKey := strings.TrimSpace(cfg.Embedding.APIKey)
|
||||
if apiKey == "" && openAIConfig != nil {
|
||||
apiKey = strings.TrimSpace(openAIConfig.APIKey)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("embedding API key 未配置")
|
||||
}
|
||||
|
||||
timeout := 120 * time.Second
|
||||
if cfg.Indexing.RequestTimeoutSeconds > 0 {
|
||||
timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second
|
||||
}
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
ByAzure: false,
|
||||
Model: model,
|
||||
HTTPClient: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino OpenAI embedder: %w", err)
|
||||
}
|
||||
|
||||
return &Embedder{
|
||||
eino: inner,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
rateLimiter: rateLimiter,
|
||||
rateLimitDelay: rateLimitDelay,
|
||||
maxRetries: maxRetries,
|
||||
retryDelay: retryDelay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。
|
||||
func (e *Embedder) EmbeddingModelName() string {
|
||||
if e == nil || e.config == nil {
|
||||
return ""
|
||||
}
|
||||
s := strings.TrimSpace(e.config.Embedding.Model)
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return "text-embedding-3-small"
|
||||
}
|
||||
|
||||
func (e *Embedder) waitRateLimiter() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.rateLimiter != nil {
|
||||
ctx := context.Background()
|
||||
if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil {
|
||||
e.logger.Warn("速率限制器等待失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if e.rateLimitDelay > 0 {
|
||||
time.Sleep(e.rateLimitDelay)
|
||||
}
|
||||
}
|
||||
|
||||
// EmbedText 单条嵌入(float32,与历史存储格式一致)。
|
||||
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
|
||||
vecs, err := e.EmbedStrings(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vecs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs))
|
||||
}
|
||||
return vecs[0], nil
|
||||
}
|
||||
|
||||
// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。
|
||||
func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) {
|
||||
if e == nil || e.eino == nil {
|
||||
return nil, fmt.Errorf("embedder not initialized")
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < e.maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
wait := e.retryDelay * time.Duration(attempt)
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(wait):
|
||||
}
|
||||
} else {
|
||||
e.waitRateLimiter()
|
||||
}
|
||||
|
||||
raw, err := e.eino.EmbedStrings(ctx, texts, opts...)
|
||||
if err == nil {
|
||||
out := make([][]float32, len(raw))
|
||||
for i, row := range raw {
|
||||
out[i] = make([]float32, len(row))
|
||||
for j, v := range row {
|
||||
out[i][j] = float32(v)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !e.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err))
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。
|
||||
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
return e.EmbedStrings(ctx, texts)
|
||||
}
|
||||
|
||||
func (e *Embedder) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
|
||||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
|
||||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store.
|
||||
type einoFloatEmbedder struct {
|
||||
inner *Embedder
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
vec32, err := w.inner.EmbedStrings(ctx, texts, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([][]float64, len(vec32))
|
||||
for i, row := range vec32 {
|
||||
out[i] = make([]float64, len(row))
|
||||
for j, v := range row {
|
||||
out[i][j] = float64(v)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) GetType() string {
|
||||
return "CyberStrikeKnowledgeEmbedder"
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) IsCallbacksEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path
|
||||
// and produces float64 vectors expected by generic Eino indexer helpers.
|
||||
func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder {
|
||||
return &einoFloatEmbedder{inner: e}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive".
|
||||
func normalizeChunkStrategy(s string) string {
|
||||
v := strings.TrimSpace(strings.ToLower(s))
|
||||
switch v {
|
||||
case "recursive":
|
||||
return "recursive"
|
||||
case "markdown_then_recursive", "markdown_recursive", "markdown":
|
||||
return "markdown_then_recursive"
|
||||
case "":
|
||||
return "markdown_then_recursive"
|
||||
default:
|
||||
return "markdown_then_recursive"
|
||||
}
|
||||
}
|
||||
|
||||
func buildKnowledgeIndexChain(
|
||||
ctx context.Context,
|
||||
indexingCfg *config.IndexingConfig,
|
||||
db *sql.DB,
|
||||
recursive document.Transformer,
|
||||
embeddingModel string,
|
||||
) (compose.Runnable[[]*schema.Document, []string], error) {
|
||||
if recursive == nil {
|
||||
return nil, fmt.Errorf("recursive transformer is nil")
|
||||
}
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("db is nil")
|
||||
}
|
||||
strategy := normalizeChunkStrategy("markdown_then_recursive")
|
||||
batch := 64
|
||||
maxChunks := 0
|
||||
if indexingCfg != nil {
|
||||
strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy)
|
||||
if indexingCfg.BatchSize > 0 {
|
||||
batch = indexingCfg.BatchSize
|
||||
}
|
||||
maxChunks = indexingCfg.MaxChunksPerItem
|
||||
}
|
||||
|
||||
si := NewSQLiteIndexer(db, batch, embeddingModel)
|
||||
ch := compose.NewChain[[]*schema.Document, []string]()
|
||||
if strategy != "recursive" {
|
||||
md, err := newMarkdownHeaderSplitter(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("markdown splitter: %w", err)
|
||||
}
|
||||
ch.AppendDocumentTransformer(md)
|
||||
}
|
||||
ch.AppendDocumentTransformer(recursive)
|
||||
ch.AppendLambda(newChunkEnrichLambda(maxChunks))
|
||||
ch.AppendIndexer(si)
|
||||
return ch.Compile(ctx)
|
||||
}
|
||||
|
||||
func newChunkEnrichLambda(maxChunks int) *compose.Lambda {
|
||||
return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
_ = ctx
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil || strings.TrimSpace(d.Content) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
if maxChunks > 0 && len(out) > maxChunks {
|
||||
out = out[:maxChunks]
|
||||
}
|
||||
for i, d := range out {
|
||||
if d.MetaData == nil {
|
||||
d.MetaData = make(map[string]any)
|
||||
}
|
||||
d.MetaData[metaKBChunkIndex] = i
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package knowledge
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeChunkStrategy(t *testing.T) {
|
||||
cases := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"", "markdown_then_recursive"},
|
||||
{"recursive", "recursive"},
|
||||
{"RECURSIVE", "recursive"},
|
||||
{"markdown_then_recursive", "markdown_then_recursive"},
|
||||
{"markdown", "markdown_then_recursive"},
|
||||
{"unknown", "markdown_then_recursive"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := normalizeChunkStrategy(tc.in); got != tc.want {
|
||||
t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。
|
||||
type Indexer struct {
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
logger *zap.Logger
|
||||
chunkSize int
|
||||
overlap int
|
||||
indexingCfg *config.IndexingConfig
|
||||
|
||||
indexChain compose.Runnable[[]*schema.Document, []string]
|
||||
fileLoader *fileloader.FileLoader
|
||||
|
||||
mu sync.RWMutex
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
errorCount int
|
||||
|
||||
rebuildMu sync.RWMutex
|
||||
isRebuilding bool
|
||||
rebuildTotalItems int
|
||||
rebuildCurrent int
|
||||
rebuildFailed int
|
||||
rebuildStartTime time.Time
|
||||
rebuildLastItemID string
|
||||
rebuildLastChunks int
|
||||
}
|
||||
|
||||
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
|
||||
func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) {
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("db is nil")
|
||||
}
|
||||
if embedder == nil {
|
||||
return nil, fmt.Errorf("embedder is nil")
|
||||
}
|
||||
if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil {
|
||||
return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err)
|
||||
}
|
||||
if kcfg == nil {
|
||||
kcfg = &config.KnowledgeConfig{}
|
||||
}
|
||||
indexingCfg := &kcfg.Indexing
|
||||
|
||||
chunkSize := 512
|
||||
overlap := 50
|
||||
if indexingCfg.ChunkSize > 0 {
|
||||
chunkSize = indexingCfg.ChunkSize
|
||||
}
|
||||
if indexingCfg.ChunkOverlap >= 0 {
|
||||
overlap = indexingCfg.ChunkOverlap
|
||||
}
|
||||
|
||||
embedModel := embedder.EmbeddingModelName()
|
||||
splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino recursive splitter: %w", err)
|
||||
}
|
||||
|
||||
chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("knowledge index chain: %w", err)
|
||||
}
|
||||
|
||||
var fl *fileloader.FileLoader
|
||||
fl, err = fileloader.NewFileLoader(ctx, nil)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err))
|
||||
}
|
||||
fl = nil
|
||||
err = nil
|
||||
}
|
||||
|
||||
return &Indexer{
|
||||
db: db,
|
||||
embedder: embedder,
|
||||
logger: logger,
|
||||
chunkSize: chunkSize,
|
||||
overlap: overlap,
|
||||
indexingCfg: indexingCfg,
|
||||
indexChain: chain,
|
||||
fileLoader: fl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。
|
||||
func (idx *Indexer) RecompileIndexChain(ctx context.Context) error {
|
||||
if idx == nil || idx.db == nil || idx.embedder == nil {
|
||||
return fmt.Errorf("indexer 未初始化")
|
||||
}
|
||||
if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil {
|
||||
return err
|
||||
}
|
||||
embedModel := idx.embedder.EmbeddingModelName()
|
||||
splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("eino recursive splitter: %w", err)
|
||||
}
|
||||
chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("knowledge index chain: %w", err)
|
||||
}
|
||||
idx.indexChain = chain
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。
|
||||
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
if idx.indexChain == nil {
|
||||
return fmt.Errorf("索引链未初始化")
|
||||
}
|
||||
if idx.embedder == nil {
|
||||
return fmt.Errorf("嵌入器未初始化")
|
||||
}
|
||||
|
||||
var content, category, title, filePath string
|
||||
err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取知识项失败:%w", err)
|
||||
}
|
||||
|
||||
if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil {
|
||||
return fmt.Errorf("删除旧向量失败:%w", err)
|
||||
}
|
||||
|
||||
body := strings.TrimSpace(content)
|
||||
if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil {
|
||||
docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)})
|
||||
if lerr == nil && len(docs) > 0 {
|
||||
var b strings.Builder
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
if i > 0 {
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString(d.Content)
|
||||
}
|
||||
if s := strings.TrimSpace(b.String()); s != "" {
|
||||
body = s
|
||||
}
|
||||
} else if idx.logger != nil {
|
||||
idx.logger.Warn("优先源文件读取失败,使用数据库正文",
|
||||
zap.String("itemId", itemID),
|
||||
zap.String("path", filePath),
|
||||
zap.Error(lerr))
|
||||
}
|
||||
}
|
||||
|
||||
root := &schema.Document{
|
||||
ID: itemID,
|
||||
Content: body,
|
||||
MetaData: map[string]any{
|
||||
metaKBCategory: category,
|
||||
metaKBTitle: title,
|
||||
metaKBItemID: itemID,
|
||||
},
|
||||
}
|
||||
|
||||
idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())}
|
||||
if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 {
|
||||
idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes))
|
||||
}
|
||||
|
||||
ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...))
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = msg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
if idx.logger != nil {
|
||||
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids)))
|
||||
}
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildLastItemID = itemID
|
||||
idx.rebuildLastChunks = len(ids)
|
||||
idx.rebuildMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasIndex 检查是否存在索引
|
||||
func (idx *Indexer) HasIndex() (bool, error) {
|
||||
var count int
|
||||
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查索引失败:%w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// RebuildIndex 重建所有索引
|
||||
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = true
|
||||
idx.rebuildTotalItems = 0
|
||||
idx.rebuildCurrent = 0
|
||||
idx.rebuildFailed = 0
|
||||
idx.rebuildStartTime = time.Now()
|
||||
idx.rebuildLastItemID = ""
|
||||
idx.rebuildLastChunks = 0
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
idx.mu.Lock()
|
||||
idx.lastError = ""
|
||||
idx.lastErrorTime = time.Time{}
|
||||
idx.errorCount = 0
|
||||
idx.mu.Unlock()
|
||||
|
||||
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
|
||||
if err != nil {
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
return fmt.Errorf("查询知识项失败:%w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var itemIDs []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
|
||||
}
|
||||
itemIDs = append(itemIDs, id)
|
||||
}
|
||||
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildTotalItems = len(itemIDs)
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
|
||||
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
maxConsecutiveFailures := 5
|
||||
firstFailureItemID := ""
|
||||
var firstFailureError error
|
||||
|
||||
for i, itemID := range itemIDs {
|
||||
if err := idx.IndexItem(ctx, itemID); err != nil {
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
idx.logger.Warn("索引知识项失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
|
||||
idx.logger.Error("连续索引失败次数过多,立即停止索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.Int("processedItems", i+1),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError)
|
||||
}
|
||||
|
||||
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
|
||||
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
|
||||
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
|
||||
zap.Int("failedCount", failedCount),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildCurrent = i + 1
|
||||
idx.rebuildFailed = failedCount
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
|
||||
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
|
||||
}
|
||||
}
|
||||
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastError 获取最近一次错误信息
|
||||
func (idx *Indexer) GetLastError() (string, time.Time) {
|
||||
idx.mu.RLock()
|
||||
defer idx.mu.RUnlock()
|
||||
return idx.lastError, idx.lastErrorTime
|
||||
}
|
||||
|
||||
// GetRebuildStatus 获取重建索引状态
|
||||
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
|
||||
idx.rebuildMu.RLock()
|
||||
defer idx.rebuildMu.RUnlock()
|
||||
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
|
||||
}
|
||||
@@ -0,0 +1,885 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager 知识库管理器
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
basePath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewManager 创建新的知识库管理器
|
||||
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
|
||||
return &Manager{
|
||||
db: db,
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanKnowledgeBase 扫描知识库目录,更新数据库
|
||||
// 返回需要索引的知识项ID列表(新添加的或更新的)
|
||||
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
|
||||
if m.basePath == "" {
|
||||
return nil, fmt.Errorf("知识库路径未配置")
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(m.basePath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
|
||||
}
|
||||
|
||||
var itemsToIndex []string
|
||||
|
||||
// 遍历知识库目录
|
||||
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 跳过目录和非markdown文件
|
||||
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 计算相对路径和分类
|
||||
relPath, err := filepath.Rel(m.basePath, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 第一个目录名作为分类(风险类型)
|
||||
parts := strings.Split(relPath, string(filepath.Separator))
|
||||
category := "未分类"
|
||||
if len(parts) > 1 {
|
||||
category = parts[0]
|
||||
}
|
||||
|
||||
// 文件名为标题
|
||||
title := strings.TrimSuffix(filepath.Base(path), ".md")
|
||||
|
||||
// 读取文件内容
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
|
||||
return nil // 继续处理其他文件
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
var existingID string
|
||||
var existingContent string
|
||||
var existingUpdatedAt time.Time
|
||||
err = m.db.QueryRow(
|
||||
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
|
||||
path,
|
||||
).Scan(&existingID, &existingContent, &existingUpdatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// 创建新项
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
_, err = m.db.Exec(
|
||||
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, category, title, path, string(content), now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入知识项失败: %w", err)
|
||||
}
|
||||
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
|
||||
// 新添加的项需要索引
|
||||
itemsToIndex = append(itemsToIndex, id)
|
||||
} else if err == nil {
|
||||
// 检查内容是否有变化
|
||||
contentChanged := existingContent != string(content)
|
||||
if contentChanged {
|
||||
// 更新现有项
|
||||
_, err = m.db.Exec(
|
||||
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
|
||||
category, title, string(content), time.Now(), existingID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新知识项失败: %w", err)
|
||||
}
|
||||
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
|
||||
// 内容已更新的项需要重新索引
|
||||
itemsToIndex = append(itemsToIndex, existingID)
|
||||
} else {
|
||||
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return itemsToIndex, nil
|
||||
}
|
||||
|
||||
// GetCategories 获取所有分类(风险类型)
|
||||
func (m *Manager) GetCategories() ([]string, error) {
|
||||
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分类失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var categories []string
|
||||
for rows.Next() {
|
||||
var category string
|
||||
if err := rows.Scan(&category); err != nil {
|
||||
return nil, fmt.Errorf("扫描分类失败: %w", err)
|
||||
}
|
||||
categories = append(categories, category)
|
||||
}
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// GetStats 获取知识库统计信息
|
||||
func (m *Manager) GetStats() (int, int, error) {
|
||||
// 获取分类总数
|
||||
categories, err := m.GetCategories()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
|
||||
}
|
||||
totalCategories := len(categories)
|
||||
|
||||
// 获取知识项总数
|
||||
var totalItems int
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
||||
if err != nil {
|
||||
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
|
||||
}
|
||||
|
||||
return totalCategories, totalItems, nil
|
||||
}
|
||||
|
||||
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
|
||||
// limit: 每页分类数量(0表示不限制)
|
||||
// offset: 偏移量(按分类偏移)
|
||||
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
|
||||
// 首先获取所有分类(带数量统计)
|
||||
rows, err := m.db.Query(`
|
||||
SELECT category, COUNT(*) as item_count
|
||||
FROM knowledge_base_items
|
||||
GROUP BY category
|
||||
ORDER BY category
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 收集所有分类信息
|
||||
type categoryInfo struct {
|
||||
name string
|
||||
itemCount int
|
||||
}
|
||||
var allCategories []categoryInfo
|
||||
for rows.Next() {
|
||||
var info categoryInfo
|
||||
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
|
||||
}
|
||||
allCategories = append(allCategories, info)
|
||||
}
|
||||
|
||||
totalCategories := len(allCategories)
|
||||
|
||||
// 应用分页(按分类分页)
|
||||
var paginatedCategories []categoryInfo
|
||||
if limit > 0 {
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start >= totalCategories {
|
||||
paginatedCategories = []categoryInfo{}
|
||||
} else {
|
||||
if end > totalCategories {
|
||||
end = totalCategories
|
||||
}
|
||||
paginatedCategories = allCategories[start:end]
|
||||
}
|
||||
} else {
|
||||
paginatedCategories = allCategories
|
||||
}
|
||||
|
||||
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
|
||||
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
|
||||
for _, catInfo := range paginatedCategories {
|
||||
// 获取该分类下的所有知识项
|
||||
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
|
||||
}
|
||||
|
||||
result = append(result, &CategoryWithItems{
|
||||
Category: catInfo.name,
|
||||
ItemCount: catInfo.itemCount,
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
|
||||
return result, totalCategories, nil
|
||||
}
|
||||
|
||||
// GetItems 获取知识项列表(完整内容,用于向后兼容)
|
||||
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
||||
return m.GetItemsWithOptions(category, 0, 0, true)
|
||||
}
|
||||
|
||||
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
|
||||
// category: 分类筛选(空字符串表示所有分类)
|
||||
// limit: 每页数量(0表示不限制)
|
||||
// offset: 偏移量
|
||||
// includeContent: 是否包含完整内容(false时只返回摘要)
|
||||
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
// 构建SQL查询
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if includeContent {
|
||||
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
|
||||
} else {
|
||||
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
||||
}
|
||||
|
||||
if category != "" {
|
||||
query += " WHERE category = ?"
|
||||
args = append(args, category)
|
||||
}
|
||||
|
||||
query += " ORDER BY category, title"
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
args = append(args, limit)
|
||||
if offset > 0 {
|
||||
query += " OFFSET ?"
|
||||
args = append(args, offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err = m.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*KnowledgeItem
|
||||
for rows.Next() {
|
||||
item := &KnowledgeItem{}
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if includeContent {
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
// 不包含内容时,Content为空字符串
|
||||
item.Content = ""
|
||||
}
|
||||
|
||||
// 解析时间 - 支持多种格式
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
// 解析创建时间
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析更新时间
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果更新时间为空,使用创建时间
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// GetItemsCount 获取知识项总数
|
||||
func (m *Manager) GetItemsCount(category string) (int, error) {
|
||||
var count int
|
||||
var err error
|
||||
|
||||
if category != "" {
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
|
||||
} else {
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
|
||||
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
|
||||
if keyword == "" {
|
||||
return nil, fmt.Errorf("搜索关键字不能为空")
|
||||
}
|
||||
|
||||
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
|
||||
// 使用%keyword%进行模糊匹配
|
||||
searchPattern := "%" + keyword + "%"
|
||||
|
||||
query = `
|
||||
SELECT id, category, title, file_path, created_at, updated_at
|
||||
FROM knowledge_base_items
|
||||
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
|
||||
`
|
||||
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
|
||||
|
||||
// 如果指定了分类,添加分类过滤
|
||||
if category != "" {
|
||||
query += " AND category = ?"
|
||||
args = append(args, category)
|
||||
}
|
||||
|
||||
query += " ORDER BY category, title"
|
||||
|
||||
rows, err := m.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索知识项失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*KnowledgeItemSummary
|
||||
for rows.Next() {
|
||||
item := &KnowledgeItemSummary{}
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
|
||||
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
|
||||
// 获取总数
|
||||
total, err := m.GetItemsCount(category)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表数据(不包含内容)
|
||||
var rows *sql.Rows
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
||||
|
||||
if category != "" {
|
||||
query += " WHERE category = ?"
|
||||
args = append(args, category)
|
||||
}
|
||||
|
||||
query += " ORDER BY category, title"
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
args = append(args, limit)
|
||||
if offset > 0 {
|
||||
query += " OFFSET ?"
|
||||
args = append(args, offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err = m.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*KnowledgeItemSummary
|
||||
for rows.Next() {
|
||||
item := &KnowledgeItemSummary{}
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
// GetItem 获取单个知识项
|
||||
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
||||
item := &KnowledgeItem{}
|
||||
var createdAt, updatedAt string
|
||||
err := m.db.QueryRow(
|
||||
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
|
||||
id,
|
||||
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("知识项不存在")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间 - 支持多种格式
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
// 解析创建时间
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析更新时间
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果更新时间为空,使用创建时间
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// CreateItem 创建知识项
|
||||
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
// 构建文件路径
|
||||
filePath := filepath.Join(m.basePath, category, title+".md")
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 插入数据库
|
||||
_, err := m.db.Exec(
|
||||
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, category, title, filePath, content, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("插入知识项失败: %w", err)
|
||||
}
|
||||
|
||||
return &KnowledgeItem{
|
||||
ID: id,
|
||||
Category: category,
|
||||
Title: title,
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateItem 更新知识项
|
||||
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
|
||||
// 获取现有项
|
||||
item, err := m.GetItem(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建新文件路径
|
||||
newFilePath := filepath.Join(m.basePath, category, title+".md")
|
||||
|
||||
// 如果路径改变,需要移动文件
|
||||
if item.FilePath != newFilePath {
|
||||
// 确保新目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 移动文件
|
||||
if err := os.Rename(item.FilePath, newFilePath); err != nil {
|
||||
return nil, fmt.Errorf("移动文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧目录(如果为空)
|
||||
oldDir := filepath.Dir(item.FilePath)
|
||||
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if oldDir != m.basePath {
|
||||
if err := os.Remove(oldDir); err != nil {
|
||||
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
_, err = m.db.Exec(
|
||||
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
|
||||
category, title, newFilePath, content, time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧的向量嵌入(需要重新索引)
|
||||
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
|
||||
if err != nil {
|
||||
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
|
||||
}
|
||||
|
||||
return m.GetItem(id)
|
||||
}
|
||||
|
||||
// DeleteItem 删除知识项
|
||||
func (m *Manager) DeleteItem(id string) error {
|
||||
// 获取文件路径
|
||||
var filePath string
|
||||
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除文件
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
|
||||
}
|
||||
|
||||
// 删除数据库记录(级联删除向量)
|
||||
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除空目录(如果为空)
|
||||
dir := filepath.Dir(filePath)
|
||||
if isEmpty, _ := isEmptyDir(dir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if dir != m.basePath {
|
||||
if err := os.Remove(dir); err != nil {
|
||||
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
|
||||
func isEmptyDir(dir string) (bool, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
// 忽略隐藏文件(以 . 开头)
|
||||
if !strings.HasPrefix(entry.Name(), ".") {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// LogRetrieval 记录检索日志
|
||||
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
|
||||
id := uuid.New().String()
|
||||
itemsJSON, _ := json.Marshal(retrievedItems)
|
||||
|
||||
_, err := m.db.Exec(
|
||||
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetIndexStatus 获取索引状态
|
||||
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
|
||||
// 获取总知识项数
|
||||
var totalItems int
|
||||
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取已索引的知识项数(有向量嵌入的)
|
||||
var indexedItems int
|
||||
err = m.db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT item_id)
|
||||
FROM knowledge_embeddings
|
||||
`).Scan(&indexedItems)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算进度百分比
|
||||
var progressPercent float64
|
||||
if totalItems > 0 {
|
||||
progressPercent = float64(indexedItems) / float64(totalItems) * 100
|
||||
} else {
|
||||
progressPercent = 100.0
|
||||
}
|
||||
|
||||
// 判断是否完成
|
||||
isComplete := indexedItems >= totalItems && totalItems > 0
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_items": totalItems,
|
||||
"indexed_items": indexedItems,
|
||||
"progress_percent": progressPercent,
|
||||
"is_complete": isComplete,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetRetrievalLogs 获取检索日志
|
||||
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
if messageID != "" {
|
||||
rows, err = m.db.Query(
|
||||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
|
||||
messageID, limit,
|
||||
)
|
||||
} else if conversationID != "" {
|
||||
rows, err = m.db.Query(
|
||||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
|
||||
conversationID, limit,
|
||||
)
|
||||
} else {
|
||||
rows, err = m.db.Query(
|
||||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
|
||||
limit,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询检索日志失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []*RetrievalLog
|
||||
for rows.Next() {
|
||||
log := &RetrievalLog{}
|
||||
var createdAt string
|
||||
var itemsJSON sql.NullString
|
||||
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间 - 支持多种格式
|
||||
var err error
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
for _, format := range timeFormats {
|
||||
log.CreatedAt, err = time.Parse(format, createdAt)
|
||||
if err == nil && !log.CreatedAt.IsZero() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有格式都失败,记录警告但继续处理
|
||||
if log.CreatedAt.IsZero() {
|
||||
m.logger.Warn("解析检索日志时间失败",
|
||||
zap.String("timeStr", createdAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
// 使用当前时间作为fallback
|
||||
log.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
// 解析检索项
|
||||
if itemsJSON.Valid {
|
||||
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
|
||||
}
|
||||
|
||||
logs = append(logs, log)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// DeleteRetrievalLog 删除检索日志
|
||||
func (m *Manager) DeleteRetrievalLog(id string) error {
|
||||
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除检索日志失败: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取删除行数失败: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("检索日志不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
|
||||
const postRetrieveMaxPrefetchCap = 200
|
||||
|
||||
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
|
||||
type DocumentReranker interface {
|
||||
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
|
||||
}
|
||||
|
||||
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
|
||||
type NopDocumentReranker struct{}
|
||||
|
||||
// Rerank implements [DocumentReranker] as no-op.
|
||||
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var tiktokenEncMu sync.Mutex
|
||||
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
|
||||
|
||||
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
|
||||
m := strings.TrimSpace(model)
|
||||
if m == "" {
|
||||
m = "gpt-4"
|
||||
}
|
||||
tiktokenEncMu.Lock()
|
||||
defer tiktokenEncMu.Unlock()
|
||||
if enc, ok := tiktokenEncCache[m]; ok {
|
||||
return enc, nil
|
||||
}
|
||||
enc, err := tiktoken.EncodingForModel(m)
|
||||
if err != nil {
|
||||
enc, err = tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
tiktokenEncCache[m] = enc
|
||||
return enc, nil
|
||||
}
|
||||
|
||||
func countDocTokens(text, model string) (int, error) {
|
||||
enc, err := encodingForTokenizerModel(model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
toks := enc.Encode(text, nil, nil)
|
||||
return len(toks), nil
|
||||
}
|
||||
|
||||
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
|
||||
func normalizeContentFingerprintKey(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
prevSpace := false
|
||||
for _, r := range s {
|
||||
if unicode.IsSpace(r) {
|
||||
if !prevSpace {
|
||||
b.WriteByte(' ')
|
||||
prevSpace = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
prevSpace = false
|
||||
b.WriteRune(r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func contentNormKey(d *schema.Document) string {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
n := normalizeContentFingerprintKey(d.Content)
|
||||
if n == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(n))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
|
||||
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
|
||||
if len(docs) < 2 {
|
||||
return docs
|
||||
}
|
||||
seen := make(map[string]struct{}, len(docs))
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
k := contentNormKey(d)
|
||||
if k == "" {
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[k]; ok {
|
||||
continue
|
||||
}
|
||||
seen[k] = struct{}{}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
|
||||
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
unlimitedChars := maxRunes <= 0
|
||||
unlimitedTok := maxTokens <= 0
|
||||
if unlimitedChars && unlimitedTok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
remRunes := maxRunes
|
||||
remTok := maxTokens
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
|
||||
for _, d := range docs {
|
||||
if d == nil || strings.TrimSpace(d.Content) == "" {
|
||||
continue
|
||||
}
|
||||
runes := utf8.RuneCountInString(d.Content)
|
||||
if !unlimitedChars && runes > remRunes {
|
||||
break
|
||||
}
|
||||
var tok int
|
||||
var err error
|
||||
if !unlimitedTok {
|
||||
tok, err = countDocTokens(d.Content, tokenModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token count: %w", err)
|
||||
}
|
||||
if tok > remTok {
|
||||
break
|
||||
}
|
||||
}
|
||||
out = append(out, d)
|
||||
if !unlimitedChars {
|
||||
remRunes -= runes
|
||||
}
|
||||
if !unlimitedTok {
|
||||
remTok -= tok
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
|
||||
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
if topK < 1 {
|
||||
topK = 5
|
||||
}
|
||||
fetch := topK
|
||||
if po != nil && po.PrefetchTopK > fetch {
|
||||
fetch = po.PrefetchTopK
|
||||
}
|
||||
if fetch > postRetrieveMaxPrefetchCap {
|
||||
fetch = postRetrieveMaxPrefetchCap
|
||||
}
|
||||
return fetch
|
||||
}
|
||||
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
|
||||
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
|
||||
if finalTopK < 1 {
|
||||
finalTopK = 5
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
maxChars := 0
|
||||
maxTok := 0
|
||||
if po != nil {
|
||||
maxChars = po.MaxContextChars
|
||||
maxTok = po.MaxContextTokens
|
||||
}
|
||||
|
||||
out := dedupeByNormalizedContent(docs)
|
||||
|
||||
var err error
|
||||
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(out) > finalTopK {
|
||||
out = out[:finalTopK]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func doc(id, content string, score float64) *schema.Document {
|
||||
d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}}
|
||||
d.WithScore(score)
|
||||
return d
|
||||
}
|
||||
|
||||
func TestDedupeByNormalizedContent(t *testing.T) {
|
||||
a := doc("1", "hello world", 0.9)
|
||||
b := doc("2", "hello world", 0.8)
|
||||
c := doc("3", "other", 0.7)
|
||||
out := dedupeByNormalizedContent([]*schema.Document{a, b, c})
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("len=%d want 2", len(out))
|
||||
}
|
||||
if out[0].ID != "1" || out[1].ID != "3" {
|
||||
t.Fatalf("order/ids wrong: %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectivePrefetchTopK(t *testing.T) {
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 5 {
|
||||
t.Fatalf("got %d", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
|
||||
t.Fatalf("got %d", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap {
|
||||
t.Fatalf("cap: got %d", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) {
|
||||
d1 := doc("1", "ab", 0.9)
|
||||
d2 := doc("2", "cd", 0.8)
|
||||
d3 := doc("3", "ef", 0.7)
|
||||
po := &config.PostRetrieveConfig{MaxContextChars: 3}
|
||||
out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 1 || out[0].ID != "1" {
|
||||
t.Fatalf("got %#v", out)
|
||||
}
|
||||
|
||||
out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out2) != 2 {
|
||||
t.Fatalf("topk: len=%d", len(out2))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值),
|
||||
// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。
|
||||
type Retriever struct {
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
config *RetrievalConfig
|
||||
logger *zap.Logger
|
||||
|
||||
rerankMu sync.RWMutex
|
||||
reranker DocumentReranker
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int
|
||||
SimilarityThreshold float64
|
||||
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
|
||||
SubIndexFilter string
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
}
|
||||
|
||||
// NewRetriever 创建新的检索器
|
||||
func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever {
|
||||
return &Retriever{
|
||||
db: db,
|
||||
embedder: embedder,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新检索配置
|
||||
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
if cfg != nil {
|
||||
r.config = cfg
|
||||
if r.logger != nil {
|
||||
r.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", cfg.TopK),
|
||||
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
|
||||
zap.String("sub_index_filter", cfg.SubIndexFilter),
|
||||
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
|
||||
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
|
||||
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
|
||||
func (r *Retriever) SetDocumentReranker(rr DocumentReranker) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.rerankMu.Lock()
|
||||
defer r.rerankMu.Unlock()
|
||||
r.reranker = rr
|
||||
}
|
||||
|
||||
func (r *Retriever) documentReranker() DocumentReranker {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
r.rerankMu.RLock()
|
||||
defer r.rerankMu.RUnlock()
|
||||
return r.reranker
|
||||
}
|
||||
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
var dotProduct, normA, normB float64
|
||||
for i := range a {
|
||||
dotProduct += float64(a[i] * b[i])
|
||||
normA += float64(a[i] * a[i])
|
||||
normB += float64(b[i] * b[i])
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。
|
||||
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("请求不能为空")
|
||||
}
|
||||
q := strings.TrimSpace(req.Query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
opts := r.einoRetrieverOptions(req)
|
||||
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return documentsToRetrievalResults(docs)
|
||||
}
|
||||
|
||||
func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option {
|
||||
var opts []retriever.Option
|
||||
if req.TopK > 0 {
|
||||
opts = append(opts, retriever.WithTopK(req.TopK))
|
||||
}
|
||||
dsl := map[string]any{}
|
||||
if strings.TrimSpace(req.RiskType) != "" {
|
||||
dsl[DSLRiskType] = strings.TrimSpace(req.RiskType)
|
||||
}
|
||||
if req.Threshold > 0 {
|
||||
dsl[DSLSimilarityThreshold] = req.Threshold
|
||||
}
|
||||
if strings.TrimSpace(req.SubIndexFilter) != "" {
|
||||
dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter)
|
||||
}
|
||||
if len(dsl) > 0 {
|
||||
opts = append(opts, retriever.WithDSLInfo(dsl))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
|
||||
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
|
||||
}
|
||||
|
||||
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
|
||||
q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE 1=1`
|
||||
var args []interface{}
|
||||
if strings.TrimSpace(riskType) != "" {
|
||||
q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE`
|
||||
args = append(args, riskType)
|
||||
}
|
||||
if tag := strings.TrimSpace(subIndexFilter); tag != "" {
|
||||
tag = strings.ToLower(strings.ReplaceAll(tag, " ", ""))
|
||||
q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)`
|
||||
args = append(args, tag)
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。
|
||||
func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req.Query == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK <= 0 && r.config != nil {
|
||||
topK = r.config.TopK
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 5
|
||||
}
|
||||
|
||||
threshold := req.Threshold
|
||||
if threshold <= 0 && r.config != nil {
|
||||
threshold = r.config.SimilarityThreshold
|
||||
}
|
||||
if threshold <= 0 {
|
||||
threshold = 0.7
|
||||
}
|
||||
|
||||
subIdxFilter := strings.TrimSpace(req.SubIndexFilter)
|
||||
if subIdxFilter == "" && r.config != nil {
|
||||
subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter)
|
||||
}
|
||||
|
||||
queryText := FormatQueryEmbeddingText(req.RiskType, req.Query)
|
||||
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量化查询失败: %w", err)
|
||||
}
|
||||
queryDim := len(queryEmbedding)
|
||||
expectedModel := ""
|
||||
if r.embedder != nil {
|
||||
expectedModel = r.embedder.EmbeddingModelName()
|
||||
}
|
||||
|
||||
sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter)
|
||||
rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询向量失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type candidate struct {
|
||||
chunk *KnowledgeChunk
|
||||
item *KnowledgeItem
|
||||
similarity float64
|
||||
}
|
||||
|
||||
candidates := make([]candidate, 0)
|
||||
rowNum := 0
|
||||
for rows.Next() {
|
||||
rowNum++
|
||||
if rowNum%48 == 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string
|
||||
var chunkIndex, rowDim int
|
||||
|
||||
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil {
|
||||
r.logger.Warn("扫描向量失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
var embedding []float32
|
||||
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
|
||||
r.logger.Warn("解析向量失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if rowDim > 0 && len(embedding) != rowDim {
|
||||
r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding)))
|
||||
continue
|
||||
}
|
||||
if queryDim > 0 && len(embedding) != queryDim {
|
||||
r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding)))
|
||||
continue
|
||||
}
|
||||
if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel {
|
||||
r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel))
|
||||
continue
|
||||
}
|
||||
|
||||
similarity := cosineSimilarity(queryEmbedding, embedding)
|
||||
candidates = append(candidates, candidate{
|
||||
chunk: &KnowledgeChunk{
|
||||
ID: chunkID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIndex,
|
||||
ChunkText: chunkText,
|
||||
Embedding: embedding,
|
||||
},
|
||||
item: &KnowledgeItem{
|
||||
ID: itemID,
|
||||
Category: category,
|
||||
Title: title,
|
||||
},
|
||||
similarity: similarity,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].similarity > candidates[j].similarity
|
||||
})
|
||||
|
||||
filtered := make([]candidate, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
if c.similarity >= threshold {
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filtered) > topK {
|
||||
filtered = filtered[:topK]
|
||||
}
|
||||
|
||||
results := make([]*RetrievalResult, len(filtered))
|
||||
for i, c := range filtered {
|
||||
results[i] = &RetrievalResult{
|
||||
Chunk: c.chunk,
|
||||
Item: c.item,
|
||||
Similarity: c.similarity,
|
||||
Score: c.similarity,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return NewVectorEinoRetriever(r)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata.
|
||||
func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
var n int
|
||||
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error {
|
||||
var colCount int
|
||||
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
|
||||
if err := db.QueryRow(q, column).Scan(&colCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if colCount > 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec(alterSQL)
|
||||
return err
|
||||
}
|
||||
|
||||
// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。
|
||||
func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error {
|
||||
return EnsureKnowledgeEmbeddingsSchema(db)
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
|
||||
func RegisterKnowledgeTool(
|
||||
mcpServer *mcp.Server,
|
||||
retriever *Retriever,
|
||||
manager *Manager,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
// 注册第一个工具:获取所有可用的风险类型列表
|
||||
listRiskTypesTool := mcp.Tool{
|
||||
Name: builtin.ToolListKnowledgeRiskTypes,
|
||||
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
|
||||
ShortDescription: "获取知识库中所有可用的风险类型列表",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
"required": []string{},
|
||||
},
|
||||
}
|
||||
|
||||
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
categories, err := manager.GetCategories()
|
||||
if err != nil {
|
||||
logger.Error("获取风险类型列表失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(categories) == 0 {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "知识库中暂无风险类型。",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var resultText strings.Builder
|
||||
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
|
||||
for i, category := range categories {
|
||||
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
|
||||
}
|
||||
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
|
||||
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
|
||||
|
||||
// 注册第二个工具:搜索知识库(保持原有功能)
|
||||
searchTool := mcp.Tool{
|
||||
Name: builtin.ToolSearchKnowledgeBase,
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(向量语义检索)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "搜索查询内容,描述你想要了解的安全知识主题",
|
||||
},
|
||||
"risk_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
|
||||
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: 查询参数不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
riskType := ""
|
||||
if rt, ok := args["risk_type"].(string); ok && rt != "" {
|
||||
riskType = rt
|
||||
}
|
||||
|
||||
logger.Info("执行知识库检索",
|
||||
zap.String("query", query),
|
||||
zap.String("riskType", riskType),
|
||||
)
|
||||
|
||||
// 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。
|
||||
searchReq := &SearchRequest{
|
||||
Query: query,
|
||||
RiskType: riskType,
|
||||
TopK: 5,
|
||||
}
|
||||
|
||||
results, err := retriever.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
logger.Error("知识库检索失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("检索失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
|
||||
// 按余弦相似度(Score)降序
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].Score > results[j].Score
|
||||
})
|
||||
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
type itemGroup struct {
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档块的最高相似度
|
||||
}
|
||||
itemGroups := make([]*itemGroup, 0)
|
||||
itemMap := make(map[string]*itemGroup)
|
||||
|
||||
for _, result := range results {
|
||||
itemID := result.Item.ID
|
||||
group, exists := itemMap[itemID]
|
||||
if !exists {
|
||||
group = &itemGroup{
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
maxScore: result.Score,
|
||||
}
|
||||
itemMap[itemID] = group
|
||||
itemGroups = append(itemGroups, group)
|
||||
}
|
||||
group.results = append(group.results, result)
|
||||
if result.Score > group.maxScore {
|
||||
group.maxScore = result.Score
|
||||
}
|
||||
}
|
||||
|
||||
// 按文档内最高相似度排序
|
||||
sort.Slice(itemGroups, func(i, j int) bool {
|
||||
return itemGroups[i].maxScore > itemGroups[j].maxScore
|
||||
})
|
||||
|
||||
// 收集检索到的知识项ID(用于日志)
|
||||
retrievedItemIDs := make([]string, 0, len(itemGroups))
|
||||
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results)))
|
||||
|
||||
resultIndex := 1
|
||||
for _, group := range itemGroups {
|
||||
itemResults := group.results
|
||||
mainResult := itemResults[0]
|
||||
maxScore := mainResult.Score
|
||||
for _, result := range itemResults {
|
||||
if result.Score > maxScore {
|
||||
maxScore = result.Score
|
||||
mainResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
|
||||
sort.Slice(itemResults, func(i, j int) bool {
|
||||
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
|
||||
})
|
||||
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
|
||||
if len(itemResults) == 1 {
|
||||
// 只有一个chunk,直接显示
|
||||
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
|
||||
} else {
|
||||
// 多个chunk,按逻辑顺序显示
|
||||
resultText.WriteString("内容片段(按文档顺序):\n")
|
||||
for i, result := range itemResults {
|
||||
// 标记主结果
|
||||
marker := ""
|
||||
if result.Chunk.ID == mainResult.Chunk.ID {
|
||||
marker = " [主匹配]"
|
||||
}
|
||||
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
|
||||
}
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
|
||||
if !contains(retrievedItemIDs, group.itemID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
|
||||
}
|
||||
resultIndex++
|
||||
}
|
||||
|
||||
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
|
||||
// 使用特殊标记,避免影响AI阅读结果
|
||||
if len(retrievedItemIDs) > 0 {
|
||||
metadataJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"_metadata": map[string]interface{}{
|
||||
"retrievedItemIDs": retrievedItemIDs,
|
||||
},
|
||||
})
|
||||
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
|
||||
}
|
||||
|
||||
// 记录检索日志(异步,不阻塞)
|
||||
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
|
||||
// 实际的日志记录应该在Agent的progressCallback中完成
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(searchTool, searchHandler)
|
||||
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
|
||||
}
|
||||
|
||||
// contains 检查切片是否包含元素
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
|
||||
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
|
||||
if q, ok := args["query"].(string); ok {
|
||||
query = q
|
||||
}
|
||||
if rt, ok := args["risk_type"].(string); ok {
|
||||
riskType = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
|
||||
func FormatRetrievalResults(results []*RetrievalResult) string {
|
||||
if len(results) == 0 {
|
||||
return "未找到相关结果"
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
|
||||
|
||||
itemIDs := make(map[string]bool)
|
||||
for i, result := range results {
|
||||
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
|
||||
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
|
||||
itemIDs[result.Item.ID] = true
|
||||
}
|
||||
|
||||
// 返回知识项ID列表(JSON格式)
|
||||
ids := make([]string, 0, len(itemIDs))
|
||||
for id := range itemIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
idsJSON, _ := json.Marshal(ids)
|
||||
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
|
||||
func formatTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// KnowledgeItem 知识库项
|
||||
type KnowledgeItem struct {
|
||||
ID string `json:"id"`
|
||||
Category string `json:"category"` // 风险类型(文件夹名)
|
||||
Title string `json:"title"` // 标题(文件名)
|
||||
FilePath string `json:"filePath"` // 文件路径
|
||||
Content string `json:"content"` // 文件内容
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容)
|
||||
type KnowledgeItemSummary struct {
|
||||
ID string `json:"id"`
|
||||
Category string `json:"category"`
|
||||
Title string `json:"title"`
|
||||
FilePath string `json:"filePath"`
|
||||
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItemSummary
|
||||
aux := &struct {
|
||||
*Alias
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItem
|
||||
aux := &struct {
|
||||
*Alias
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// KnowledgeChunk 知识块(用于向量化)
|
||||
type KnowledgeChunk struct {
|
||||
ID string `json:"id"`
|
||||
ItemID string `json:"itemId"`
|
||||
ChunkIndex int `json:"chunkIndex"`
|
||||
ChunkText string `json:"chunkText"`
|
||||
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// RetrievalResult 检索结果
|
||||
type RetrievalResult struct {
|
||||
Chunk *KnowledgeChunk `json:"chunk"`
|
||||
Item *KnowledgeItem `json:"item"`
|
||||
Similarity float64 `json:"similarity"` // 相似度分数
|
||||
Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度
|
||||
}
|
||||
|
||||
// RetrievalLog 检索日志
|
||||
type RetrievalLog struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
MessageID string `json:"messageId,omitempty"`
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"`
|
||||
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
type Alias RetrievalLog
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}{
|
||||
Alias: (*Alias)(r),
|
||||
CreatedAt: formatTime(r.CreatedAt),
|
||||
})
|
||||
}
|
||||
|
||||
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
||||
type CategoryWithItems struct {
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
}
|
||||
|
||||
// SearchRequest 搜索请求
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
|
||||
SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据)
|
||||
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
|
||||
}
|
||||
Reference in New Issue
Block a user