mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
324 lines
8.9 KiB
Go
324 lines
8.9 KiB
Go
package knowledge
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/config"
|
||
"cyberstrike-ai/internal/openai"
|
||
|
||
"go.uber.org/zap"
|
||
"golang.org/x/time/rate"
|
||
)
|
||
|
||
// Embedder 文本嵌入器
|
||
type Embedder struct {
|
||
openAIClient *openai.Client
|
||
config *config.KnowledgeConfig
|
||
openAIConfig *config.OpenAIConfig // 用于获取 API Key
|
||
logger *zap.Logger
|
||
rateLimiter *rate.Limiter // 速率限制器
|
||
rateLimitDelay time.Duration // 请求间隔时间
|
||
maxRetries int // 最大重试次数
|
||
retryDelay time.Duration // 重试间隔
|
||
mu sync.Mutex // 保护 rateLimiter
|
||
}
|
||
|
||
// NewEmbedder 创建新的嵌入器
|
||
func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder {
|
||
// 初始化速率限制器
|
||
var rateLimiter *rate.Limiter
|
||
var rateLimitDelay time.Duration
|
||
|
||
// 如果配置了 MaxRPM,根据 RPM 计算速率限制
|
||
if cfg.Indexing.MaxRPM > 0 {
|
||
rpm := cfg.Indexing.MaxRPM
|
||
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
|
||
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
|
||
} else if cfg.Indexing.RateLimitDelayMs > 0 {
|
||
// 如果没有配置 MaxRPM 但配置了固定延迟,使用固定延迟模式
|
||
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
|
||
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
|
||
}
|
||
|
||
// 重试配置
|
||
maxRetries := 3
|
||
retryDelay := 1000 * time.Millisecond
|
||
if cfg.Indexing.MaxRetries > 0 {
|
||
maxRetries = cfg.Indexing.MaxRetries
|
||
}
|
||
if cfg.Indexing.RetryDelayMs > 0 {
|
||
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
|
||
}
|
||
|
||
return &Embedder{
|
||
openAIClient: openAIClient,
|
||
config: cfg,
|
||
openAIConfig: openAIConfig,
|
||
logger: logger,
|
||
rateLimiter: rateLimiter,
|
||
rateLimitDelay: rateLimitDelay,
|
||
maxRetries: maxRetries,
|
||
retryDelay: retryDelay,
|
||
}
|
||
}
|
||
|
||
// EmbeddingRequest OpenAI 嵌入请求
|
||
type EmbeddingRequest struct {
|
||
Model string `json:"model"`
|
||
Input []string `json:"input"`
|
||
}
|
||
|
||
// EmbeddingResponse OpenAI 嵌入响应
|
||
type EmbeddingResponse struct {
|
||
Data []EmbeddingData `json:"data"`
|
||
Error *EmbeddingError `json:"error,omitempty"`
|
||
}
|
||
|
||
// EmbeddingData 嵌入数据
|
||
type EmbeddingData struct {
|
||
Embedding []float64 `json:"embedding"`
|
||
Index int `json:"index"`
|
||
}
|
||
|
||
// EmbeddingError 嵌入错误
|
||
type EmbeddingError struct {
|
||
Message string `json:"message"`
|
||
Type string `json:"type"`
|
||
}
|
||
|
||
// waitRateLimiter 等待速率限制器
|
||
func (e *Embedder) waitRateLimiter() {
|
||
e.mu.Lock()
|
||
defer e.mu.Unlock()
|
||
|
||
if e.rateLimiter != nil {
|
||
// 等待令牌
|
||
ctx := context.Background()
|
||
if err := e.rateLimiter.Wait(ctx); err != nil {
|
||
e.logger.Warn("速率限制器等待失败", zap.Error(err))
|
||
}
|
||
}
|
||
|
||
if e.rateLimitDelay > 0 {
|
||
time.Sleep(e.rateLimitDelay)
|
||
}
|
||
}
|
||
|
||
// EmbedText 对文本进行嵌入(带重试和速率限制)
|
||
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
|
||
if e.openAIClient == nil {
|
||
return nil, fmt.Errorf("OpenAI 客户端未初始化")
|
||
}
|
||
|
||
var lastErr error
|
||
for attempt := 0; attempt < e.maxRetries; attempt++ {
|
||
// 速率限制
|
||
if attempt > 0 {
|
||
// 重试时等待更长时间
|
||
waitTime := e.retryDelay * time.Duration(attempt)
|
||
e.logger.Debug("重试前等待", zap.Int("attempt", attempt+1), zap.Duration("waitTime", waitTime))
|
||
select {
|
||
case <-ctx.Done():
|
||
return nil, ctx.Err()
|
||
case <-time.After(waitTime):
|
||
}
|
||
} else {
|
||
e.waitRateLimiter()
|
||
}
|
||
|
||
result, err := e.doEmbedText(ctx, text)
|
||
if err == nil {
|
||
return result, nil
|
||
}
|
||
|
||
lastErr = err
|
||
|
||
// 检查是否是可重试的错误(429 速率限制、5xx 服务器错误、网络错误)
|
||
if !e.isRetryableError(err) {
|
||
return nil, err
|
||
}
|
||
|
||
e.logger.Debug("嵌入请求失败,准备重试",
|
||
zap.Int("attempt", attempt+1),
|
||
zap.Int("maxRetries", e.maxRetries),
|
||
zap.Error(err))
|
||
}
|
||
|
||
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
|
||
}
|
||
|
||
// doEmbedText 执行实际的嵌入请求(内部方法)
|
||
func (e *Embedder) doEmbedText(ctx context.Context, text string) ([]float32, error) {
|
||
// 使用配置的嵌入模型
|
||
model := e.config.Embedding.Model
|
||
if model == "" {
|
||
model = "text-embedding-3-small"
|
||
}
|
||
|
||
req := EmbeddingRequest{
|
||
Model: model,
|
||
Input: []string{text},
|
||
}
|
||
|
||
// 清理 baseURL:去除前后空格和尾部斜杠
|
||
baseURL := strings.TrimSpace(e.config.Embedding.BaseURL)
|
||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||
if baseURL == "" {
|
||
baseURL = "https://api.openai.com/v1"
|
||
}
|
||
|
||
// 构建请求
|
||
body, err := json.Marshal(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败:%w", err)
|
||
}
|
||
|
||
requestURL := baseURL + "/embeddings"
|
||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body)))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建请求失败:%w", err)
|
||
}
|
||
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
// 使用配置的 API Key,如果没有则使用 OpenAI 配置的
|
||
apiKey := strings.TrimSpace(e.config.Embedding.APIKey)
|
||
if apiKey == "" && e.openAIConfig != nil {
|
||
apiKey = e.openAIConfig.APIKey
|
||
}
|
||
if apiKey == "" {
|
||
return nil, fmt.Errorf("API Key 未配置")
|
||
}
|
||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
// 发送请求
|
||
httpClient := &http.Client{
|
||
Timeout: 30 * time.Second,
|
||
}
|
||
resp, err := httpClient.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("发送请求失败:%w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应体以便在错误时输出详细信息
|
||
bodyBytes := make([]byte, 0)
|
||
buf := make([]byte, 4096)
|
||
for {
|
||
n, err := resp.Body.Read(buf)
|
||
if n > 0 {
|
||
bodyBytes = append(bodyBytes, buf[:n]...)
|
||
}
|
||
if err != nil {
|
||
break
|
||
}
|
||
}
|
||
|
||
// 记录请求和响应信息(用于调试)
|
||
requestBodyPreview := string(body)
|
||
if len(requestBodyPreview) > 200 {
|
||
requestBodyPreview = requestBodyPreview[:200] + "..."
|
||
}
|
||
e.logger.Debug("嵌入 API 请求",
|
||
zap.String("url", httpReq.URL.String()),
|
||
zap.String("model", model),
|
||
zap.String("requestBody", requestBodyPreview),
|
||
zap.Int("status", resp.StatusCode),
|
||
zap.Int("bodySize", len(bodyBytes)),
|
||
zap.String("contentType", resp.Header.Get("Content-Type")),
|
||
)
|
||
|
||
var embeddingResp EmbeddingResponse
|
||
if err := json.Unmarshal(bodyBytes, &embeddingResp); err != nil {
|
||
// 输出详细的错误信息
|
||
bodyPreview := string(bodyBytes)
|
||
if len(bodyPreview) > 500 {
|
||
bodyPreview = bodyPreview[:500] + "..."
|
||
}
|
||
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码:%d, 响应长度:%d字节): %w\n请求体:%s\n响应内容预览:%s",
|
||
requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview)
|
||
}
|
||
|
||
if embeddingResp.Error != nil {
|
||
return nil, fmt.Errorf("OpenAI API 错误 (状态码:%d): 类型=%s, 消息=%s",
|
||
resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
bodyPreview := string(bodyBytes)
|
||
if len(bodyPreview) > 500 {
|
||
bodyPreview = bodyPreview[:500] + "..."
|
||
}
|
||
return nil, fmt.Errorf("HTTP 请求失败 (URL: %s, 状态码:%d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
|
||
}
|
||
|
||
if len(embeddingResp.Data) == 0 {
|
||
bodyPreview := string(bodyBytes)
|
||
if len(bodyPreview) > 500 {
|
||
bodyPreview = bodyPreview[:500] + "..."
|
||
}
|
||
return nil, fmt.Errorf("未收到嵌入数据 (状态码:%d, 响应长度:%d字节)\n响应内容:%s",
|
||
resp.StatusCode, len(bodyBytes), bodyPreview)
|
||
}
|
||
|
||
// 转换为 float32
|
||
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
|
||
for i, v := range embeddingResp.Data[0].Embedding {
|
||
embedding[i] = float32(v)
|
||
}
|
||
|
||
return embedding, nil
|
||
}
|
||
|
||
// isRetryableError 判断是否是可重试的错误
|
||
func (e *Embedder) isRetryableError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
errStr := err.Error()
|
||
|
||
// 429 速率限制错误
|
||
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
|
||
return true
|
||
}
|
||
|
||
// 5xx 服务器错误
|
||
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
|
||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
|
||
return true
|
||
}
|
||
|
||
// 网络错误
|
||
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
|
||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// EmbedTexts 批量嵌入文本
|
||
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
|
||
if len(texts) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
embeddings := make([][]float32, len(texts))
|
||
for i, text := range texts {
|
||
embedding, err := e.EmbedText(ctx, text)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("嵌入文本 [%d] 失败:%w", i, err)
|
||
}
|
||
embeddings[i] = embedding
|
||
}
|
||
|
||
return embeddings, nil
|
||
}
|