mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-02 18:55:52 +02:00
Add files via upload
This commit is contained in:
@@ -1477,7 +1477,12 @@ func Default() *Config {
|
||||
},
|
||||
Retrieval: RetrievalConfig{
|
||||
TopK: 5,
|
||||
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
|
||||
SimilarityThreshold: 0.65,
|
||||
MultiQuery: MultiQueryConfig{MaxQueries: 4},
|
||||
Rerank: RerankConfig{},
|
||||
PostRetrieve: PostRetrieveConfig{
|
||||
PrefetchTopK: 20,
|
||||
},
|
||||
},
|
||||
Indexing: IndexingConfig{
|
||||
ChunkStrategy: "markdown_then_recursive",
|
||||
@@ -1573,7 +1578,7 @@ type EmbeddingConfig struct {
|
||||
|
||||
// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。
|
||||
type PostRetrieveConfig struct {
|
||||
// PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。
|
||||
// PrefetchTopK 向量检索阶段每条 MultiQuery 变体最多保留的候选数;0 表示使用内置默认 max(top_k*4, 20)。
|
||||
PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"`
|
||||
// MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。
|
||||
MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"`
|
||||
@@ -1581,13 +1586,62 @@ type PostRetrieveConfig struct {
|
||||
MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// MultiQueryConfig Eino MultiQuery 查询改写(始终启用,无关闭开关)。
|
||||
type MultiQueryConfig struct {
|
||||
// MaxQueries LLM 生成的检索变体上限(含原问语义覆盖);0 表示默认 4。
|
||||
MaxQueries int `yaml:"max_queries,omitempty" json:"max_queries,omitempty"`
|
||||
}
|
||||
|
||||
func (c MultiQueryConfig) MaxQueriesEffective() int {
|
||||
if c.MaxQueries <= 0 {
|
||||
return 4
|
||||
}
|
||||
if c.MaxQueries > 8 {
|
||||
return 8
|
||||
}
|
||||
return c.MaxQueries
|
||||
}
|
||||
|
||||
// RerankConfig 检索精排(始终启用);支持 dashscope 与 Cohere 兼容 HTTP API。
|
||||
type RerankConfig struct {
|
||||
// Provider: dashscope | cohere;空则按 base_url 自动推断。
|
||||
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"`
|
||||
Model string `yaml:"model,omitempty" json:"model,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"`
|
||||
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
|
||||
}
|
||||
|
||||
func (c RerankConfig) ProviderEffective(baseURL string) string {
|
||||
p := strings.TrimSpace(strings.ToLower(c.Provider))
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
u := strings.ToLower(baseURL)
|
||||
if strings.Contains(u, "dashscope") {
|
||||
return "dashscope"
|
||||
}
|
||||
return "cohere"
|
||||
}
|
||||
|
||||
func (c RerankConfig) ModelEffective(provider string) string {
|
||||
if m := strings.TrimSpace(c.Model); m != "" {
|
||||
return m
|
||||
}
|
||||
if provider == "dashscope" {
|
||||
return "gte-rerank"
|
||||
}
|
||||
return "rerank-multilingual-v3.0"
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值
|
||||
// SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。
|
||||
SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"`
|
||||
// PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。
|
||||
MultiQuery MultiQueryConfig `yaml:"multi_query" json:"multi_query"`
|
||||
Rerank RerankConfig `yaml:"rerank" json:"rerank"`
|
||||
// PostRetrieve 检索后处理(去重、预算截断);精排在 MultiQuery 融合后执行。
|
||||
PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// knowledgePipelineRetriever: MultiQuery → vector candidates → rerank → post-process.
|
||||
type knowledgePipelineRetriever struct {
|
||||
inner retriever.Retriever
|
||||
base *Retriever
|
||||
}
|
||||
|
||||
func newKnowledgePipelineRetriever(inner retriever.Retriever, base *Retriever) *knowledgePipelineRetriever {
|
||||
if inner == nil || base == nil {
|
||||
return nil
|
||||
}
|
||||
return &knowledgePipelineRetriever{inner: inner, base: base}
|
||||
}
|
||||
|
||||
func (p *knowledgePipelineRetriever) GetType() string {
|
||||
return "KnowledgeRAGPipeline"
|
||||
}
|
||||
|
||||
func (p *knowledgePipelineRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
|
||||
if p == nil || p.inner == nil || p.base == nil {
|
||||
return nil, fmt.Errorf("knowledge pipeline retriever: nil")
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
ro := retriever.GetCommonOptions(nil, opts...)
|
||||
finalTopK := p.base.config.TopK
|
||||
if finalTopK <= 0 {
|
||||
finalTopK = 5
|
||||
}
|
||||
if ro.TopK != nil && *ro.TopK > 0 {
|
||||
finalTopK = *ro.TopK
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, p.GetType(), components.ComponentOfRetriever)
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{Query: q, TopK: finalTopK, Extra: ro.DSLInfo})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
|
||||
}()
|
||||
|
||||
out, err = p.inner.Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
if rr := p.base.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if p.base.logger != nil {
|
||||
p.base.logger.Warn("知识检索重排失败,已使用融合序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if p.base.embedder != nil {
|
||||
tokenModel = p.base.embedder.EmbeddingModelName()
|
||||
}
|
||||
var postPO *config.PostRetrieveConfig
|
||||
if p.base.config != nil {
|
||||
postPO = &p.base.config.PostRetrieve
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ retriever.Retriever = (*knowledgePipelineRetriever)(nil)
|
||||
@@ -8,8 +8,7 @@ import (
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
|
||||
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain(MultiQuery → 向量 → 重排 → 后处理)。
|
||||
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("retriever is nil")
|
||||
|
||||
@@ -11,19 +11,10 @@ import (
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
|
||||
//
|
||||
// Options:
|
||||
// - [retriever.WithTopK]
|
||||
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string)
|
||||
//
|
||||
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
|
||||
//
|
||||
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
|
||||
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
|
||||
// It returns prefetch-sized vector candidates only; rerank and post-process run in [knowledgePipelineRetriever].
|
||||
type VectorEinoRetriever struct {
|
||||
inner *Retriever
|
||||
}
|
||||
@@ -119,26 +110,6 @@ func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts .
|
||||
return nil, err
|
||||
}
|
||||
out = retrievalResultsToDocuments(results)
|
||||
|
||||
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if h.inner.logger != nil {
|
||||
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if h.inner.embedder != nil {
|
||||
tokenModel = h.inner.embedder.EmbeddingModelName()
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPReranker calls a hosted rerank API (DashScope or Cohere-compatible).
|
||||
type HTTPReranker struct {
|
||||
provider string
|
||||
model string
|
||||
baseURL string
|
||||
apiKey string
|
||||
client *http.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewHTTPReranker builds a rerank client from knowledge retrieval config; openAI supplies fallback credentials.
|
||||
func NewHTTPReranker(rc *config.RerankConfig, openAI *config.OpenAIConfig, logger *zap.Logger) (*HTTPReranker, error) {
|
||||
if rc == nil {
|
||||
return nil, fmt.Errorf("rerank config is nil")
|
||||
}
|
||||
baseURL := strings.TrimSpace(rc.BaseURL)
|
||||
apiKey := strings.TrimSpace(rc.APIKey)
|
||||
if openAI != nil {
|
||||
if baseURL == "" {
|
||||
baseURL = strings.TrimSpace(openAI.BaseURL)
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = strings.TrimSpace(openAI.APIKey)
|
||||
}
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("rerank api_key is required")
|
||||
}
|
||||
provider := rc.ProviderEffective(baseURL)
|
||||
model := rc.ModelEffective(provider)
|
||||
return &HTTPReranker{
|
||||
provider: provider,
|
||||
model: model,
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
apiKey: apiKey,
|
||||
client: &http.Client{Timeout: 60 * time.Second},
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
if r == nil {
|
||||
return docs, nil
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" || len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
if len(docs) == 1 {
|
||||
return docs, nil
|
||||
}
|
||||
texts := make([]string, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil {
|
||||
texts = append(texts, "")
|
||||
continue
|
||||
}
|
||||
texts = append(texts, d.Content)
|
||||
}
|
||||
var order []int
|
||||
var err error
|
||||
switch r.provider {
|
||||
case "dashscope":
|
||||
order, err = r.rerankDashScope(ctx, q, texts, len(docs))
|
||||
default:
|
||||
order, err = r.rerankCohere(ctx, q, texts, len(docs))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]*schema.Document, 0, len(order))
|
||||
for _, idx := range order {
|
||||
if idx < 0 || idx >= len(docs) || docs[idx] == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, docs[idx])
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) rerankCohere(ctx context.Context, query string, documents []string, topN int) ([]int, error) {
|
||||
url := r.cohereRerankURL()
|
||||
body := map[string]any{
|
||||
"model": r.model,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": topN,
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rerank request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody)))
|
||||
}
|
||||
var parsed struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
} `json:"results"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("rerank decode: %w", err)
|
||||
}
|
||||
order := make([]int, 0, len(parsed.Results))
|
||||
for _, row := range parsed.Results {
|
||||
order = append(order, row.Index)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) rerankDashScope(ctx context.Context, query string, documents []string, topN int) ([]int, error) {
|
||||
url := r.dashscopeRerankURL()
|
||||
body := map[string]any{
|
||||
"model": r.model,
|
||||
"input": map[string]any{
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
"parameters": map[string]any{
|
||||
"return_documents": false,
|
||||
"top_n": topN,
|
||||
},
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dashscope rerank: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("dashscope rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody)))
|
||||
}
|
||||
var parsed struct {
|
||||
Output struct {
|
||||
Results []struct {
|
||||
Index int `json:"index"`
|
||||
} `json:"results"`
|
||||
} `json:"output"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("dashscope rerank decode: %w", err)
|
||||
}
|
||||
order := make([]int, 0, len(parsed.Output.Results))
|
||||
for _, row := range parsed.Output.Results {
|
||||
order = append(order, row.Index)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) cohereRerankURL() string {
|
||||
base := r.baseURL
|
||||
if base == "" {
|
||||
base = "https://api.cohere.com"
|
||||
}
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
return base + "/rerank"
|
||||
}
|
||||
return base + "/v1/rerank"
|
||||
}
|
||||
|
||||
func (r *HTTPReranker) dashscopeRerankURL() string {
|
||||
base := strings.TrimSpace(r.baseURL)
|
||||
if base == "" {
|
||||
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
||||
}
|
||||
if strings.Contains(base, "/api/v1/services/rerank") {
|
||||
return base
|
||||
}
|
||||
if strings.Contains(base, "dashscope.aliyuncs.com") || strings.Contains(base, "compatible-mode") {
|
||||
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
||||
}
|
||||
return strings.TrimSuffix(base, "/")
|
||||
}
|
||||
|
||||
func truncateForRerankLog(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) > 512 {
|
||||
return s[:512] + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
var _ DocumentReranker = (*HTTPReranker)(nil)
|
||||
@@ -0,0 +1,97 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestHTTPReranker_CohereOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/rerank" {
|
||||
t.Fatalf("path %s", r.URL.Path)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"index": 2, "relevance_score": 0.9},
|
||||
{"index": 0, "relevance_score": 0.5},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rr, err := NewHTTPReranker(&config.RerankConfig{
|
||||
Provider: "cohere",
|
||||
Model: "rerank-multilingual-v3.0",
|
||||
BaseURL: srv.URL,
|
||||
APIKey: "test-key",
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
docs := []*schema.Document{
|
||||
{ID: "a", Content: "alpha"},
|
||||
{ID: "b", Content: "beta"},
|
||||
{ID: "c", Content: "gamma"},
|
||||
}
|
||||
out, err := rr.Rerank(context.Background(), "query", docs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 2 || out[0].ID != "c" || out[1].ID != "a" {
|
||||
t.Fatalf("order wrong: %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPReranker_DashScopeOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"output": map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"index": 1, "relevance_score": 0.88},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
rr, err := NewHTTPReranker(&config.RerankConfig{
|
||||
Provider: "dashscope",
|
||||
Model: "gte-rerank",
|
||||
BaseURL: srv.URL,
|
||||
APIKey: "test-key",
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
docs := []*schema.Document{{ID: "a", Content: "a"}, {ID: "b", Content: "b"}}
|
||||
out, err := rr.Rerank(context.Background(), "q", docs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 1 || out[0].ID != "b" {
|
||||
t.Fatalf("got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRerankConfigDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
rc := config.RerankConfig{}
|
||||
if rc.ProviderEffective("https://dashscope.aliyuncs.com/x") != "dashscope" {
|
||||
t.Fatal("dashscope detect")
|
||||
}
|
||||
if rc.ModelEffective("dashscope") != "gte-rerank" {
|
||||
t.Fatal("dashscope model")
|
||||
}
|
||||
if rc.ModelEffective("cohere") != "rerank-multilingual-v3.0" {
|
||||
t.Fatal("cohere model")
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
|
||||
const postRetrieveMaxPrefetchCap = 200
|
||||
|
||||
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
|
||||
// DocumentReranker 精排(HTTP dashscope / Cohere 兼容 API),由 [WireRetrieverPipeline] 注入。
|
||||
type DocumentReranker interface {
|
||||
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
|
||||
}
|
||||
@@ -167,13 +167,16 @@ func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int,
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
|
||||
// EffectivePrefetchTopK 计算每条 MultiQuery 变体在向量阶段的候选条数(供融合 / 重排 / 后处理)。
|
||||
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
if topK < 1 {
|
||||
topK = 5
|
||||
}
|
||||
fetch := topK
|
||||
if po != nil && po.PrefetchTopK > fetch {
|
||||
fetch := topK * 4
|
||||
if fetch < 20 {
|
||||
fetch = 20
|
||||
}
|
||||
if po != nil && po.PrefetchTopK > 0 {
|
||||
fetch = po.PrefetchTopK
|
||||
}
|
||||
if fetch > postRetrieveMaxPrefetchCap {
|
||||
@@ -182,7 +185,7 @@ func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
return fetch
|
||||
}
|
||||
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK(精排已在流水线中完成)。
|
||||
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
|
||||
if finalTopK < 1 {
|
||||
finalTopK = 5
|
||||
|
||||
@@ -28,8 +28,8 @@ func TestDedupeByNormalizedContent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectivePrefetchTopK(t *testing.T) {
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 5 {
|
||||
t.Fatalf("got %d", g)
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 20 {
|
||||
t.Fatalf("default prefetch got %d want 20", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
|
||||
t.Fatalf("got %d", g)
|
||||
|
||||
@@ -27,15 +27,19 @@ type Retriever struct {
|
||||
|
||||
rerankMu sync.RWMutex
|
||||
reranker DocumentReranker
|
||||
|
||||
pipeline retriever.Retriever
|
||||
wireOpenAI *config.OpenAIConfig
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int
|
||||
SimilarityThreshold float64
|
||||
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
|
||||
SubIndexFilter string
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
SubIndexFilter string
|
||||
MultiQuery config.MultiQueryConfig
|
||||
Rerank config.RerankConfig
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
}
|
||||
|
||||
// NewRetriever 创建新的检索器
|
||||
@@ -48,7 +52,7 @@ func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logge
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新检索配置
|
||||
// UpdateConfig 更新检索配置并重建 Eino MultiQuery + 重排流水线。
|
||||
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
if cfg != nil {
|
||||
r.config = cfg
|
||||
@@ -57,12 +61,18 @@ func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
zap.Int("top_k", cfg.TopK),
|
||||
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
|
||||
zap.String("sub_index_filter", cfg.SubIndexFilter),
|
||||
zap.Int("multi_query_max", cfg.MultiQuery.MaxQueriesEffective()),
|
||||
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
|
||||
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
|
||||
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
|
||||
)
|
||||
}
|
||||
}
|
||||
if r.wireOpenAI != nil {
|
||||
if err := WireRetrieverPipeline(context.Background(), r, r.wireOpenAI); err != nil && r.logger != nil {
|
||||
r.logger.Warn("检索流水线重建失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
|
||||
@@ -103,7 +113,7 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。
|
||||
// Search 搜索知识库(Eino MultiQuery → 向量检索 → 重排 → 后处理)。
|
||||
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("请求不能为空")
|
||||
@@ -113,7 +123,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
opts := r.einoRetrieverOptions(req)
|
||||
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
|
||||
docs, err := r.activeEinoRetriever().Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -143,7 +153,19 @@ func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option
|
||||
|
||||
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
|
||||
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
|
||||
return r.activeEinoRetriever().Retrieve(ctx, query, opts...)
|
||||
}
|
||||
|
||||
func (r *Retriever) activeEinoRetriever() retriever.Retriever {
|
||||
if r != nil && r.pipeline != nil {
|
||||
return r.pipeline
|
||||
}
|
||||
return NewVectorEinoRetriever(r)
|
||||
}
|
||||
|
||||
// AsEinoRetriever 将知识库检索流水线暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return r.activeEinoRetriever()
|
||||
}
|
||||
|
||||
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
|
||||
@@ -299,7 +321,14 @@ func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*Re
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return NewVectorEinoRetriever(r)
|
||||
// RetrievalConfigFromYAML maps API/YAML retrieval settings into the knowledge package.
|
||||
func RetrievalConfigFromYAML(r config.RetrievalConfig) *RetrievalConfig {
|
||||
return &RetrievalConfig{
|
||||
TopK: r.TopK,
|
||||
SimilarityThreshold: r.SimilarityThreshold,
|
||||
SubIndexFilter: r.SubIndexFilter,
|
||||
MultiQuery: r.MultiQuery,
|
||||
Rerank: r.Rerank,
|
||||
PostRetrieve: r.PostRetrieve,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/flow/retriever/multiquery"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WireRetrieverPipeline builds Eino MultiQuery + HTTP rerank + post-process pipeline on r.
|
||||
// Call once after NewRetriever; UpdateConfig re-invokes when wireOpenAI is set.
|
||||
func WireRetrieverPipeline(ctx context.Context, r *Retriever, openAI *config.OpenAIConfig) error {
|
||||
if r == nil {
|
||||
return fmt.Errorf("retriever is nil")
|
||||
}
|
||||
if openAI == nil {
|
||||
return fmt.Errorf("openai config is nil")
|
||||
}
|
||||
if r.config == nil {
|
||||
return fmt.Errorf("retrieval config is nil")
|
||||
}
|
||||
r.wireOpenAI = openAI
|
||||
|
||||
httpClient := openai.NewEinoHTTPClient(openAI, &http.Client{Timeout: 120 * time.Second})
|
||||
chatCfg := &einoopenai.ChatModelConfig{
|
||||
APIKey: strings.TrimSpace(openAI.APIKey),
|
||||
BaseURL: strings.TrimSuffix(strings.TrimSpace(openAI.BaseURL), "/"),
|
||||
Model: strings.TrimSpace(openAI.Model),
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
if chatCfg.Model == "" {
|
||||
chatCfg.Model = "gpt-4o"
|
||||
}
|
||||
rewriteLLM, err := einoopenai.NewChatModel(ctx, chatCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("multi_query rewrite model: %w", err)
|
||||
}
|
||||
|
||||
reranker, err := NewHTTPReranker(&r.config.Rerank, openAI, r.logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reranker: %w", err)
|
||||
}
|
||||
r.SetDocumentReranker(reranker)
|
||||
|
||||
vec := NewVectorEinoRetriever(r)
|
||||
mq, err := multiquery.NewRetriever(ctx, &multiquery.Config{
|
||||
RewriteLLM: rewriteLLM,
|
||||
MaxQueriesNum: r.config.MultiQuery.MaxQueriesEffective(),
|
||||
OrigRetriever: vec,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("multi_query: %w", err)
|
||||
}
|
||||
|
||||
r.pipeline = newKnowledgePipelineRetriever(mq, r)
|
||||
if r.logger != nil {
|
||||
provider := r.config.Rerank.ProviderEffective(strings.TrimSpace(openAI.BaseURL))
|
||||
r.logger.Info("知识库检索流水线已启用",
|
||||
zap.String("pipeline", "MultiQuery→Vector→Rerank→PostRetrieve"),
|
||||
zap.Int("multi_query_max", r.config.MultiQuery.MaxQueriesEffective()),
|
||||
zap.String("rerank_provider", provider),
|
||||
zap.String("rerank_model", r.config.Rerank.ModelEffective(provider)),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user