From 2e063dd857f7a8bad6c480b147f30dd3c8753f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Thu, 2 Jul 2026 11:51:27 +0800 Subject: [PATCH] Add files via upload --- internal/config/config.go | 60 ++++- internal/knowledge/eino_pipeline_retriever.go | 96 ++++++++ internal/knowledge/eino_retrieve_chain.go | 3 +- internal/knowledge/eino_retriever_adapter.go | 31 +-- internal/knowledge/rerank_http.go | 226 ++++++++++++++++++ internal/knowledge/rerank_http_test.go | 97 ++++++++ internal/knowledge/retrieval_postprocess.go | 13 +- .../knowledge/retrieval_postprocess_test.go | 4 +- internal/knowledge/retriever.go | 49 +++- internal/knowledge/wire_retriever.go | 74 ++++++ 10 files changed, 601 insertions(+), 52 deletions(-) create mode 100644 internal/knowledge/eino_pipeline_retriever.go create mode 100644 internal/knowledge/rerank_http.go create mode 100644 internal/knowledge/rerank_http_test.go create mode 100644 internal/knowledge/wire_retriever.go diff --git a/internal/config/config.go b/internal/config/config.go index 4cbfe060..a70a2507 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1477,7 +1477,12 @@ func Default() *Config { }, Retrieval: RetrievalConfig{ TopK: 5, - SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检 + SimilarityThreshold: 0.65, + MultiQuery: MultiQueryConfig{MaxQueries: 4}, + Rerank: RerankConfig{}, + PostRetrieve: PostRetrieveConfig{ + PrefetchTopK: 20, + }, }, Indexing: IndexingConfig{ ChunkStrategy: "markdown_then_recursive", @@ -1573,7 +1578,7 @@ type EmbeddingConfig struct { // PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。 type PostRetrieveConfig struct { - // PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。 + // PrefetchTopK 向量检索阶段每条 MultiQuery 变体最多保留的候选数;0 表示使用内置默认 max(top_k*4, 20)。 PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"` // MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。 MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"` @@ -1581,13 +1586,62 @@ type PostRetrieveConfig struct { MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"` } +// MultiQueryConfig Eino MultiQuery 查询改写(始终启用,无关闭开关)。 +type MultiQueryConfig struct { + // MaxQueries LLM 生成的检索变体上限(含原问语义覆盖);0 表示默认 4。 + MaxQueries int `yaml:"max_queries,omitempty" json:"max_queries,omitempty"` +} + +func (c MultiQueryConfig) MaxQueriesEffective() int { + if c.MaxQueries <= 0 { + return 4 + } + if c.MaxQueries > 8 { + return 8 + } + return c.MaxQueries +} + +// RerankConfig 检索精排(始终启用);支持 dashscope 与 Cohere 兼容 HTTP API。 +type RerankConfig struct { + // Provider: dashscope | cohere;空则按 base_url 自动推断。 + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` + Model string `yaml:"model,omitempty" json:"model,omitempty"` + BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` + APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"` +} + +func (c RerankConfig) ProviderEffective(baseURL string) string { + p := strings.TrimSpace(strings.ToLower(c.Provider)) + if p != "" { + return p + } + u := strings.ToLower(baseURL) + if strings.Contains(u, "dashscope") { + return "dashscope" + } + return "cohere" +} + +func (c RerankConfig) ModelEffective(provider string) string { + if m := strings.TrimSpace(c.Model); m != "" { + return m + } + if provider == "dashscope" { + return "gte-rerank" + } + return "rerank-multilingual-v3.0" +} + // RetrievalConfig 检索配置 type RetrievalConfig struct { TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值 // SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。 SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"` - // PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。 + MultiQuery MultiQueryConfig `yaml:"multi_query" json:"multi_query"` + Rerank RerankConfig `yaml:"rerank" json:"rerank"` + // PostRetrieve 检索后处理(去重、预算截断);精排在 MultiQuery 融合后执行。 PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"` } diff --git a/internal/knowledge/eino_pipeline_retriever.go b/internal/knowledge/eino_pipeline_retriever.go new file mode 100644 index 00000000..487b439b --- /dev/null +++ b/internal/knowledge/eino_pipeline_retriever.go @@ -0,0 +1,96 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// knowledgePipelineRetriever: MultiQuery → vector candidates → rerank → post-process. +type knowledgePipelineRetriever struct { + inner retriever.Retriever + base *Retriever +} + +func newKnowledgePipelineRetriever(inner retriever.Retriever, base *Retriever) *knowledgePipelineRetriever { + if inner == nil || base == nil { + return nil + } + return &knowledgePipelineRetriever{inner: inner, base: base} +} + +func (p *knowledgePipelineRetriever) GetType() string { + return "KnowledgeRAGPipeline" +} + +func (p *knowledgePipelineRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) { + if p == nil || p.inner == nil || p.base == nil { + return nil, fmt.Errorf("knowledge pipeline retriever: nil") + } + q := strings.TrimSpace(query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") + } + + ro := retriever.GetCommonOptions(nil, opts...) + finalTopK := p.base.config.TopK + if finalTopK <= 0 { + finalTopK = 5 + } + if ro.TopK != nil && *ro.TopK > 0 { + finalTopK = *ro.TopK + } + + ctx = callbacks.EnsureRunInfo(ctx, p.GetType(), components.ComponentOfRetriever) + ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{Query: q, TopK: finalTopK, Extra: ro.DSLInfo}) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out}) + }() + + out, err = p.inner.Retrieve(ctx, q, opts...) + if err != nil { + return nil, err + } + if len(out) == 0 { + return out, nil + } + + if rr := p.base.documentReranker(); rr != nil && len(out) > 1 { + reranked, rerr := rr.Rerank(ctx, q, out) + if rerr != nil { + if p.base.logger != nil { + p.base.logger.Warn("知识检索重排失败,已使用融合序", zap.Error(rerr)) + } + } else if len(reranked) > 0 { + out = reranked + } + } + + tokenModel := "" + if p.base.embedder != nil { + tokenModel = p.base.embedder.EmbeddingModelName() + } + var postPO *config.PostRetrieveConfig + if p.base.config != nil { + postPO = &p.base.config.PostRetrieve + } + out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK) + if err != nil { + return nil, err + } + return out, nil +} + +var _ retriever.Retriever = (*knowledgePipelineRetriever)(nil) diff --git a/internal/knowledge/eino_retrieve_chain.go b/internal/knowledge/eino_retrieve_chain.go index 2d1b72eb..81fa4159 100644 --- a/internal/knowledge/eino_retrieve_chain.go +++ b/internal/knowledge/eino_retrieve_chain.go @@ -8,8 +8,7 @@ import ( "github.com/cloudwego/eino/schema" ) -// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。 -// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。 +// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain(MultiQuery → 向量 → 重排 → 后处理)。 func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) { if r == nil { return nil, fmt.Errorf("retriever is nil") diff --git a/internal/knowledge/eino_retriever_adapter.go b/internal/knowledge/eino_retriever_adapter.go index f5635121..712b4734 100644 --- a/internal/knowledge/eino_retriever_adapter.go +++ b/internal/knowledge/eino_retriever_adapter.go @@ -11,19 +11,10 @@ import ( "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" - "go.uber.org/zap" ) // VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity. -// -// Options: -// - [retriever.WithTopK] -// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string) -// -// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric. -// -// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then -// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig]. +// It returns prefetch-sized vector candidates only; rerank and post-process run in [knowledgePipelineRetriever]. type VectorEinoRetriever struct { inner *Retriever } @@ -119,26 +110,6 @@ func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts . return nil, err } out = retrievalResultsToDocuments(results) - - if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 { - reranked, rerr := rr.Rerank(ctx, q, out) - if rerr != nil { - if h.inner.logger != nil { - h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr)) - } - } else if len(reranked) > 0 { - out = reranked - } - } - - tokenModel := "" - if h.inner.embedder != nil { - tokenModel = h.inner.embedder.EmbeddingModelName() - } - out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK) - if err != nil { - return nil, err - } return out, nil } diff --git a/internal/knowledge/rerank_http.go b/internal/knowledge/rerank_http.go new file mode 100644 index 00000000..61e173ed --- /dev/null +++ b/internal/knowledge/rerank_http.go @@ -0,0 +1,226 @@ +package knowledge + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// HTTPReranker calls a hosted rerank API (DashScope or Cohere-compatible). +type HTTPReranker struct { + provider string + model string + baseURL string + apiKey string + client *http.Client + logger *zap.Logger +} + +// NewHTTPReranker builds a rerank client from knowledge retrieval config; openAI supplies fallback credentials. +func NewHTTPReranker(rc *config.RerankConfig, openAI *config.OpenAIConfig, logger *zap.Logger) (*HTTPReranker, error) { + if rc == nil { + return nil, fmt.Errorf("rerank config is nil") + } + baseURL := strings.TrimSpace(rc.BaseURL) + apiKey := strings.TrimSpace(rc.APIKey) + if openAI != nil { + if baseURL == "" { + baseURL = strings.TrimSpace(openAI.BaseURL) + } + if apiKey == "" { + apiKey = strings.TrimSpace(openAI.APIKey) + } + } + if apiKey == "" { + return nil, fmt.Errorf("rerank api_key is required") + } + provider := rc.ProviderEffective(baseURL) + model := rc.ModelEffective(provider) + return &HTTPReranker{ + provider: provider, + model: model, + baseURL: strings.TrimSuffix(baseURL, "/"), + apiKey: apiKey, + client: &http.Client{Timeout: 60 * time.Second}, + logger: logger, + }, nil +} + +func (r *HTTPReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) { + if r == nil { + return docs, nil + } + q := strings.TrimSpace(query) + if q == "" || len(docs) == 0 { + return docs, nil + } + if len(docs) == 1 { + return docs, nil + } + texts := make([]string, 0, len(docs)) + for _, d := range docs { + if d == nil { + texts = append(texts, "") + continue + } + texts = append(texts, d.Content) + } + var order []int + var err error + switch r.provider { + case "dashscope": + order, err = r.rerankDashScope(ctx, q, texts, len(docs)) + default: + order, err = r.rerankCohere(ctx, q, texts, len(docs)) + } + if err != nil { + return nil, err + } + out := make([]*schema.Document, 0, len(order)) + for _, idx := range order { + if idx < 0 || idx >= len(docs) || docs[idx] == nil { + continue + } + out = append(out, docs[idx]) + } + if len(out) == 0 { + return docs, nil + } + return out, nil +} + +func (r *HTTPReranker) rerankCohere(ctx context.Context, query string, documents []string, topN int) ([]int, error) { + url := r.cohereRerankURL() + body := map[string]any{ + "model": r.model, + "query": query, + "documents": documents, + "top_n": topN, + } + raw, err := json.Marshal(body) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+r.apiKey) + resp, err := r.client.Do(req) + if err != nil { + return nil, fmt.Errorf("rerank request: %w", err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody))) + } + var parsed struct { + Results []struct { + Index int `json:"index"` + } `json:"results"` + } + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, fmt.Errorf("rerank decode: %w", err) + } + order := make([]int, 0, len(parsed.Results)) + for _, row := range parsed.Results { + order = append(order, row.Index) + } + return order, nil +} + +func (r *HTTPReranker) rerankDashScope(ctx context.Context, query string, documents []string, topN int) ([]int, error) { + url := r.dashscopeRerankURL() + body := map[string]any{ + "model": r.model, + "input": map[string]any{ + "query": query, + "documents": documents, + }, + "parameters": map[string]any{ + "return_documents": false, + "top_n": topN, + }, + } + raw, err := json.Marshal(body) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+r.apiKey) + resp, err := r.client.Do(req) + if err != nil { + return nil, fmt.Errorf("dashscope rerank: %w", err) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("dashscope rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody))) + } + var parsed struct { + Output struct { + Results []struct { + Index int `json:"index"` + } `json:"results"` + } `json:"output"` + } + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, fmt.Errorf("dashscope rerank decode: %w", err) + } + order := make([]int, 0, len(parsed.Output.Results)) + for _, row := range parsed.Output.Results { + order = append(order, row.Index) + } + return order, nil +} + +func (r *HTTPReranker) cohereRerankURL() string { + base := r.baseURL + if base == "" { + base = "https://api.cohere.com" + } + if strings.HasSuffix(base, "/v1") { + return base + "/rerank" + } + return base + "/v1/rerank" +} + +func (r *HTTPReranker) dashscopeRerankURL() string { + base := strings.TrimSpace(r.baseURL) + if base == "" { + return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank" + } + if strings.Contains(base, "/api/v1/services/rerank") { + return base + } + if strings.Contains(base, "dashscope.aliyuncs.com") || strings.Contains(base, "compatible-mode") { + return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank" + } + return strings.TrimSuffix(base, "/") +} + +func truncateForRerankLog(s string) string { + s = strings.TrimSpace(s) + if len(s) > 512 { + return s[:512] + "..." + } + return s +} + +var _ DocumentReranker = (*HTTPReranker)(nil) diff --git a/internal/knowledge/rerank_http_test.go b/internal/knowledge/rerank_http_test.go new file mode 100644 index 00000000..013ad7d4 --- /dev/null +++ b/internal/knowledge/rerank_http_test.go @@ -0,0 +1,97 @@ +package knowledge + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" +) + +func TestHTTPReranker_CohereOrder(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/rerank" { + t.Fatalf("path %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "results": []map[string]any{ + {"index": 2, "relevance_score": 0.9}, + {"index": 0, "relevance_score": 0.5}, + }, + }) + })) + defer srv.Close() + + rr, err := NewHTTPReranker(&config.RerankConfig{ + Provider: "cohere", + Model: "rerank-multilingual-v3.0", + BaseURL: srv.URL, + APIKey: "test-key", + }, nil, nil) + if err != nil { + t.Fatal(err) + } + docs := []*schema.Document{ + {ID: "a", Content: "alpha"}, + {ID: "b", Content: "beta"}, + {ID: "c", Content: "gamma"}, + } + out, err := rr.Rerank(context.Background(), "query", docs) + if err != nil { + t.Fatal(err) + } + if len(out) != 2 || out[0].ID != "c" || out[1].ID != "a" { + t.Fatalf("order wrong: %#v", out) + } +} + +func TestHTTPReranker_DashScopeOrder(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "output": map[string]any{ + "results": []map[string]any{ + {"index": 1, "relevance_score": 0.88}, + }, + }, + }) + })) + defer srv.Close() + + rr, err := NewHTTPReranker(&config.RerankConfig{ + Provider: "dashscope", + Model: "gte-rerank", + BaseURL: srv.URL, + APIKey: "test-key", + }, nil, nil) + if err != nil { + t.Fatal(err) + } + docs := []*schema.Document{{ID: "a", Content: "a"}, {ID: "b", Content: "b"}} + out, err := rr.Rerank(context.Background(), "q", docs) + if err != nil { + t.Fatal(err) + } + if len(out) != 1 || out[0].ID != "b" { + t.Fatalf("got %#v", out) + } +} + +func TestRerankConfigDefaults(t *testing.T) { + t.Parallel() + rc := config.RerankConfig{} + if rc.ProviderEffective("https://dashscope.aliyuncs.com/x") != "dashscope" { + t.Fatal("dashscope detect") + } + if rc.ModelEffective("dashscope") != "gte-rerank" { + t.Fatal("dashscope model") + } + if rc.ModelEffective("cohere") != "rerank-multilingual-v3.0" { + t.Fatal("cohere model") + } +} diff --git a/internal/knowledge/retrieval_postprocess.go b/internal/knowledge/retrieval_postprocess.go index eb69e4c3..20e07110 100644 --- a/internal/knowledge/retrieval_postprocess.go +++ b/internal/knowledge/retrieval_postprocess.go @@ -19,7 +19,7 @@ import ( // postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。 const postRetrieveMaxPrefetchCap = 200 -// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。 +// DocumentReranker 精排(HTTP dashscope / Cohere 兼容 API),由 [WireRetrieverPipeline] 注入。 type DocumentReranker interface { Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) } @@ -167,13 +167,16 @@ func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, return out, nil } -// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。 +// EffectivePrefetchTopK 计算每条 MultiQuery 变体在向量阶段的候选条数(供融合 / 重排 / 后处理)。 func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int { if topK < 1 { topK = 5 } - fetch := topK - if po != nil && po.PrefetchTopK > fetch { + fetch := topK * 4 + if fetch < 20 { + fetch = 20 + } + if po != nil && po.PrefetchTopK > 0 { fetch = po.PrefetchTopK } if fetch > postRetrieveMaxPrefetchCap { @@ -182,7 +185,7 @@ func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int { return fetch } -// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。 +// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK(精排已在流水线中完成)。 func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) { if finalTopK < 1 { finalTopK = 5 diff --git a/internal/knowledge/retrieval_postprocess_test.go b/internal/knowledge/retrieval_postprocess_test.go index 10c661a8..889d5c62 100644 --- a/internal/knowledge/retrieval_postprocess_test.go +++ b/internal/knowledge/retrieval_postprocess_test.go @@ -28,8 +28,8 @@ func TestDedupeByNormalizedContent(t *testing.T) { } func TestEffectivePrefetchTopK(t *testing.T) { - if g := EffectivePrefetchTopK(5, nil); g != 5 { - t.Fatalf("got %d", g) + if g := EffectivePrefetchTopK(5, nil); g != 20 { + t.Fatalf("default prefetch got %d want 20", g) } if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 { t.Fatalf("got %d", g) diff --git a/internal/knowledge/retriever.go b/internal/knowledge/retriever.go index 9145b2c6..c75e8a13 100644 --- a/internal/knowledge/retriever.go +++ b/internal/knowledge/retriever.go @@ -27,15 +27,19 @@ type Retriever struct { rerankMu sync.RWMutex reranker DocumentReranker + + pipeline retriever.Retriever + wireOpenAI *config.OpenAIConfig } // RetrievalConfig 检索配置 type RetrievalConfig struct { TopK int SimilarityThreshold float64 - // SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。 - SubIndexFilter string - PostRetrieve config.PostRetrieveConfig + SubIndexFilter string + MultiQuery config.MultiQueryConfig + Rerank config.RerankConfig + PostRetrieve config.PostRetrieveConfig } // NewRetriever 创建新的检索器 @@ -48,7 +52,7 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge } } -// UpdateConfig 更新检索配置 +// UpdateConfig 更新检索配置并重建 Eino MultiQuery + 重排流水线。 func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) { if cfg != nil { r.config = cfg @@ -57,12 +61,18 @@ func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) { zap.Int("top_k", cfg.TopK), zap.Float64("similarity_threshold", cfg.SimilarityThreshold), zap.String("sub_index_filter", cfg.SubIndexFilter), + zap.Int("multi_query_max", cfg.MultiQuery.MaxQueriesEffective()), zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK), zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars), zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens), ) } } + if r.wireOpenAI != nil { + if err := WireRetrieverPipeline(context.Background(), r, r.wireOpenAI); err != nil && r.logger != nil { + r.logger.Warn("检索流水线重建失败", zap.Error(err)) + } + } } // SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。 @@ -103,7 +113,7 @@ func cosineSimilarity(a, b []float32) float64 { return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) } -// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。 +// Search 搜索知识库(Eino MultiQuery → 向量检索 → 重排 → 后处理)。 func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { if req == nil { return nil, fmt.Errorf("请求不能为空") @@ -113,7 +123,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva return nil, fmt.Errorf("查询不能为空") } opts := r.einoRetrieverOptions(req) - docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...) + docs, err := r.activeEinoRetriever().Retrieve(ctx, q, opts...) if err != nil { return nil, err } @@ -143,7 +153,19 @@ func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option // EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。 func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { - return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...) + return r.activeEinoRetriever().Retrieve(ctx, query, opts...) +} + +func (r *Retriever) activeEinoRetriever() retriever.Retriever { + if r != nil && r.pipeline != nil { + return r.pipeline + } + return NewVectorEinoRetriever(r) +} + +// AsEinoRetriever 将知识库检索流水线暴露为 Eino [retriever.Retriever]。 +func (r *Retriever) AsEinoRetriever() retriever.Retriever { + return r.activeEinoRetriever() } func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) { @@ -299,7 +321,14 @@ func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*Re return results, nil } -// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。 -func (r *Retriever) AsEinoRetriever() retriever.Retriever { - return NewVectorEinoRetriever(r) +// RetrievalConfigFromYAML maps API/YAML retrieval settings into the knowledge package. +func RetrievalConfigFromYAML(r config.RetrievalConfig) *RetrievalConfig { + return &RetrievalConfig{ + TopK: r.TopK, + SimilarityThreshold: r.SimilarityThreshold, + SubIndexFilter: r.SubIndexFilter, + MultiQuery: r.MultiQuery, + Rerank: r.Rerank, + PostRetrieve: r.PostRetrieve, + } } diff --git a/internal/knowledge/wire_retriever.go b/internal/knowledge/wire_retriever.go new file mode 100644 index 00000000..d2155add --- /dev/null +++ b/internal/knowledge/wire_retriever.go @@ -0,0 +1,74 @@ +package knowledge + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/openai" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/flow/retriever/multiquery" + "go.uber.org/zap" +) + +// WireRetrieverPipeline builds Eino MultiQuery + HTTP rerank + post-process pipeline on r. +// Call once after NewRetriever; UpdateConfig re-invokes when wireOpenAI is set. +func WireRetrieverPipeline(ctx context.Context, r *Retriever, openAI *config.OpenAIConfig) error { + if r == nil { + return fmt.Errorf("retriever is nil") + } + if openAI == nil { + return fmt.Errorf("openai config is nil") + } + if r.config == nil { + return fmt.Errorf("retrieval config is nil") + } + r.wireOpenAI = openAI + + httpClient := openai.NewEinoHTTPClient(openAI, &http.Client{Timeout: 120 * time.Second}) + chatCfg := &einoopenai.ChatModelConfig{ + APIKey: strings.TrimSpace(openAI.APIKey), + BaseURL: strings.TrimSuffix(strings.TrimSpace(openAI.BaseURL), "/"), + Model: strings.TrimSpace(openAI.Model), + HTTPClient: httpClient, + } + if chatCfg.Model == "" { + chatCfg.Model = "gpt-4o" + } + rewriteLLM, err := einoopenai.NewChatModel(ctx, chatCfg) + if err != nil { + return fmt.Errorf("multi_query rewrite model: %w", err) + } + + reranker, err := NewHTTPReranker(&r.config.Rerank, openAI, r.logger) + if err != nil { + return fmt.Errorf("reranker: %w", err) + } + r.SetDocumentReranker(reranker) + + vec := NewVectorEinoRetriever(r) + mq, err := multiquery.NewRetriever(ctx, &multiquery.Config{ + RewriteLLM: rewriteLLM, + MaxQueriesNum: r.config.MultiQuery.MaxQueriesEffective(), + OrigRetriever: vec, + }) + if err != nil { + return fmt.Errorf("multi_query: %w", err) + } + + r.pipeline = newKnowledgePipelineRetriever(mq, r) + if r.logger != nil { + provider := r.config.Rerank.ProviderEffective(strings.TrimSpace(openAI.BaseURL)) + r.logger.Info("知识库检索流水线已启用", + zap.String("pipeline", "MultiQuery→Vector→Rerank→PostRetrieve"), + zap.Int("multi_query_max", r.config.MultiQuery.MaxQueriesEffective()), + zap.String("rerank_provider", provider), + zap.String("rerank_model", r.config.Rerank.ModelEffective(provider)), + ) + } + return nil +}