mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-02 18:55:52 +02:00
227 lines
5.7 KiB
Go
227 lines
5.7 KiB
Go
package knowledge
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/config"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// HTTPReranker calls a hosted rerank API (DashScope or Cohere-compatible).
|
|
type HTTPReranker struct {
|
|
provider string
|
|
model string
|
|
baseURL string
|
|
apiKey string
|
|
client *http.Client
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewHTTPReranker builds a rerank client from knowledge retrieval config; openAI supplies fallback credentials.
|
|
func NewHTTPReranker(rc *config.RerankConfig, openAI *config.OpenAIConfig, logger *zap.Logger) (*HTTPReranker, error) {
|
|
if rc == nil {
|
|
return nil, fmt.Errorf("rerank config is nil")
|
|
}
|
|
baseURL := strings.TrimSpace(rc.BaseURL)
|
|
apiKey := strings.TrimSpace(rc.APIKey)
|
|
if openAI != nil {
|
|
if baseURL == "" {
|
|
baseURL = strings.TrimSpace(openAI.BaseURL)
|
|
}
|
|
if apiKey == "" {
|
|
apiKey = strings.TrimSpace(openAI.APIKey)
|
|
}
|
|
}
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("rerank api_key is required")
|
|
}
|
|
provider := rc.ProviderEffective(baseURL)
|
|
model := rc.ModelEffective(provider)
|
|
return &HTTPReranker{
|
|
provider: provider,
|
|
model: model,
|
|
baseURL: strings.TrimSuffix(baseURL, "/"),
|
|
apiKey: apiKey,
|
|
client: &http.Client{Timeout: 60 * time.Second},
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
func (r *HTTPReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) {
|
|
if r == nil {
|
|
return docs, nil
|
|
}
|
|
q := strings.TrimSpace(query)
|
|
if q == "" || len(docs) == 0 {
|
|
return docs, nil
|
|
}
|
|
if len(docs) == 1 {
|
|
return docs, nil
|
|
}
|
|
texts := make([]string, 0, len(docs))
|
|
for _, d := range docs {
|
|
if d == nil {
|
|
texts = append(texts, "")
|
|
continue
|
|
}
|
|
texts = append(texts, d.Content)
|
|
}
|
|
var order []int
|
|
var err error
|
|
switch r.provider {
|
|
case "dashscope":
|
|
order, err = r.rerankDashScope(ctx, q, texts, len(docs))
|
|
default:
|
|
order, err = r.rerankCohere(ctx, q, texts, len(docs))
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := make([]*schema.Document, 0, len(order))
|
|
for _, idx := range order {
|
|
if idx < 0 || idx >= len(docs) || docs[idx] == nil {
|
|
continue
|
|
}
|
|
out = append(out, docs[idx])
|
|
}
|
|
if len(out) == 0 {
|
|
return docs, nil
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (r *HTTPReranker) rerankCohere(ctx context.Context, query string, documents []string, topN int) ([]int, error) {
|
|
url := r.cohereRerankURL()
|
|
body := map[string]any{
|
|
"model": r.model,
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_n": topN,
|
|
}
|
|
raw, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
|
resp, err := r.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("rerank request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody)))
|
|
}
|
|
var parsed struct {
|
|
Results []struct {
|
|
Index int `json:"index"`
|
|
} `json:"results"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
|
return nil, fmt.Errorf("rerank decode: %w", err)
|
|
}
|
|
order := make([]int, 0, len(parsed.Results))
|
|
for _, row := range parsed.Results {
|
|
order = append(order, row.Index)
|
|
}
|
|
return order, nil
|
|
}
|
|
|
|
func (r *HTTPReranker) rerankDashScope(ctx context.Context, query string, documents []string, topN int) ([]int, error) {
|
|
url := r.dashscopeRerankURL()
|
|
body := map[string]any{
|
|
"model": r.model,
|
|
"input": map[string]any{
|
|
"query": query,
|
|
"documents": documents,
|
|
},
|
|
"parameters": map[string]any{
|
|
"return_documents": false,
|
|
"top_n": topN,
|
|
},
|
|
}
|
|
raw, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
|
resp, err := r.client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dashscope rerank: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, fmt.Errorf("dashscope rerank http %d: %s", resp.StatusCode, truncateForRerankLog(string(respBody)))
|
|
}
|
|
var parsed struct {
|
|
Output struct {
|
|
Results []struct {
|
|
Index int `json:"index"`
|
|
} `json:"results"`
|
|
} `json:"output"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
|
return nil, fmt.Errorf("dashscope rerank decode: %w", err)
|
|
}
|
|
order := make([]int, 0, len(parsed.Output.Results))
|
|
for _, row := range parsed.Output.Results {
|
|
order = append(order, row.Index)
|
|
}
|
|
return order, nil
|
|
}
|
|
|
|
func (r *HTTPReranker) cohereRerankURL() string {
|
|
base := r.baseURL
|
|
if base == "" {
|
|
base = "https://api.cohere.com"
|
|
}
|
|
if strings.HasSuffix(base, "/v1") {
|
|
return base + "/rerank"
|
|
}
|
|
return base + "/v1/rerank"
|
|
}
|
|
|
|
func (r *HTTPReranker) dashscopeRerankURL() string {
|
|
base := strings.TrimSpace(r.baseURL)
|
|
if base == "" {
|
|
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
|
}
|
|
if strings.Contains(base, "/api/v1/services/rerank") {
|
|
return base
|
|
}
|
|
if strings.Contains(base, "dashscope.aliyuncs.com") || strings.Contains(base, "compatible-mode") {
|
|
return "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
|
|
}
|
|
return strings.TrimSuffix(base, "/")
|
|
}
|
|
|
|
func truncateForRerankLog(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
if len(s) > 512 {
|
|
return s[:512] + "..."
|
|
}
|
|
return s
|
|
}
|
|
|
|
var _ DocumentReranker = (*HTTPReranker)(nil)
|