mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
143 lines
4.4 KiB
Go
143 lines
4.4 KiB
Go
package knowledge
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/cloudwego/eino/callbacks"
|
|
"github.com/cloudwego/eino/components"
|
|
"github.com/cloudwego/eino/components/indexer"
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
|
|
type SQLiteIndexer struct {
|
|
db *sql.DB
|
|
batchSize int
|
|
embeddingModel string
|
|
}
|
|
|
|
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
|
|
// batchSize is the embedding batch size; if <= 0, default 64 is used.
|
|
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
|
|
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
|
|
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
|
|
}
|
|
|
|
// GetType implements eino callback run info.
|
|
func (s *SQLiteIndexer) GetType() string {
|
|
return "SQLiteKnowledgeIndexer"
|
|
}
|
|
|
|
// Store embeds documents and inserts rows. Each doc must carry MetaData:
|
|
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
|
|
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
|
options := indexer.GetCommonOptions(nil, opts...)
|
|
if options.Embedding == nil {
|
|
return nil, fmt.Errorf("sqlite indexer: embedding is required")
|
|
}
|
|
if len(docs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
|
|
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
|
defer func() {
|
|
if err != nil {
|
|
_ = callbacks.OnError(ctx, err)
|
|
return
|
|
}
|
|
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
|
|
}()
|
|
|
|
subIdxStr := strings.Join(options.SubIndexes, ",")
|
|
|
|
texts := make([]string, len(docs))
|
|
for i, d := range docs {
|
|
if d == nil {
|
|
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
|
|
}
|
|
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
|
title := MetaLookupString(d.MetaData, metaKBTitle)
|
|
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
|
|
}
|
|
|
|
bs := s.batchSize
|
|
if bs <= 0 {
|
|
bs = 64
|
|
}
|
|
|
|
var allVecs [][]float64
|
|
for start := 0; start < len(texts); start += bs {
|
|
end := start + bs
|
|
if end > len(texts) {
|
|
end = len(texts)
|
|
}
|
|
batch := texts[start:end]
|
|
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
|
|
if embedErr != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
|
|
}
|
|
if len(vecs) != len(batch) {
|
|
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
|
|
}
|
|
allVecs = append(allVecs, vecs...)
|
|
}
|
|
|
|
embedDim := 0
|
|
if len(allVecs) > 0 {
|
|
embedDim = len(allVecs[0])
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
ids = make([]string, 0, len(docs))
|
|
for i, d := range docs {
|
|
chunkID := uuid.New().String()
|
|
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
|
|
if metaErr != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
|
}
|
|
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
|
if metaErr != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
|
}
|
|
vec := allVecs[i]
|
|
if embedDim > 0 && len(vec) != embedDim {
|
|
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
|
|
}
|
|
vec32 := make([]float32, len(vec))
|
|
for j, v := range vec {
|
|
vec32[j] = float32(v)
|
|
}
|
|
embeddingJSON, jsonErr := json.Marshal(vec32)
|
|
if jsonErr != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
|
|
}
|
|
_, err = tx.ExecContext(ctx,
|
|
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
|
|
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
|
|
}
|
|
ids = append(ids, chunkID)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
|