diff --git a/knowledge/chunk_eino.go b/knowledge/chunk_eino.go deleted file mode 100644 index 6592f350..00000000 --- a/knowledge/chunk_eino.go +++ /dev/null @@ -1,67 +0,0 @@ -package knowledge - -import ( - "context" - "fmt" - "strings" - - "github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown" - "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" - "github.com/cloudwego/eino/components/document" - "github.com/pkoukk/tiktoken-go" -) - -func tokenizerLenFunc(embeddingModel string) func(string) int { - fallback := func(s string) int { - r := []rune(s) - if len(r) == 0 { - return 0 - } - return (len(r) + 3) / 4 - } - m := strings.TrimSpace(embeddingModel) - if m == "" { - return fallback - } - tok, err := tiktoken.EncodingForModel(m) - if err != nil { - return fallback - } - return func(s string) int { - return len(tok.Encode(s, nil, nil)) - } -} - -// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for -// embeddingModel when available, else rune/4 approximation. -func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) { - if chunkSize <= 0 { - return nil, fmt.Errorf("chunk size must be positive") - } - if overlap < 0 { - overlap = 0 - } - return recursive.NewSplitter(context.Background(), &recursive.Config{ - ChunkSize: chunkSize, - OverlapSize: overlap, - LenFunc: tokenizerLenFunc(embeddingModel), - Separators: []string{ - "\n\n", "\n## ", "\n### ", "\n#### ", "\n", - "。", "!", "?", ". ", "? ", "! ", - " ", - }, - }) -} - -// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。 -func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) { - return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{ - Headers: map[string]string{ - "#": "h1", - "##": "h2", - "###": "h3", - "####": "h4", - }, - TrimHeaders: false, - }) -} diff --git a/knowledge/eino_meta.go b/knowledge/eino_meta.go deleted file mode 100644 index 2ae419c4..00000000 --- a/knowledge/eino_meta.go +++ /dev/null @@ -1,129 +0,0 @@ -package knowledge - -import ( - "fmt" - "strings" -) - -// Document metadata keys for Eino schema.Document flowing through the RAG pipeline. -const ( - metaKBCategory = "kb_category" - metaKBTitle = "kb_title" - metaKBItemID = "kb_item_id" - metaKBChunkIndex = "kb_chunk_index" - metaSimilarity = "similarity" -) - -// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo]. -const ( - DSLRiskType = "risk_type" - DSLSimilarityThreshold = "similarity_threshold" - DSLSubIndexFilter = "sub_index_filter" -) - -// FormatEmbeddingInput matches the historical indexing format so existing embeddings -// stay comparable if users skip reindex; new indexes use the same string shape. -func FormatEmbeddingInput(category, title, chunkText string) string { - return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText) -} - -// FormatQueryEmbeddingText builds the string embedded at query time so it matches -// [FormatEmbeddingInput] for the same risk category (title left empty for queries). -func FormatQueryEmbeddingText(riskType, query string) string { - q := strings.TrimSpace(query) - rt := strings.TrimSpace(riskType) - if rt != "" { - return FormatEmbeddingInput(rt, "", q) - } - return q -} - -// MetaLookupString returns metadata string value or "" if absent. -func MetaLookupString(md map[string]any, key string) string { - if md == nil { - return "" - } - v, ok := md[key] - if !ok || v == nil { - return "" - } - switch t := v.(type) { - case string: - return t - default: - return strings.TrimSpace(fmt.Sprint(t)) - } -} - -// MetaStringOK returns trimmed non-empty string and true if present and non-empty. -func MetaStringOK(md map[string]any, key string) (string, bool) { - s := strings.TrimSpace(MetaLookupString(md, key)) - if s == "" { - return "", false - } - return s, true -} - -// RequireMetaString requires a non-empty string metadata field. -func RequireMetaString(md map[string]any, key string) (string, error) { - s, ok := MetaStringOK(md, key) - if !ok { - return "", fmt.Errorf("missing or empty metadata %q", key) - } - return s, nil -} - -// RequireMetaInt requires an integer metadata field. -func RequireMetaInt(md map[string]any, key string) (int, error) { - if md == nil { - return 0, fmt.Errorf("missing metadata key %q", key) - } - v, ok := md[key] - if !ok { - return 0, fmt.Errorf("missing metadata key %q", key) - } - switch t := v.(type) { - case int: - return t, nil - case int32: - return int(t), nil - case int64: - return int(t), nil - case float64: - return int(t), nil - default: - return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v) - } -} - -// DSLNumeric coerces DSL map values (e.g. from JSON) to float64. -func DSLNumeric(v any) (float64, bool) { - switch t := v.(type) { - case float64: - return t, true - case float32: - return float64(t), true - case int: - return float64(t), true - case int64: - return float64(t), true - case uint32: - return float64(t), true - case uint64: - return float64(t), true - default: - return 0, false - } -} - -// MetaFloat64OK reads a float metadata value. -func MetaFloat64OK(md map[string]any, key string) (float64, bool) { - if md == nil { - return 0, false - } - v, ok := md[key] - if !ok { - return 0, false - } - return DSLNumeric(v) -} diff --git a/knowledge/eino_meta_test.go b/knowledge/eino_meta_test.go deleted file mode 100644 index ba3f60da..00000000 --- a/knowledge/eino_meta_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package knowledge - -import "testing" - -func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) { - q := FormatQueryEmbeddingText("XSS", "payload") - want := FormatEmbeddingInput("XSS", "", "payload") - if q != want { - t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want) - } - if FormatQueryEmbeddingText("", "hello") != "hello" { - t.Fatalf("expected bare query without risk type") - } -} diff --git a/knowledge/eino_retrieve_chain.go b/knowledge/eino_retrieve_chain.go deleted file mode 100644 index 2d1b72eb..00000000 --- a/knowledge/eino_retrieve_chain.go +++ /dev/null @@ -1,25 +0,0 @@ -package knowledge - -import ( - "context" - "fmt" - - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/schema" -) - -// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。 -// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。 -func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) { - if r == nil { - return nil, fmt.Errorf("retriever is nil") - } - ch := compose.NewChain[string, []*schema.Document]() - ch.AppendRetriever(r.AsEinoRetriever()) - return ch.Compile(ctx) -} - -// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。 -func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) { - return BuildKnowledgeRetrieveChain(ctx, r) -} diff --git a/knowledge/eino_retrieve_chain_test.go b/knowledge/eino_retrieve_chain_test.go deleted file mode 100644 index c74a6900..00000000 --- a/knowledge/eino_retrieve_chain_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package knowledge - -import ( - "context" - "testing" - - "go.uber.org/zap" -) - -func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) { - r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop()) - _, err := BuildKnowledgeRetrieveChain(context.Background(), r) - if err != nil { - t.Fatal(err) - } -} - -func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) { - _, err := BuildKnowledgeRetrieveChain(context.Background(), nil) - if err == nil { - t.Fatal("expected error for nil retriever") - } -} diff --git a/knowledge/eino_retriever_adapter.go b/knowledge/eino_retriever_adapter.go deleted file mode 100644 index f5635121..00000000 --- a/knowledge/eino_retriever_adapter.go +++ /dev/null @@ -1,202 +0,0 @@ -package knowledge - -import ( - "context" - "fmt" - "strings" - - "cyberstrike-ai/internal/config" - - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/retriever" - "github.com/cloudwego/eino/schema" - "go.uber.org/zap" -) - -// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity. -// -// Options: -// - [retriever.WithTopK] -// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string) -// -// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric. -// -// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then -// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig]. -type VectorEinoRetriever struct { - inner *Retriever -} - -// NewVectorEinoRetriever wraps r for Eino compose / tooling. -func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever { - if r == nil { - return nil - } - return &VectorEinoRetriever{inner: r} -} - -// GetType identifies this retriever for Eino callbacks. -func (h *VectorEinoRetriever) GetType() string { - return "SQLiteVectorKnowledgeRetriever" -} - -// Retrieve runs vector search and returns [schema.Document] rows. -func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) { - if h == nil || h.inner == nil { - return nil, fmt.Errorf("VectorEinoRetriever: nil retriever") - } - q := strings.TrimSpace(query) - if q == "" { - return nil, fmt.Errorf("查询不能为空") - } - - ro := retriever.GetCommonOptions(nil, opts...) - cfg := h.inner.config - - req := &SearchRequest{Query: q} - - if ro.TopK != nil && *ro.TopK > 0 { - req.TopK = *ro.TopK - } else if cfg != nil && cfg.TopK > 0 { - req.TopK = cfg.TopK - } else { - req.TopK = 5 - } - - req.Threshold = 0 - if ro.DSLInfo != nil { - if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok { - req.RiskType = strings.TrimSpace(rt) - } - if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok { - if f, ok2 := DSLNumeric(v); ok2 && f > 0 { - req.Threshold = f - } - } - if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok { - req.SubIndexFilter = strings.TrimSpace(sf) - } - } - if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" { - req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter) - } - if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 { - req.Threshold = cfg.SimilarityThreshold - } - if req.Threshold <= 0 { - req.Threshold = 0.7 - } - - finalTopK := req.TopK - var postPO *config.PostRetrieveConfig - if cfg != nil { - postPO = &cfg.PostRetrieve - } - fetchK := EffectivePrefetchTopK(finalTopK, postPO) - searchReq := *req - searchReq.TopK = fetchK - - ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever) - th := req.Threshold - st := &th - ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ - Query: q, - TopK: finalTopK, - ScoreThreshold: st, - Extra: ro.DSLInfo, - }) - defer func() { - if err != nil { - _ = callbacks.OnError(ctx, err) - return - } - _ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out}) - }() - - results, err := h.inner.vectorSearch(ctx, &searchReq) - if err != nil { - return nil, err - } - out = retrievalResultsToDocuments(results) - - if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 { - reranked, rerr := rr.Rerank(ctx, q, out) - if rerr != nil { - if h.inner.logger != nil { - h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr)) - } - } else if len(reranked) > 0 { - out = reranked - } - } - - tokenModel := "" - if h.inner.embedder != nil { - tokenModel = h.inner.embedder.EmbeddingModelName() - } - out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK) - if err != nil { - return nil, err - } - return out, nil -} - -func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document { - out := make([]*schema.Document, 0, len(results)) - for _, res := range results { - if res == nil || res.Chunk == nil || res.Item == nil { - continue - } - d := &schema.Document{ - ID: res.Chunk.ID, - Content: res.Chunk.ChunkText, - MetaData: map[string]any{ - metaKBItemID: res.Item.ID, - metaKBCategory: res.Item.Category, - metaKBTitle: res.Item.Title, - metaKBChunkIndex: res.Chunk.ChunkIndex, - metaSimilarity: res.Similarity, - }, - } - d.WithScore(res.Score) - out = append(out, d) - } - return out -} - -func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) { - out := make([]*RetrievalResult, 0, len(docs)) - for i, d := range docs { - if d == nil { - continue - } - itemID, err := RequireMetaString(d.MetaData, metaKBItemID) - if err != nil { - return nil, fmt.Errorf("document %d: %w", i, err) - } - cat := MetaLookupString(d.MetaData, metaKBCategory) - title := MetaLookupString(d.MetaData, metaKBTitle) - chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex) - if err != nil { - return nil, fmt.Errorf("document %d: %w", i, err) - } - sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity) - item := &KnowledgeItem{ID: itemID, Category: cat, Title: title} - chunk := &KnowledgeChunk{ - ID: d.ID, - ItemID: itemID, - ChunkIndex: chunkIdx, - ChunkText: d.Content, - } - out = append(out, &RetrievalResult{ - Chunk: chunk, - Item: item, - Similarity: sim, - Score: d.Score(), - }) - } - return out, nil -} - -var _ retriever.Retriever = (*VectorEinoRetriever)(nil) diff --git a/knowledge/eino_sqlite_indexer.go b/knowledge/eino_sqlite_indexer.go deleted file mode 100644 index a0bbdcdc..00000000 --- a/knowledge/eino_sqlite_indexer.go +++ /dev/null @@ -1,142 +0,0 @@ -package knowledge - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "strings" - - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/indexer" - "github.com/cloudwego/eino/schema" - "github.com/google/uuid" -) - -// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema. -type SQLiteIndexer struct { - db *sql.DB - batchSize int - embeddingModel string -} - -// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call. -// batchSize is the embedding batch size; if <= 0, default 64 is used. -// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty). -func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer { - return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)} -} - -// GetType implements eino callback run info. -func (s *SQLiteIndexer) GetType() string { - return "SQLiteKnowledgeIndexer" -} - -// Store embeds documents and inserts rows. Each doc must carry MetaData: -// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only. -func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { - options := indexer.GetCommonOptions(nil, opts...) - if options.Embedding == nil { - return nil, fmt.Errorf("sqlite indexer: embedding is required") - } - if len(docs) == 0 { - return nil, nil - } - - ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer) - ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) - defer func() { - if err != nil { - _ = callbacks.OnError(ctx, err) - return - } - _ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids}) - }() - - subIdxStr := strings.Join(options.SubIndexes, ",") - - texts := make([]string, len(docs)) - for i, d := range docs { - if d == nil { - return nil, fmt.Errorf("sqlite indexer: nil document at %d", i) - } - cat := MetaLookupString(d.MetaData, metaKBCategory) - title := MetaLookupString(d.MetaData, metaKBTitle) - texts[i] = FormatEmbeddingInput(cat, title, d.Content) - } - - bs := s.batchSize - if bs <= 0 { - bs = 64 - } - - var allVecs [][]float64 - for start := 0; start < len(texts); start += bs { - end := start + bs - if end > len(texts) { - end = len(texts) - } - batch := texts[start:end] - vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch) - if embedErr != nil { - return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr) - } - if len(vecs) != len(batch) { - return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch)) - } - allVecs = append(allVecs, vecs...) - } - - embedDim := 0 - if len(allVecs) > 0 { - embedDim = len(allVecs[0]) - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err) - } - defer tx.Rollback() - - ids = make([]string, 0, len(docs)) - for i, d := range docs { - chunkID := uuid.New().String() - itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID) - if metaErr != nil { - return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) - } - chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex) - if metaErr != nil { - return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) - } - vec := allVecs[i] - if embedDim > 0 && len(vec) != embedDim { - return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim) - } - vec32 := make([]float32, len(vec)) - for j, v := range vec { - vec32[j] = float32(v) - } - embeddingJSON, jsonErr := json.Marshal(vec32) - if jsonErr != nil { - return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr) - } - _, err = tx.ExecContext(ctx, - `INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`, - chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim, - ) - if err != nil { - return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err) - } - ids = append(ids, chunkID) - } - - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("sqlite indexer: commit: %w", err) - } - return ids, nil -} - -var _ indexer.Indexer = (*SQLiteIndexer)(nil) diff --git a/knowledge/embedder.go b/knowledge/embedder.go deleted file mode 100644 index d9ce8afa..00000000 --- a/knowledge/embedder.go +++ /dev/null @@ -1,251 +0,0 @@ -package knowledge - -import ( - "context" - "fmt" - "net/http" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/config" - - einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai" - "github.com/cloudwego/eino/components/embedding" - "go.uber.org/zap" - "golang.org/x/time/rate" -) - -// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。 -type Embedder struct { - eino embedding.Embedder - config *config.KnowledgeConfig - logger *zap.Logger - - rateLimiter *rate.Limiter - rateLimitDelay time.Duration - maxRetries int - retryDelay time.Duration - mu sync.Mutex -} - -// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。 -func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) { - if cfg == nil { - return nil, fmt.Errorf("knowledge config is nil") - } - - var rateLimiter *rate.Limiter - var rateLimitDelay time.Duration - if cfg.Indexing.MaxRPM > 0 { - rpm := cfg.Indexing.MaxRPM - rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm) - if logger != nil { - logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm)) - } - } else if cfg.Indexing.RateLimitDelayMs > 0 { - rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond - if logger != nil { - logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay)) - } - } - - maxRetries := 3 - retryDelay := 1000 * time.Millisecond - if cfg.Indexing.MaxRetries > 0 { - maxRetries = cfg.Indexing.MaxRetries - } - if cfg.Indexing.RetryDelayMs > 0 { - retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond - } - - model := strings.TrimSpace(cfg.Embedding.Model) - if model == "" { - model = "text-embedding-3-small" - } - - baseURL := strings.TrimSpace(cfg.Embedding.BaseURL) - baseURL = strings.TrimSuffix(baseURL, "/") - if baseURL == "" { - baseURL = "https://api.openai.com/v1" - } - - apiKey := strings.TrimSpace(cfg.Embedding.APIKey) - if apiKey == "" && openAIConfig != nil { - apiKey = strings.TrimSpace(openAIConfig.APIKey) - } - if apiKey == "" { - return nil, fmt.Errorf("embedding API key 未配置") - } - - timeout := 120 * time.Second - if cfg.Indexing.RequestTimeoutSeconds > 0 { - timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second - } - httpClient := &http.Client{Timeout: timeout} - - inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{ - APIKey: apiKey, - BaseURL: baseURL, - ByAzure: false, - Model: model, - HTTPClient: httpClient, - }) - if err != nil { - return nil, fmt.Errorf("eino OpenAI embedder: %w", err) - } - - return &Embedder{ - eino: inner, - config: cfg, - logger: logger, - rateLimiter: rateLimiter, - rateLimitDelay: rateLimitDelay, - maxRetries: maxRetries, - retryDelay: retryDelay, - }, nil -} - -// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。 -func (e *Embedder) EmbeddingModelName() string { - if e == nil || e.config == nil { - return "" - } - s := strings.TrimSpace(e.config.Embedding.Model) - if s != "" { - return s - } - return "text-embedding-3-small" -} - -func (e *Embedder) waitRateLimiter() { - e.mu.Lock() - defer e.mu.Unlock() - - if e.rateLimiter != nil { - ctx := context.Background() - if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil { - e.logger.Warn("速率限制器等待失败", zap.Error(err)) - } - } - if e.rateLimitDelay > 0 { - time.Sleep(e.rateLimitDelay) - } -} - -// EmbedText 单条嵌入(float32,与历史存储格式一致)。 -func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) { - vecs, err := e.EmbedStrings(ctx, []string{text}) - if err != nil { - return nil, err - } - if len(vecs) != 1 { - return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs)) - } - return vecs[0], nil -} - -// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。 -func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) { - if e == nil || e.eino == nil { - return nil, fmt.Errorf("embedder not initialized") - } - if len(texts) == 0 { - return nil, nil - } - - var lastErr error - for attempt := 0; attempt < e.maxRetries; attempt++ { - if attempt > 0 { - wait := e.retryDelay * time.Duration(attempt) - if e.logger != nil { - e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait)) - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(wait): - } - } else { - e.waitRateLimiter() - } - - raw, err := e.eino.EmbedStrings(ctx, texts, opts...) - if err == nil { - out := make([][]float32, len(raw)) - for i, row := range raw { - out[i] = make([]float32, len(row)) - for j, v := range row { - out[i][j] = float32(v) - } - } - return out, nil - } - lastErr = err - if !e.isRetryableError(err) { - return nil, err - } - if e.logger != nil { - e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err)) - } - } - return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr) -} - -// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。 -func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { - return e.EmbedStrings(ctx, texts) -} - -func (e *Embedder) isRetryableError(err error) bool { - if err == nil { - return false - } - errStr := err.Error() - if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") { - return true - } - if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") || - strings.Contains(errStr, "503") || strings.Contains(errStr, "504") { - return true - } - if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") || - strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") { - return true - } - return false -} - -// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store. -type einoFloatEmbedder struct { - inner *Embedder -} - -func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { - vec32, err := w.inner.EmbedStrings(ctx, texts, opts...) - if err != nil { - return nil, err - } - out := make([][]float64, len(vec32)) - for i, row := range vec32 { - out[i] = make([]float64, len(row)) - for j, v := range row { - out[i][j] = float64(v) - } - } - return out, nil -} - -func (w *einoFloatEmbedder) GetType() string { - return "CyberStrikeKnowledgeEmbedder" -} - -func (w *einoFloatEmbedder) IsCallbacksEnabled() bool { - return false -} - -// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path -// and produces float64 vectors expected by generic Eino indexer helpers. -func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder { - return &einoFloatEmbedder{inner: e} -} diff --git a/knowledge/index_pipeline.go b/knowledge/index_pipeline.go deleted file mode 100644 index de5d466e..00000000 --- a/knowledge/index_pipeline.go +++ /dev/null @@ -1,91 +0,0 @@ -package knowledge - -import ( - "context" - "database/sql" - "fmt" - "strings" - - "cyberstrike-ai/internal/config" - - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/components/document" - "github.com/cloudwego/eino/schema" -) - -// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive". -func normalizeChunkStrategy(s string) string { - v := strings.TrimSpace(strings.ToLower(s)) - switch v { - case "recursive": - return "recursive" - case "markdown_then_recursive", "markdown_recursive", "markdown": - return "markdown_then_recursive" - case "": - return "markdown_then_recursive" - default: - return "markdown_then_recursive" - } -} - -func buildKnowledgeIndexChain( - ctx context.Context, - indexingCfg *config.IndexingConfig, - db *sql.DB, - recursive document.Transformer, - embeddingModel string, -) (compose.Runnable[[]*schema.Document, []string], error) { - if recursive == nil { - return nil, fmt.Errorf("recursive transformer is nil") - } - if db == nil { - return nil, fmt.Errorf("db is nil") - } - strategy := normalizeChunkStrategy("markdown_then_recursive") - batch := 64 - maxChunks := 0 - if indexingCfg != nil { - strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy) - if indexingCfg.BatchSize > 0 { - batch = indexingCfg.BatchSize - } - maxChunks = indexingCfg.MaxChunksPerItem - } - - si := NewSQLiteIndexer(db, batch, embeddingModel) - ch := compose.NewChain[[]*schema.Document, []string]() - if strategy != "recursive" { - md, err := newMarkdownHeaderSplitter(ctx) - if err != nil { - return nil, fmt.Errorf("markdown splitter: %w", err) - } - ch.AppendDocumentTransformer(md) - } - ch.AppendDocumentTransformer(recursive) - ch.AppendLambda(newChunkEnrichLambda(maxChunks)) - ch.AppendIndexer(si) - return ch.Compile(ctx) -} - -func newChunkEnrichLambda(maxChunks int) *compose.Lambda { - return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) { - _ = ctx - out := make([]*schema.Document, 0, len(docs)) - for _, d := range docs { - if d == nil || strings.TrimSpace(d.Content) == "" { - continue - } - out = append(out, d) - } - if maxChunks > 0 && len(out) > maxChunks { - out = out[:maxChunks] - } - for i, d := range out { - if d.MetaData == nil { - d.MetaData = make(map[string]any) - } - d.MetaData[metaKBChunkIndex] = i - } - return out, nil - }) -} diff --git a/knowledge/index_pipeline_test.go b/knowledge/index_pipeline_test.go deleted file mode 100644 index 9e4b03fa..00000000 --- a/knowledge/index_pipeline_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package knowledge - -import "testing" - -func TestNormalizeChunkStrategy(t *testing.T) { - cases := []struct { - in, want string - }{ - {"", "markdown_then_recursive"}, - {"recursive", "recursive"}, - {"RECURSIVE", "recursive"}, - {"markdown_then_recursive", "markdown_then_recursive"}, - {"markdown", "markdown_then_recursive"}, - {"unknown", "markdown_then_recursive"}, - } - for _, tc := range cases { - if got := normalizeChunkStrategy(tc.in); got != tc.want { - t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want) - } - } -} diff --git a/knowledge/indexer.go b/knowledge/indexer.go deleted file mode 100644 index 390835c6..00000000 --- a/knowledge/indexer.go +++ /dev/null @@ -1,352 +0,0 @@ -package knowledge - -import ( - "context" - "database/sql" - "fmt" - "strings" - "sync" - "time" - - "cyberstrike-ai/internal/config" - - fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/components/document" - "github.com/cloudwego/eino/components/indexer" - "github.com/cloudwego/eino/schema" - "go.uber.org/zap" -) - -// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。 -type Indexer struct { - db *sql.DB - embedder *Embedder - logger *zap.Logger - chunkSize int - overlap int - indexingCfg *config.IndexingConfig - - indexChain compose.Runnable[[]*schema.Document, []string] - fileLoader *fileloader.FileLoader - - mu sync.RWMutex - lastError string - lastErrorTime time.Time - errorCount int - - rebuildMu sync.RWMutex - isRebuilding bool - rebuildTotalItems int - rebuildCurrent int - rebuildFailed int - rebuildStartTime time.Time - rebuildLastItemID string - rebuildLastChunks int -} - -// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。 -func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) { - if db == nil { - return nil, fmt.Errorf("db is nil") - } - if embedder == nil { - return nil, fmt.Errorf("embedder is nil") - } - if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil { - return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err) - } - if kcfg == nil { - kcfg = &config.KnowledgeConfig{} - } - indexingCfg := &kcfg.Indexing - - chunkSize := 512 - overlap := 50 - if indexingCfg.ChunkSize > 0 { - chunkSize = indexingCfg.ChunkSize - } - if indexingCfg.ChunkOverlap >= 0 { - overlap = indexingCfg.ChunkOverlap - } - - embedModel := embedder.EmbeddingModelName() - splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel) - if err != nil { - return nil, fmt.Errorf("eino recursive splitter: %w", err) - } - - chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel) - if err != nil { - return nil, fmt.Errorf("knowledge index chain: %w", err) - } - - var fl *fileloader.FileLoader - fl, err = fileloader.NewFileLoader(ctx, nil) - if err != nil { - if logger != nil { - logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err)) - } - fl = nil - err = nil - } - - return &Indexer{ - db: db, - embedder: embedder, - logger: logger, - chunkSize: chunkSize, - overlap: overlap, - indexingCfg: indexingCfg, - indexChain: chain, - fileLoader: fl, - }, nil -} - -// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。 -func (idx *Indexer) RecompileIndexChain(ctx context.Context) error { - if idx == nil || idx.db == nil || idx.embedder == nil { - return fmt.Errorf("indexer 未初始化") - } - if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil { - return err - } - embedModel := idx.embedder.EmbeddingModelName() - splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel) - if err != nil { - return fmt.Errorf("eino recursive splitter: %w", err) - } - chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel) - if err != nil { - return fmt.Errorf("knowledge index chain: %w", err) - } - idx.indexChain = chain - return nil -} - -// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。 -func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { - if idx.indexChain == nil { - return fmt.Errorf("索引链未初始化") - } - if idx.embedder == nil { - return fmt.Errorf("嵌入器未初始化") - } - - var content, category, title, filePath string - err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath) - if err != nil { - return fmt.Errorf("获取知识项失败:%w", err) - } - - if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil { - return fmt.Errorf("删除旧向量失败:%w", err) - } - - body := strings.TrimSpace(content) - if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil { - docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)}) - if lerr == nil && len(docs) > 0 { - var b strings.Builder - for i, d := range docs { - if d == nil { - continue - } - if i > 0 { - b.WriteString("\n\n") - } - b.WriteString(d.Content) - } - if s := strings.TrimSpace(b.String()); s != "" { - body = s - } - } else if idx.logger != nil { - idx.logger.Warn("优先源文件读取失败,使用数据库正文", - zap.String("itemId", itemID), - zap.String("path", filePath), - zap.Error(lerr)) - } - } - - root := &schema.Document{ - ID: itemID, - Content: body, - MetaData: map[string]any{ - metaKBCategory: category, - metaKBTitle: title, - metaKBItemID: itemID, - }, - } - - idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())} - if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 { - idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes)) - } - - ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...)) - if err != nil { - msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err) - idx.mu.Lock() - idx.lastError = msg - idx.lastErrorTime = time.Now() - idx.mu.Unlock() - return err - } - - if idx.logger != nil { - idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids))) - } - idx.rebuildMu.Lock() - idx.rebuildLastItemID = itemID - idx.rebuildLastChunks = len(ids) - idx.rebuildMu.Unlock() - return nil -} - -// HasIndex 检查是否存在索引 -func (idx *Indexer) HasIndex() (bool, error) { - var count int - err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count) - if err != nil { - return false, fmt.Errorf("检查索引失败:%w", err) - } - return count > 0, nil -} - -// RebuildIndex 重建所有索引 -func (idx *Indexer) RebuildIndex(ctx context.Context) error { - idx.rebuildMu.Lock() - idx.isRebuilding = true - idx.rebuildTotalItems = 0 - idx.rebuildCurrent = 0 - idx.rebuildFailed = 0 - idx.rebuildStartTime = time.Now() - idx.rebuildLastItemID = "" - idx.rebuildLastChunks = 0 - idx.rebuildMu.Unlock() - - idx.mu.Lock() - idx.lastError = "" - idx.lastErrorTime = time.Time{} - idx.errorCount = 0 - idx.mu.Unlock() - - rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") - if err != nil { - idx.rebuildMu.Lock() - idx.isRebuilding = false - idx.rebuildMu.Unlock() - return fmt.Errorf("查询知识项失败:%w", err) - } - defer rows.Close() - - var itemIDs []string - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - idx.rebuildMu.Lock() - idx.isRebuilding = false - idx.rebuildMu.Unlock() - return fmt.Errorf("扫描知识项 ID 失败:%w", err) - } - itemIDs = append(itemIDs, id) - } - - idx.rebuildMu.Lock() - idx.rebuildTotalItems = len(itemIDs) - idx.rebuildMu.Unlock() - - idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs))) - - failedCount := 0 - consecutiveFailures := 0 - maxConsecutiveFailures := 5 - firstFailureItemID := "" - var firstFailureError error - - for i, itemID := range itemIDs { - if err := idx.IndexItem(ctx, itemID); err != nil { - failedCount++ - consecutiveFailures++ - - if consecutiveFailures == 1 { - firstFailureItemID = itemID - firstFailureError = err - idx.logger.Warn("索引知识项失败", - zap.String("itemId", itemID), - zap.Int("totalItems", len(itemIDs)), - zap.Error(err), - ) - } - - if consecutiveFailures >= maxConsecutiveFailures { - errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError) - idx.mu.Lock() - idx.lastError = errorMsg - idx.lastErrorTime = time.Now() - idx.mu.Unlock() - - idx.logger.Error("连续索引失败次数过多,立即停止索引", - zap.Int("consecutiveFailures", consecutiveFailures), - zap.Int("totalItems", len(itemIDs)), - zap.Int("processedItems", i+1), - zap.String("firstFailureItemId", firstFailureItemID), - zap.Error(firstFailureError), - ) - return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError) - } - - if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 { - errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError) - idx.mu.Lock() - idx.lastError = errorMsg - idx.lastErrorTime = time.Now() - idx.mu.Unlock() - - idx.logger.Error("索引失败的知识项过多,可能存在配置问题", - zap.Int("failedCount", failedCount), - zap.Int("totalItems", len(itemIDs)), - zap.String("firstFailureItemId", firstFailureItemID), - zap.Error(firstFailureError), - ) - } - continue - } - - if consecutiveFailures > 0 { - consecutiveFailures = 0 - firstFailureItemID = "" - firstFailureError = nil - } - - idx.rebuildMu.Lock() - idx.rebuildCurrent = i + 1 - idx.rebuildFailed = failedCount - idx.rebuildMu.Unlock() - - if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) { - idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount)) - } - } - - idx.rebuildMu.Lock() - idx.isRebuilding = false - idx.rebuildMu.Unlock() - - idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount)) - return nil -} - -// GetLastError 获取最近一次错误信息 -func (idx *Indexer) GetLastError() (string, time.Time) { - idx.mu.RLock() - defer idx.mu.RUnlock() - return idx.lastError, idx.lastErrorTime -} - -// GetRebuildStatus 获取重建索引状态 -func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) { - idx.rebuildMu.RLock() - defer idx.rebuildMu.RUnlock() - return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime -} diff --git a/knowledge/manager.go b/knowledge/manager.go deleted file mode 100644 index 7309cc2a..00000000 --- a/knowledge/manager.go +++ /dev/null @@ -1,885 +0,0 @@ -package knowledge - -import ( - "database/sql" - "encoding/json" - "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "time" - - "github.com/google/uuid" - "go.uber.org/zap" -) - -// Manager 知识库管理器 -type Manager struct { - db *sql.DB - basePath string - logger *zap.Logger -} - -// NewManager 创建新的知识库管理器 -func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager { - return &Manager{ - db: db, - basePath: basePath, - logger: logger, - } -} - -// ScanKnowledgeBase 扫描知识库目录,更新数据库 -// 返回需要索引的知识项ID列表(新添加的或更新的) -func (m *Manager) ScanKnowledgeBase() ([]string, error) { - if m.basePath == "" { - return nil, fmt.Errorf("知识库路径未配置") - } - - // 确保目录存在 - if err := os.MkdirAll(m.basePath, 0755); err != nil { - return nil, fmt.Errorf("创建知识库目录失败: %w", err) - } - - var itemsToIndex []string - - // 遍历知识库目录 - err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - - // 跳过目录和非markdown文件 - if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") { - return nil - } - - // 计算相对路径和分类 - relPath, err := filepath.Rel(m.basePath, path) - if err != nil { - return err - } - - // 第一个目录名作为分类(风险类型) - parts := strings.Split(relPath, string(filepath.Separator)) - category := "未分类" - if len(parts) > 1 { - category = parts[0] - } - - // 文件名为标题 - title := strings.TrimSuffix(filepath.Base(path), ".md") - - // 读取文件内容 - content, err := os.ReadFile(path) - if err != nil { - m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err)) - return nil // 继续处理其他文件 - } - - // 检查是否已存在 - var existingID string - var existingContent string - var existingUpdatedAt time.Time - err = m.db.QueryRow( - "SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?", - path, - ).Scan(&existingID, &existingContent, &existingUpdatedAt) - - if err == sql.ErrNoRows { - // 创建新项 - id := uuid.New().String() - now := time.Now() - _, err = m.db.Exec( - "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, category, title, path, string(content), now, now, - ) - if err != nil { - return fmt.Errorf("插入知识项失败: %w", err) - } - m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category)) - // 新添加的项需要索引 - itemsToIndex = append(itemsToIndex, id) - } else if err == nil { - // 检查内容是否有变化 - contentChanged := existingContent != string(content) - if contentChanged { - // 更新现有项 - _, err = m.db.Exec( - "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", - category, title, string(content), time.Now(), existingID, - ) - if err != nil { - return fmt.Errorf("更新知识项失败: %w", err) - } - m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title)) - // 内容已更新的项需要重新索引 - itemsToIndex = append(itemsToIndex, existingID) - } else { - m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title)) - } - } else { - return fmt.Errorf("查询知识项失败: %w", err) - } - - return nil - }) - - if err != nil { - return nil, err - } - - return itemsToIndex, nil -} - -// GetCategories 获取所有分类(风险类型) -func (m *Manager) GetCategories() ([]string, error) { - rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category") - if err != nil { - return nil, fmt.Errorf("查询分类失败: %w", err) - } - defer rows.Close() - - var categories []string - for rows.Next() { - var category string - if err := rows.Scan(&category); err != nil { - return nil, fmt.Errorf("扫描分类失败: %w", err) - } - categories = append(categories, category) - } - - return categories, nil -} - -// GetStats 获取知识库统计信息 -func (m *Manager) GetStats() (int, int, error) { - // 获取分类总数 - categories, err := m.GetCategories() - if err != nil { - return 0, 0, fmt.Errorf("获取分类失败: %w", err) - } - totalCategories := len(categories) - - // 获取知识项总数 - var totalItems int - err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) - if err != nil { - return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err) - } - - return totalCategories, totalItems, nil -} - -// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项) -// limit: 每页分类数量(0表示不限制) -// offset: 偏移量(按分类偏移) -func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) { - // 首先获取所有分类(带数量统计) - rows, err := m.db.Query(` - SELECT category, COUNT(*) as item_count - FROM knowledge_base_items - GROUP BY category - ORDER BY category - `) - if err != nil { - return nil, 0, fmt.Errorf("查询分类失败: %w", err) - } - defer rows.Close() - - // 收集所有分类信息 - type categoryInfo struct { - name string - itemCount int - } - var allCategories []categoryInfo - for rows.Next() { - var info categoryInfo - if err := rows.Scan(&info.name, &info.itemCount); err != nil { - return nil, 0, fmt.Errorf("扫描分类失败: %w", err) - } - allCategories = append(allCategories, info) - } - - totalCategories := len(allCategories) - - // 应用分页(按分类分页) - var paginatedCategories []categoryInfo - if limit > 0 { - start := offset - end := offset + limit - if start >= totalCategories { - paginatedCategories = []categoryInfo{} - } else { - if end > totalCategories { - end = totalCategories - } - paginatedCategories = allCategories[start:end] - } - } else { - paginatedCategories = allCategories - } - - // 为每个分类获取其下的知识项(只返回摘要,不包含完整内容) - result := make([]*CategoryWithItems, 0, len(paginatedCategories)) - for _, catInfo := range paginatedCategories { - // 获取该分类下的所有知识项 - items, _, err := m.GetItemsSummary(catInfo.name, 0, 0) - if err != nil { - return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err) - } - - result = append(result, &CategoryWithItems{ - Category: catInfo.name, - ItemCount: catInfo.itemCount, - Items: items, - }) - } - - return result, totalCategories, nil -} - -// GetItems 获取知识项列表(完整内容,用于向后兼容) -func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { - return m.GetItemsWithOptions(category, 0, 0, true) -} - -// GetItemsWithOptions 获取知识项列表(支持分页和可选内容) -// category: 分类筛选(空字符串表示所有分类) -// limit: 每页数量(0表示不限制) -// offset: 偏移量 -// includeContent: 是否包含完整内容(false时只返回摘要) -func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) { - var rows *sql.Rows - var err error - - // 构建SQL查询 - var query string - var args []interface{} - - if includeContent { - query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items" - } else { - query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" - } - - if category != "" { - query += " WHERE category = ?" - args = append(args, category) - } - - query += " ORDER BY category, title" - - if limit > 0 { - query += " LIMIT ?" - args = append(args, limit) - if offset > 0 { - query += " OFFSET ?" - args = append(args, offset) - } - } - - rows, err = m.db.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("查询知识项失败: %w", err) - } - defer rows.Close() - - var items []*KnowledgeItem - for rows.Next() { - item := &KnowledgeItem{} - var createdAt, updatedAt string - - if includeContent { - if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil { - return nil, fmt.Errorf("扫描知识项失败: %w", err) - } - } else { - if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { - return nil, fmt.Errorf("扫描知识项失败: %w", err) - } - // 不包含内容时,Content为空字符串 - item.Content = "" - } - - // 解析时间 - 支持多种格式 - timeFormats := []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999Z07:00", - "2006-01-02T15:04:05Z", - "2006-01-02 15:04:05", - time.RFC3339, - time.RFC3339Nano, - } - - // 解析创建时间 - if createdAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, createdAt) - if err == nil && !parsed.IsZero() { - item.CreatedAt = parsed - break - } - } - } - - // 解析更新时间 - if updatedAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, updatedAt) - if err == nil && !parsed.IsZero() { - item.UpdatedAt = parsed - break - } - } - } - - // 如果更新时间为空,使用创建时间 - if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { - item.UpdatedAt = item.CreatedAt - } - - items = append(items, item) - } - - return items, nil -} - -// GetItemsCount 获取知识项总数 -func (m *Manager) GetItemsCount(category string) (int, error) { - var count int - var err error - - if category != "" { - err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count) - } else { - err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count) - } - - if err != nil { - return 0, fmt.Errorf("查询知识项总数失败: %w", err) - } - - return count, nil -} - -// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配) -func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) { - if keyword == "" { - return nil, fmt.Errorf("搜索关键字不能为空") - } - - // 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写) - var query string - var args []interface{} - - // SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数 - // 使用%keyword%进行模糊匹配 - searchPattern := "%" + keyword + "%" - - query = ` - SELECT id, category, title, file_path, created_at, updated_at - FROM knowledge_base_items - WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?)) - ` - args = append(args, searchPattern, searchPattern, searchPattern, searchPattern) - - // 如果指定了分类,添加分类过滤 - if category != "" { - query += " AND category = ?" - args = append(args, category) - } - - query += " ORDER BY category, title" - - rows, err := m.db.Query(query, args...) - if err != nil { - return nil, fmt.Errorf("搜索知识项失败: %w", err) - } - defer rows.Close() - - var items []*KnowledgeItemSummary - for rows.Next() { - item := &KnowledgeItemSummary{} - var createdAt, updatedAt string - - if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { - return nil, fmt.Errorf("扫描知识项失败: %w", err) - } - - // 解析时间 - timeFormats := []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999Z07:00", - "2006-01-02T15:04:05Z", - "2006-01-02 15:04:05", - time.RFC3339, - time.RFC3339Nano, - } - - if createdAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, createdAt) - if err == nil && !parsed.IsZero() { - item.CreatedAt = parsed - break - } - } - } - - if updatedAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, updatedAt) - if err == nil && !parsed.IsZero() { - item.UpdatedAt = parsed - break - } - } - } - - if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { - item.UpdatedAt = item.CreatedAt - } - - items = append(items, item) - } - - return items, nil -} - -// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页) -func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) { - // 获取总数 - total, err := m.GetItemsCount(category) - if err != nil { - return nil, 0, err - } - - // 获取列表数据(不包含内容) - var rows *sql.Rows - var query string - var args []interface{} - - query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" - - if category != "" { - query += " WHERE category = ?" - args = append(args, category) - } - - query += " ORDER BY category, title" - - if limit > 0 { - query += " LIMIT ?" - args = append(args, limit) - if offset > 0 { - query += " OFFSET ?" - args = append(args, offset) - } - } - - rows, err = m.db.Query(query, args...) - if err != nil { - return nil, 0, fmt.Errorf("查询知识项失败: %w", err) - } - defer rows.Close() - - var items []*KnowledgeItemSummary - for rows.Next() { - item := &KnowledgeItemSummary{} - var createdAt, updatedAt string - - if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { - return nil, 0, fmt.Errorf("扫描知识项失败: %w", err) - } - - // 解析时间 - timeFormats := []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999Z07:00", - "2006-01-02T15:04:05Z", - "2006-01-02 15:04:05", - time.RFC3339, - time.RFC3339Nano, - } - - if createdAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, createdAt) - if err == nil && !parsed.IsZero() { - item.CreatedAt = parsed - break - } - } - } - - if updatedAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, updatedAt) - if err == nil && !parsed.IsZero() { - item.UpdatedAt = parsed - break - } - } - } - - if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { - item.UpdatedAt = item.CreatedAt - } - - items = append(items, item) - } - - return items, total, nil -} - -// GetItem 获取单个知识项 -func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { - item := &KnowledgeItem{} - var createdAt, updatedAt string - err := m.db.QueryRow( - "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?", - id, - ).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt) - - if err == sql.ErrNoRows { - return nil, fmt.Errorf("知识项不存在") - } - if err != nil { - return nil, fmt.Errorf("查询知识项失败: %w", err) - } - - // 解析时间 - 支持多种格式 - timeFormats := []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999Z07:00", - "2006-01-02T15:04:05Z", - "2006-01-02 15:04:05", - time.RFC3339, - time.RFC3339Nano, - } - - // 解析创建时间 - if createdAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, createdAt) - if err == nil && !parsed.IsZero() { - item.CreatedAt = parsed - break - } - } - } - - // 解析更新时间 - if updatedAt != "" { - for _, format := range timeFormats { - parsed, err := time.Parse(format, updatedAt) - if err == nil && !parsed.IsZero() { - item.UpdatedAt = parsed - break - } - } - } - - // 如果更新时间为空,使用创建时间 - if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { - item.UpdatedAt = item.CreatedAt - } - - return item, nil -} - -// CreateItem 创建知识项 -func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) { - id := uuid.New().String() - now := time.Now() - - // 构建文件路径 - filePath := filepath.Join(m.basePath, category, title+".md") - - // 确保目录存在 - if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { - return nil, fmt.Errorf("创建目录失败: %w", err) - } - - // 写入文件 - if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { - return nil, fmt.Errorf("写入文件失败: %w", err) - } - - // 插入数据库 - _, err := m.db.Exec( - "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, category, title, filePath, content, now, now, - ) - if err != nil { - return nil, fmt.Errorf("插入知识项失败: %w", err) - } - - return &KnowledgeItem{ - ID: id, - Category: category, - Title: title, - FilePath: filePath, - Content: content, - CreatedAt: now, - UpdatedAt: now, - }, nil -} - -// UpdateItem 更新知识项 -func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) { - // 获取现有项 - item, err := m.GetItem(id) - if err != nil { - return nil, err - } - - // 构建新文件路径 - newFilePath := filepath.Join(m.basePath, category, title+".md") - - // 如果路径改变,需要移动文件 - if item.FilePath != newFilePath { - // 确保新目录存在 - if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil { - return nil, fmt.Errorf("创建目录失败: %w", err) - } - - // 移动文件 - if err := os.Rename(item.FilePath, newFilePath); err != nil { - return nil, fmt.Errorf("移动文件失败: %w", err) - } - - // 删除旧目录(如果为空) - oldDir := filepath.Dir(item.FilePath) - if isEmpty, _ := isEmptyDir(oldDir); isEmpty { - // 只有当目录不是知识库根目录时才删除(避免删除根目录) - if oldDir != m.basePath { - if err := os.Remove(oldDir); err != nil { - m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err)) - } - } - } - } - - // 写入文件 - if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil { - return nil, fmt.Errorf("写入文件失败: %w", err) - } - - // 更新数据库 - _, err = m.db.Exec( - "UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?", - category, title, newFilePath, content, time.Now(), id, - ) - if err != nil { - return nil, fmt.Errorf("更新知识项失败: %w", err) - } - - // 删除旧的向量嵌入(需要重新索引) - _, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id) - if err != nil { - m.logger.Warn("删除旧向量嵌入失败", zap.Error(err)) - } - - return m.GetItem(id) -} - -// DeleteItem 删除知识项 -func (m *Manager) DeleteItem(id string) error { - // 获取文件路径 - var filePath string - err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath) - if err != nil { - return fmt.Errorf("查询知识项失败: %w", err) - } - - // 删除文件 - if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { - m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err)) - } - - // 删除数据库记录(级联删除向量) - _, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id) - if err != nil { - return fmt.Errorf("删除知识项失败: %w", err) - } - - // 删除空目录(如果为空) - dir := filepath.Dir(filePath) - if isEmpty, _ := isEmptyDir(dir); isEmpty { - // 只有当目录不是知识库根目录时才删除(避免删除根目录) - if dir != m.basePath { - if err := os.Remove(dir); err != nil { - m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err)) - } - } - } - - return nil -} - -// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件) -func isEmptyDir(dir string) (bool, error) { - entries, err := os.ReadDir(dir) - if err != nil { - return false, err - } - for _, entry := range entries { - // 忽略隐藏文件(以 . 开头) - if !strings.HasPrefix(entry.Name(), ".") { - return false, nil - } - } - return true, nil -} - -// LogRetrieval 记录检索日志 -func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error { - id := uuid.New().String() - itemsJSON, _ := json.Marshal(retrievedItems) - - _, err := m.db.Exec( - "INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(), - ) - return err -} - -// GetIndexStatus 获取索引状态 -func (m *Manager) GetIndexStatus() (map[string]interface{}, error) { - // 获取总知识项数 - var totalItems int - err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) - if err != nil { - return nil, fmt.Errorf("查询总知识项数失败: %w", err) - } - - // 获取已索引的知识项数(有向量嵌入的) - var indexedItems int - err = m.db.QueryRow(` - SELECT COUNT(DISTINCT item_id) - FROM knowledge_embeddings - `).Scan(&indexedItems) - if err != nil { - return nil, fmt.Errorf("查询已索引项数失败: %w", err) - } - - // 计算进度百分比 - var progressPercent float64 - if totalItems > 0 { - progressPercent = float64(indexedItems) / float64(totalItems) * 100 - } else { - progressPercent = 100.0 - } - - // 判断是否完成 - isComplete := indexedItems >= totalItems && totalItems > 0 - - return map[string]interface{}{ - "total_items": totalItems, - "indexed_items": indexedItems, - "progress_percent": progressPercent, - "is_complete": isComplete, - }, nil -} - -// GetRetrievalLogs 获取检索日志 -func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) { - var rows *sql.Rows - var err error - - if messageID != "" { - rows, err = m.db.Query( - "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?", - messageID, limit, - ) - } else if conversationID != "" { - rows, err = m.db.Query( - "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?", - conversationID, limit, - ) - } else { - rows, err = m.db.Query( - "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?", - limit, - ) - } - - if err != nil { - return nil, fmt.Errorf("查询检索日志失败: %w", err) - } - defer rows.Close() - - var logs []*RetrievalLog - for rows.Next() { - log := &RetrievalLog{} - var createdAt string - var itemsJSON sql.NullString - if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil { - return nil, fmt.Errorf("扫描检索日志失败: %w", err) - } - - // 解析时间 - 支持多种格式 - var err error - timeFormats := []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999Z07:00", - "2006-01-02T15:04:05Z", - "2006-01-02 15:04:05", - time.RFC3339, - time.RFC3339Nano, - } - - for _, format := range timeFormats { - log.CreatedAt, err = time.Parse(format, createdAt) - if err == nil && !log.CreatedAt.IsZero() { - break - } - } - - // 如果所有格式都失败,记录警告但继续处理 - if log.CreatedAt.IsZero() { - m.logger.Warn("解析检索日志时间失败", - zap.String("timeStr", createdAt), - zap.Error(err), - ) - // 使用当前时间作为fallback - log.CreatedAt = time.Now() - } - - // 解析检索项 - if itemsJSON.Valid { - json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems) - } - - logs = append(logs, log) - } - - return logs, nil -} - -// DeleteRetrievalLog 删除检索日志 -func (m *Manager) DeleteRetrievalLog(id string) error { - result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id) - if err != nil { - return fmt.Errorf("删除检索日志失败: %w", err) - } - - rowsAffected, err := result.RowsAffected() - if err != nil { - return fmt.Errorf("获取删除行数失败: %w", err) - } - - if rowsAffected == 0 { - return fmt.Errorf("检索日志不存在") - } - - return nil -} diff --git a/knowledge/retrieval_postprocess.go b/knowledge/retrieval_postprocess.go deleted file mode 100644 index eb69e4c3..00000000 --- a/knowledge/retrieval_postprocess.go +++ /dev/null @@ -1,213 +0,0 @@ -package knowledge - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "fmt" - "strings" - "sync" - "unicode" - "unicode/utf8" - - "cyberstrike-ai/internal/config" - - "github.com/cloudwego/eino/schema" - "github.com/pkoukk/tiktoken-go" -) - -// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。 -const postRetrieveMaxPrefetchCap = 200 - -// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。 -type DocumentReranker interface { - Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) -} - -// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。 -type NopDocumentReranker struct{} - -// Rerank implements [DocumentReranker] as no-op. -func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) { - return docs, nil -} - -var tiktokenEncMu sync.Mutex -var tiktokenEncCache = map[string]*tiktoken.Tiktoken{} - -func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) { - m := strings.TrimSpace(model) - if m == "" { - m = "gpt-4" - } - tiktokenEncMu.Lock() - defer tiktokenEncMu.Unlock() - if enc, ok := tiktokenEncCache[m]; ok { - return enc, nil - } - enc, err := tiktoken.EncodingForModel(m) - if err != nil { - enc, err = tiktoken.GetEncoding("cl100k_base") - if err != nil { - return nil, err - } - } - tiktokenEncCache[m] = enc - return enc, nil -} - -func countDocTokens(text, model string) (int, error) { - enc, err := encodingForTokenizerModel(model) - if err != nil { - return 0, err - } - toks := enc.Encode(text, nil, nil) - return len(toks), nil -} - -// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。 -func normalizeContentFingerprintKey(s string) string { - s = strings.TrimSpace(s) - var b strings.Builder - b.Grow(len(s)) - prevSpace := false - for _, r := range s { - if unicode.IsSpace(r) { - if !prevSpace { - b.WriteByte(' ') - prevSpace = true - } - continue - } - prevSpace = false - b.WriteRune(r) - } - return b.String() -} - -func contentNormKey(d *schema.Document) string { - if d == nil { - return "" - } - n := normalizeContentFingerprintKey(d.Content) - if n == "" { - return "" - } - sum := sha256.Sum256([]byte(n)) - return hex.EncodeToString(sum[:]) -} - -// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。 -func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document { - if len(docs) < 2 { - return docs - } - seen := make(map[string]struct{}, len(docs)) - out := make([]*schema.Document, 0, len(docs)) - for _, d := range docs { - if d == nil { - continue - } - k := contentNormKey(d) - if k == "" { - out = append(out, d) - continue - } - if _, ok := seen[k]; ok { - continue - } - seen[k] = struct{}{} - out = append(out, d) - } - return out -} - -// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。 -func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) { - if len(docs) == 0 { - return docs, nil - } - unlimitedChars := maxRunes <= 0 - unlimitedTok := maxTokens <= 0 - if unlimitedChars && unlimitedTok { - return docs, nil - } - - remRunes := maxRunes - remTok := maxTokens - out := make([]*schema.Document, 0, len(docs)) - - for _, d := range docs { - if d == nil || strings.TrimSpace(d.Content) == "" { - continue - } - runes := utf8.RuneCountInString(d.Content) - if !unlimitedChars && runes > remRunes { - break - } - var tok int - var err error - if !unlimitedTok { - tok, err = countDocTokens(d.Content, tokenModel) - if err != nil { - return nil, fmt.Errorf("token count: %w", err) - } - if tok > remTok { - break - } - } - out = append(out, d) - if !unlimitedChars { - remRunes -= runes - } - if !unlimitedTok { - remTok -= tok - } - } - return out, nil -} - -// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。 -func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int { - if topK < 1 { - topK = 5 - } - fetch := topK - if po != nil && po.PrefetchTopK > fetch { - fetch = po.PrefetchTopK - } - if fetch > postRetrieveMaxPrefetchCap { - fetch = postRetrieveMaxPrefetchCap - } - return fetch -} - -// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。 -func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) { - if finalTopK < 1 { - finalTopK = 5 - } - if len(docs) == 0 { - return docs, nil - } - - maxChars := 0 - maxTok := 0 - if po != nil { - maxChars = po.MaxContextChars - maxTok = po.MaxContextTokens - } - - out := dedupeByNormalizedContent(docs) - - var err error - out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel) - if err != nil { - return nil, err - } - - if len(out) > finalTopK { - out = out[:finalTopK] - } - return out, nil -} diff --git a/knowledge/retrieval_postprocess_test.go b/knowledge/retrieval_postprocess_test.go deleted file mode 100644 index 10c661a8..00000000 --- a/knowledge/retrieval_postprocess_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package knowledge - -import ( - "testing" - - "cyberstrike-ai/internal/config" - - "github.com/cloudwego/eino/schema" -) - -func doc(id, content string, score float64) *schema.Document { - d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}} - d.WithScore(score) - return d -} - -func TestDedupeByNormalizedContent(t *testing.T) { - a := doc("1", "hello world", 0.9) - b := doc("2", "hello world", 0.8) - c := doc("3", "other", 0.7) - out := dedupeByNormalizedContent([]*schema.Document{a, b, c}) - if len(out) != 2 { - t.Fatalf("len=%d want 2", len(out)) - } - if out[0].ID != "1" || out[1].ID != "3" { - t.Fatalf("order/ids wrong: %#v", out) - } -} - -func TestEffectivePrefetchTopK(t *testing.T) { - if g := EffectivePrefetchTopK(5, nil); g != 5 { - t.Fatalf("got %d", g) - } - if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 { - t.Fatalf("got %d", g) - } - if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap { - t.Fatalf("cap: got %d", g) - } -} - -func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) { - d1 := doc("1", "ab", 0.9) - d2 := doc("2", "cd", 0.8) - d3 := doc("3", "ef", 0.7) - po := &config.PostRetrieveConfig{MaxContextChars: 3} - out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5) - if err != nil { - t.Fatal(err) - } - if len(out) != 1 || out[0].ID != "1" { - t.Fatalf("got %#v", out) - } - - out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2) - if err != nil { - t.Fatal(err) - } - if len(out2) != 2 { - t.Fatalf("topk: len=%d", len(out2)) - } -} diff --git a/knowledge/retriever.go b/knowledge/retriever.go deleted file mode 100644 index 9145b2c6..00000000 --- a/knowledge/retriever.go +++ /dev/null @@ -1,305 +0,0 @@ -package knowledge - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "math" - "sort" - "strings" - "sync" - - "cyberstrike-ai/internal/config" - - "github.com/cloudwego/eino/components/retriever" - "github.com/cloudwego/eino/schema" - "go.uber.org/zap" -) - -// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值), -// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。 -type Retriever struct { - db *sql.DB - embedder *Embedder - config *RetrievalConfig - logger *zap.Logger - - rerankMu sync.RWMutex - reranker DocumentReranker -} - -// RetrievalConfig 检索配置 -type RetrievalConfig struct { - TopK int - SimilarityThreshold float64 - // SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。 - SubIndexFilter string - PostRetrieve config.PostRetrieveConfig -} - -// NewRetriever 创建新的检索器 -func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever { - return &Retriever{ - db: db, - embedder: embedder, - config: config, - logger: logger, - } -} - -// UpdateConfig 更新检索配置 -func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) { - if cfg != nil { - r.config = cfg - if r.logger != nil { - r.logger.Info("检索器配置已更新", - zap.Int("top_k", cfg.TopK), - zap.Float64("similarity_threshold", cfg.SimilarityThreshold), - zap.String("sub_index_filter", cfg.SubIndexFilter), - zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK), - zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars), - zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens), - ) - } - } -} - -// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。 -func (r *Retriever) SetDocumentReranker(rr DocumentReranker) { - if r == nil { - return - } - r.rerankMu.Lock() - defer r.rerankMu.Unlock() - r.reranker = rr -} - -func (r *Retriever) documentReranker() DocumentReranker { - if r == nil { - return nil - } - r.rerankMu.RLock() - defer r.rerankMu.RUnlock() - return r.reranker -} - -func cosineSimilarity(a, b []float32) float64 { - if len(a) != len(b) { - return 0.0 - } - - var dotProduct, normA, normB float64 - for i := range a { - dotProduct += float64(a[i] * b[i]) - normA += float64(a[i] * a[i]) - normB += float64(b[i] * b[i]) - } - - if normA == 0 || normB == 0 { - return 0.0 - } - - return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) -} - -// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。 -func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { - if req == nil { - return nil, fmt.Errorf("请求不能为空") - } - q := strings.TrimSpace(req.Query) - if q == "" { - return nil, fmt.Errorf("查询不能为空") - } - opts := r.einoRetrieverOptions(req) - docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...) - if err != nil { - return nil, err - } - return documentsToRetrievalResults(docs) -} - -func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option { - var opts []retriever.Option - if req.TopK > 0 { - opts = append(opts, retriever.WithTopK(req.TopK)) - } - dsl := map[string]any{} - if strings.TrimSpace(req.RiskType) != "" { - dsl[DSLRiskType] = strings.TrimSpace(req.RiskType) - } - if req.Threshold > 0 { - dsl[DSLSimilarityThreshold] = req.Threshold - } - if strings.TrimSpace(req.SubIndexFilter) != "" { - dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter) - } - if len(dsl) > 0 { - opts = append(opts, retriever.WithDSLInfo(dsl)) - } - return opts -} - -// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。 -func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { - return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...) -} - -func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) { - q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title -FROM knowledge_embeddings e -JOIN knowledge_base_items i ON e.item_id = i.id -WHERE 1=1` - var args []interface{} - if strings.TrimSpace(riskType) != "" { - q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE` - args = append(args, riskType) - } - if tag := strings.TrimSpace(subIndexFilter); tag != "" { - tag = strings.ToLower(strings.ReplaceAll(tag, " ", "")) - q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)` - args = append(args, tag) - } - return q, args -} - -// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。 -func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { - if req.Query == "" { - return nil, fmt.Errorf("查询不能为空") - } - - topK := req.TopK - if topK <= 0 && r.config != nil { - topK = r.config.TopK - } - if topK <= 0 { - topK = 5 - } - - threshold := req.Threshold - if threshold <= 0 && r.config != nil { - threshold = r.config.SimilarityThreshold - } - if threshold <= 0 { - threshold = 0.7 - } - - subIdxFilter := strings.TrimSpace(req.SubIndexFilter) - if subIdxFilter == "" && r.config != nil { - subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter) - } - - queryText := FormatQueryEmbeddingText(req.RiskType, req.Query) - queryEmbedding, err := r.embedder.EmbedText(ctx, queryText) - if err != nil { - return nil, fmt.Errorf("向量化查询失败: %w", err) - } - queryDim := len(queryEmbedding) - expectedModel := "" - if r.embedder != nil { - expectedModel = r.embedder.EmbeddingModelName() - } - - sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter) - rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...) - if err != nil { - return nil, fmt.Errorf("查询向量失败: %w", err) - } - defer rows.Close() - - type candidate struct { - chunk *KnowledgeChunk - item *KnowledgeItem - similarity float64 - } - - candidates := make([]candidate, 0) - rowNum := 0 - for rows.Next() { - rowNum++ - if rowNum%48 == 0 { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - } - - var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string - var chunkIndex, rowDim int - - if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil { - r.logger.Warn("扫描向量失败", zap.Error(err)) - continue - } - - var embedding []float32 - if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil { - r.logger.Warn("解析向量失败", zap.Error(err)) - continue - } - - if rowDim > 0 && len(embedding) != rowDim { - r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding))) - continue - } - if queryDim > 0 && len(embedding) != queryDim { - r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding))) - continue - } - if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel { - r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel)) - continue - } - - similarity := cosineSimilarity(queryEmbedding, embedding) - candidates = append(candidates, candidate{ - chunk: &KnowledgeChunk{ - ID: chunkID, - ItemID: itemID, - ChunkIndex: chunkIndex, - ChunkText: chunkText, - Embedding: embedding, - }, - item: &KnowledgeItem{ - ID: itemID, - Category: category, - Title: title, - }, - similarity: similarity, - }) - } - - sort.Slice(candidates, func(i, j int) bool { - return candidates[i].similarity > candidates[j].similarity - }) - - filtered := make([]candidate, 0, len(candidates)) - for _, c := range candidates { - if c.similarity >= threshold { - filtered = append(filtered, c) - } - } - - if len(filtered) > topK { - filtered = filtered[:topK] - } - - results := make([]*RetrievalResult, len(filtered)) - for i, c := range filtered { - results[i] = &RetrievalResult{ - Chunk: c.chunk, - Item: c.item, - Similarity: c.similarity, - Score: c.similarity, - } - } - return results, nil -} - -// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。 -func (r *Retriever) AsEinoRetriever() retriever.Retriever { - return NewVectorEinoRetriever(r) -} diff --git a/knowledge/schema_migrate.go b/knowledge/schema_migrate.go deleted file mode 100644 index 85fd26e2..00000000 --- a/knowledge/schema_migrate.go +++ /dev/null @@ -1,51 +0,0 @@ -package knowledge - -import ( - "database/sql" - "fmt" -) - -// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata. -func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error { - if db == nil { - return fmt.Errorf("db is nil") - } - var n int - if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { - return err - } - if n == 0 { - return nil - } - if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes", - `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil { - return err - } - if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model", - `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil { - return err - } - if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim", - `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil { - return err - } - return nil -} - -func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error { - var colCount int - q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` - if err := db.QueryRow(q, column).Scan(&colCount); err != nil { - return err - } - if colCount > 0 { - return nil - } - _, err := db.Exec(alterSQL) - return err -} - -// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。 -func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error { - return EnsureKnowledgeEmbeddingsSchema(db) -} diff --git a/knowledge/tool.go b/knowledge/tool.go deleted file mode 100644 index c7aa3f68..00000000 --- a/knowledge/tool.go +++ /dev/null @@ -1,323 +0,0 @@ -package knowledge - -import ( - "context" - "encoding/json" - "fmt" - "sort" - "strings" - - "cyberstrike-ai/internal/mcp" - "cyberstrike-ai/internal/mcp/builtin" - - "go.uber.org/zap" -) - -// RegisterKnowledgeTool 注册知识检索工具到MCP服务器 -func RegisterKnowledgeTool( - mcpServer *mcp.Server, - retriever *Retriever, - manager *Manager, - logger *zap.Logger, -) { - // 注册第一个工具:获取所有可用的风险类型列表 - listRiskTypesTool := mcp.Tool{ - Name: builtin.ToolListKnowledgeRiskTypes, - Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。", - ShortDescription: "获取知识库中所有可用的风险类型列表", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{}, - "required": []string{}, - }, - } - - listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - categories, err := manager.GetCategories() - if err != nil { - logger.Error("获取风险类型列表失败", zap.Error(err)) - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("获取风险类型列表失败: %v", err), - }, - }, - IsError: true, - }, nil - } - - if len(categories) == 0 { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: "知识库中暂无风险类型。", - }, - }, - }, nil - } - - var resultText strings.Builder - resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories))) - for i, category := range categories { - resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category)) - } - resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。") - - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: resultText.String(), - }, - }, - }, nil - } - - mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler) - logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name)) - - // 注册第二个工具:搜索知识库(保持原有功能) - searchTool := mcp.Tool{ - Name: builtin.ToolSearchKnowledgeBase, - Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", - ShortDescription: "搜索知识库中的安全知识(向量语义检索)", - InputSchema: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "query": map[string]interface{}{ - "type": "string", - "description": "搜索查询内容,描述你想要了解的安全知识主题", - }, - "risk_type": map[string]interface{}{ - "type": "string", - "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", - }, - }, - "required": []string{"query"}, - }, - } - - searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { - query, ok := args["query"].(string) - if !ok || query == "" { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: "错误: 查询参数不能为空", - }, - }, - IsError: true, - }, nil - } - - riskType := "" - if rt, ok := args["risk_type"].(string); ok && rt != "" { - riskType = rt - } - - logger.Info("执行知识库检索", - zap.String("query", query), - zap.String("riskType", riskType), - ) - - // 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。 - searchReq := &SearchRequest{ - Query: query, - RiskType: riskType, - TopK: 5, - } - - results, err := retriever.Search(ctx, searchReq) - if err != nil { - logger.Error("知识库检索失败", zap.Error(err)) - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("检索失败: %v", err), - }, - }, - IsError: true, - }, nil - } - - if len(results) == 0 { - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query), - }, - }, - }, nil - } - - // 格式化结果 - var resultText strings.Builder - - // 按余弦相似度(Score)降序 - sort.Slice(results, func(i, j int) bool { - return results[i].Score > results[j].Score - }) - - // 按文档分组结果,以便更好地展示上下文 - type itemGroup struct { - itemID string - results []*RetrievalResult - maxScore float64 // 该文档块的最高相似度 - } - itemGroups := make([]*itemGroup, 0) - itemMap := make(map[string]*itemGroup) - - for _, result := range results { - itemID := result.Item.ID - group, exists := itemMap[itemID] - if !exists { - group = &itemGroup{ - itemID: itemID, - results: make([]*RetrievalResult, 0), - maxScore: result.Score, - } - itemMap[itemID] = group - itemGroups = append(itemGroups, group) - } - group.results = append(group.results, result) - if result.Score > group.maxScore { - group.maxScore = result.Score - } - } - - // 按文档内最高相似度排序 - sort.Slice(itemGroups, func(i, j int) bool { - return itemGroups[i].maxScore > itemGroups[j].maxScore - }) - - // 收集检索到的知识项ID(用于日志) - retrievedItemIDs := make([]string, 0, len(itemGroups)) - - resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results))) - - resultIndex := 1 - for _, group := range itemGroups { - itemResults := group.results - mainResult := itemResults[0] - maxScore := mainResult.Score - for _, result := range itemResults { - if result.Score > maxScore { - maxScore = result.Score - mainResult = result - } - } - - // 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序) - sort.Slice(itemResults, func(i, j int) bool { - return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex - }) - - resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", - resultIndex, mainResult.Similarity*100)) - resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) - - // 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk) - if len(itemResults) == 1 { - // 只有一个chunk,直接显示 - resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText)) - } else { - // 多个chunk,按逻辑顺序显示 - resultText.WriteString("内容片段(按文档顺序):\n") - for i, result := range itemResults { - // 标记主结果 - marker := "" - if result.Chunk.ID == mainResult.Chunk.ID { - marker = " [主匹配]" - } - resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText)) - } - } - resultText.WriteString("\n") - - if !contains(retrievedItemIDs, group.itemID) { - retrievedItemIDs = append(retrievedItemIDs, group.itemID) - } - resultIndex++ - } - - // 在结果末尾添加元数据(JSON格式,用于提取知识项ID) - // 使用特殊标记,避免影响AI阅读结果 - if len(retrievedItemIDs) > 0 { - metadataJSON, _ := json.Marshal(map[string]interface{}{ - "_metadata": map[string]interface{}{ - "retrievedItemIDs": retrievedItemIDs, - }, - }) - resultText.WriteString(fmt.Sprintf("\n", string(metadataJSON))) - } - - // 记录检索日志(异步,不阻塞) - // 注意:这里没有conversationID和messageID,需要在Agent层面记录 - // 实际的日志记录应该在Agent的progressCallback中完成 - - return &mcp.ToolResult{ - Content: []mcp.Content{ - { - Type: "text", - Text: resultText.String(), - }, - }, - }, nil - } - - mcpServer.RegisterTool(searchTool, searchHandler) - logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name)) -} - -// contains 检查切片是否包含元素 -func contains(slice []string, item string) bool { - for _, s := range slice { - if s == item { - return true - } - } - return false -} - -// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录) -func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) { - if q, ok := args["query"].(string); ok { - query = q - } - if rt, ok := args["risk_type"].(string); ok { - riskType = rt - } - return -} - -// FormatRetrievalResults 格式化检索结果为字符串(用于日志) -func FormatRetrievalResults(results []*RetrievalResult) string { - if len(results) == 0 { - return "未找到相关结果" - } - - var builder strings.Builder - builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results))) - - itemIDs := make(map[string]bool) - for i, result := range results { - builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n", - i+1, result.Item.Category, result.Item.Title, result.Similarity*100)) - itemIDs[result.Item.ID] = true - } - - // 返回知识项ID列表(JSON格式) - ids := make([]string, 0, len(itemIDs)) - for id := range itemIDs { - ids = append(ids, id) - } - idsJSON, _ := json.Marshal(ids) - builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON))) - - return builder.String() -} diff --git a/knowledge/types.go b/knowledge/types.go deleted file mode 100644 index 80d0eb5f..00000000 --- a/knowledge/types.go +++ /dev/null @@ -1,123 +0,0 @@ -package knowledge - -import ( - "encoding/json" - "time" -) - -// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串 -func formatTime(t time.Time) string { - if t.IsZero() { - return "" - } - return t.Format(time.RFC3339) -} - -// KnowledgeItem 知识库项 -type KnowledgeItem struct { - ID string `json:"id"` - Category string `json:"category"` // 风险类型(文件夹名) - Title string `json:"title"` // 标题(文件名) - FilePath string `json:"filePath"` // 文件路径 - Content string `json:"content"` // 文件内容 - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容) -type KnowledgeItemSummary struct { - ID string `json:"id"` - Category string `json:"category"` - Title string `json:"title"` - FilePath string `json:"filePath"` - Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符) - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` -} - -// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 -func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) { - type Alias KnowledgeItemSummary - aux := &struct { - *Alias - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - }{ - Alias: (*Alias)(k), - } - aux.CreatedAt = formatTime(k.CreatedAt) - aux.UpdatedAt = formatTime(k.UpdatedAt) - return json.Marshal(aux) -} - -// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 -func (k *KnowledgeItem) MarshalJSON() ([]byte, error) { - type Alias KnowledgeItem - aux := &struct { - *Alias - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - }{ - Alias: (*Alias)(k), - } - aux.CreatedAt = formatTime(k.CreatedAt) - aux.UpdatedAt = formatTime(k.UpdatedAt) - return json.Marshal(aux) -} - -// KnowledgeChunk 知识块(用于向量化) -type KnowledgeChunk struct { - ID string `json:"id"` - ItemID string `json:"itemId"` - ChunkIndex int `json:"chunkIndex"` - ChunkText string `json:"chunkText"` - Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON - CreatedAt time.Time `json:"createdAt"` -} - -// RetrievalResult 检索结果 -type RetrievalResult struct { - Chunk *KnowledgeChunk `json:"chunk"` - Item *KnowledgeItem `json:"item"` - Similarity float64 `json:"similarity"` // 相似度分数 - Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度 -} - -// RetrievalLog 检索日志 -type RetrievalLog struct { - ID string `json:"id"` - ConversationID string `json:"conversationId,omitempty"` - MessageID string `json:"messageId,omitempty"` - Query string `json:"query"` - RiskType string `json:"riskType,omitempty"` - RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表 - CreatedAt time.Time `json:"createdAt"` -} - -// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 -func (r *RetrievalLog) MarshalJSON() ([]byte, error) { - type Alias RetrievalLog - return json.Marshal(&struct { - *Alias - CreatedAt string `json:"createdAt"` - }{ - Alias: (*Alias)(r), - CreatedAt: formatTime(r.CreatedAt), - }) -} - -// CategoryWithItems 分类及其下的知识项(用于按分类分页) -type CategoryWithItems struct { - Category string `json:"category"` // 分类名称 - ItemCount int `json:"itemCount"` // 该分类下的知识项总数 - Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 -} - -// SearchRequest 搜索请求 -type SearchRequest struct { - Query string `json:"query"` - RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 - SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据) - TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5 - Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7 -}