Files
CyberStrikeAI/internal/knowledge/embedder.go
2025-12-20 17:36:40 +08:00

206 lines
5.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package knowledge
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap"
)
// Embedder 文本嵌入器
type Embedder struct {
openAIClient *openai.Client
config *config.KnowledgeConfig
openAIConfig *config.OpenAIConfig // 用于获取API Key
logger *zap.Logger
}
// NewEmbedder 创建新的嵌入器
func NewEmbedder(cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, openAIClient *openai.Client, logger *zap.Logger) *Embedder {
return &Embedder{
openAIClient: openAIClient,
config: cfg,
openAIConfig: openAIConfig,
logger: logger,
}
}
// EmbeddingRequest OpenAI嵌入请求
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
}
// EmbeddingResponse OpenAI嵌入响应
type EmbeddingResponse struct {
Data []EmbeddingData `json:"data"`
Error *EmbeddingError `json:"error,omitempty"`
}
// EmbeddingData 嵌入数据
type EmbeddingData struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
// EmbeddingError 嵌入错误
type EmbeddingError struct {
Message string `json:"message"`
Type string `json:"type"`
}
// EmbedText 对文本进行嵌入
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
if e.openAIClient == nil {
return nil, fmt.Errorf("OpenAI客户端未初始化")
}
// 使用配置的嵌入模型
model := e.config.Embedding.Model
if model == "" {
model = "text-embedding-3-small"
}
req := EmbeddingRequest{
Model: model,
Input: []string{text},
}
// 清理baseURL去除前后空格和尾部斜杠
baseURL := strings.TrimSpace(e.config.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
// 构建请求
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
requestURL := baseURL + "/embeddings"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// 使用配置的API Key如果没有则使用OpenAI配置的
apiKey := strings.TrimSpace(e.config.Embedding.APIKey)
if apiKey == "" && e.openAIConfig != nil {
apiKey = e.openAIConfig.APIKey
}
if apiKey == "" {
return nil, fmt.Errorf("API Key未配置")
}
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
// 发送请求
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
// 读取响应体以便在错误时输出详细信息
bodyBytes := make([]byte, 0)
buf := make([]byte, 4096)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
bodyBytes = append(bodyBytes, buf[:n]...)
}
if err != nil {
break
}
}
// 记录请求和响应信息(用于调试)
requestBodyPreview := string(body)
if len(requestBodyPreview) > 200 {
requestBodyPreview = requestBodyPreview[:200] + "..."
}
e.logger.Debug("嵌入API请求",
zap.String("url", httpReq.URL.String()),
zap.String("model", model),
zap.String("requestBody", requestBodyPreview),
zap.Int("status", resp.StatusCode),
zap.Int("bodySize", len(bodyBytes)),
zap.String("contentType", resp.Header.Get("Content-Type")),
)
var embeddingResp EmbeddingResponse
if err := json.Unmarshal(bodyBytes, &embeddingResp); err != nil {
// 输出详细的错误信息
bodyPreview := string(bodyBytes)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("解析响应失败 (URL: %s, 状态码: %d, 响应长度: %d字节): %w\n请求体: %s\n响应内容预览: %s",
requestURL, resp.StatusCode, len(bodyBytes), err, requestBodyPreview, bodyPreview)
}
if embeddingResp.Error != nil {
return nil, fmt.Errorf("OpenAI API错误 (状态码: %d): 类型=%s, 消息=%s",
resp.StatusCode, embeddingResp.Error.Type, embeddingResp.Error.Message)
}
if resp.StatusCode != http.StatusOK {
bodyPreview := string(bodyBytes)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("HTTP请求失败 (URL: %s, 状态码: %d): 响应内容=%s", requestURL, resp.StatusCode, bodyPreview)
}
if len(embeddingResp.Data) == 0 {
bodyPreview := string(bodyBytes)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "..."
}
return nil, fmt.Errorf("未收到嵌入数据 (状态码: %d, 响应长度: %d字节)\n响应内容: %s",
resp.StatusCode, len(bodyBytes), bodyPreview)
}
// 转换为float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
}
return embedding, nil
}
// EmbedTexts 批量嵌入文本
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
// OpenAI API支持批量但为了简单起见我们逐个处理
// 实际可以使用批量API以提高效率
embeddings := make([][]float32, len(texts))
for i, text := range texts {
embedding, err := e.EmbedText(ctx, text)
if err != nil {
return nil, fmt.Errorf("嵌入文本[%d]失败: %w", i, err)
}
embeddings[i] = embedding
}
return embeddings, nil
}