mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
Add files via upload
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -14,6 +12,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
|
||||
// Agent AI代理
|
||||
type Agent struct {
|
||||
openAIClient *http.Client
|
||||
openAIClient *openai.Client
|
||||
config *config.OpenAIConfig
|
||||
agentConfig *config.AgentConfig
|
||||
memoryCompressor *MemoryCompressor
|
||||
@@ -94,6 +93,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
||||
Timeout: 30 * time.Minute, // 从5分钟增加到30分钟
|
||||
Transport: transport,
|
||||
}
|
||||
llmClient := openai.NewClient(cfg, httpClient, logger)
|
||||
|
||||
var memoryCompressor *MemoryCompressor
|
||||
if cfg != nil {
|
||||
@@ -112,7 +112,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
openAIClient: httpClient,
|
||||
openAIClient: llmClient,
|
||||
config: cfg,
|
||||
agentConfig: agentCfg,
|
||||
memoryCompressor: memoryCompressor,
|
||||
@@ -1016,95 +1016,17 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
|
||||
reqBody.Tools = tools
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录请求大小(用于诊断)
|
||||
requestSize := len(jsonData)
|
||||
a.logger.Debug("准备发送OpenAI请求",
|
||||
zap.Int("messagesCount", len(messages)),
|
||||
zap.Int("requestSizeKB", requestSize/1024),
|
||||
zap.Int("toolsCount", len(tools)),
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
|
||||
|
||||
// 记录请求开始时间
|
||||
requestStartTime := time.Now()
|
||||
resp, err := a.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录响应头接收时间
|
||||
headerReceiveTime := time.Now()
|
||||
headerReceiveDuration := headerReceiveTime.Sub(requestStartTime)
|
||||
|
||||
a.logger.Debug("收到OpenAI响应头",
|
||||
zap.Int("statusCode", resp.StatusCode),
|
||||
zap.Duration("headerReceiveDuration", headerReceiveDuration),
|
||||
zap.Int64("contentLength", resp.ContentLength),
|
||||
)
|
||||
|
||||
// 使用带超时的读取(通过context控制)
|
||||
bodyChan := make(chan []byte, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
bodyChan <- body
|
||||
}()
|
||||
|
||||
var body []byte
|
||||
select {
|
||||
case body = <-bodyChan:
|
||||
// 读取成功
|
||||
bodyReceiveTime := time.Now()
|
||||
bodyReceiveDuration := bodyReceiveTime.Sub(headerReceiveTime)
|
||||
totalDuration := bodyReceiveTime.Sub(requestStartTime)
|
||||
|
||||
a.logger.Debug("完成读取OpenAI响应体",
|
||||
zap.Int("bodySizeKB", len(body)/1024),
|
||||
zap.Duration("bodyReceiveDuration", bodyReceiveDuration),
|
||||
zap.Duration("totalDuration", totalDuration),
|
||||
)
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("读取响应体超时: %w", ctx.Err())
|
||||
case <-time.After(25 * time.Minute):
|
||||
// 额外的安全超时:25分钟(小于30分钟的总超时)
|
||||
return nil, fmt.Errorf("读取响应体超时(超过25分钟)")
|
||||
}
|
||||
|
||||
// 记录响应内容(用于调试)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
a.logger.Warn("OpenAI API返回非200状态码",
|
||||
zap.Int("status", resp.StatusCode),
|
||||
zap.String("body", string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
var response OpenAIResponse
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
a.logger.Error("解析OpenAI响应失败",
|
||||
zap.Error(err),
|
||||
zap.String("body", string(body)),
|
||||
)
|
||||
return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body))
|
||||
if a.openAIClient == nil {
|
||||
return nil, fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
if err := a.openAIClient.ChatCompletion(ctx, reqBody, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"go.uber.org/zap"
|
||||
@@ -143,15 +141,15 @@ func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 更新summaryModel字段
|
||||
if cfg.Model != "" {
|
||||
mc.summaryModel = cfg.Model
|
||||
}
|
||||
|
||||
|
||||
// 更新completionClient中的配置(如果是OpenAICompletionClient)
|
||||
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
|
||||
openAIClient.config = cfg
|
||||
openAIClient.UpdateConfig(cfg)
|
||||
mc.logger.Info("MemoryCompressor配置已更新",
|
||||
zap.String("model", cfg.Model),
|
||||
)
|
||||
@@ -410,9 +408,9 @@ type CompletionClient interface {
|
||||
|
||||
// OpenAICompletionClient 基于 OpenAI Chat Completion。
|
||||
type OpenAICompletionClient struct {
|
||||
config *config.OpenAIConfig
|
||||
httpClient *http.Client
|
||||
logger *zap.Logger
|
||||
config *config.OpenAIConfig
|
||||
client *openai.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
|
||||
@@ -421,9 +419,17 @@ func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, lo
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &OpenAICompletionClient{
|
||||
config: cfg,
|
||||
httpClient: client,
|
||||
logger: logger,
|
||||
config: cfg,
|
||||
client: openai.NewClient(cfg, client, logger),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新底层配置。
|
||||
func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
c.config = cfg
|
||||
if c.client != nil {
|
||||
c.client.UpdateConfig(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -443,11 +449,6 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestCtx := ctx
|
||||
var cancel context.CancelFunc
|
||||
if timeout > 0 {
|
||||
@@ -455,57 +456,14 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// 处理 BaseURL 路径拼接:移除末尾斜杠,然后添加 /chat/completions
|
||||
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
||||
url := baseURL + "/chat/completions"
|
||||
|
||||
c.logger.Debug("calling completion API",
|
||||
zap.String("url", url),
|
||||
zap.String("model", model),
|
||||
zap.Int("prompt_length", len(prompt)))
|
||||
|
||||
req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// 读取响应体以获取更详细的错误信息
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
errorDetail := resp.Status
|
||||
if readErr == nil && len(bodyBytes) > 0 {
|
||||
// 尝试解析错误响应
|
||||
var errorResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &errorResp) == nil && errorResp.Error.Message != "" {
|
||||
errorDetail = fmt.Sprintf("%s: %s", resp.Status, errorResp.Error.Message)
|
||||
} else {
|
||||
errorDetail = fmt.Sprintf("%s: %s", resp.Status, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
c.logger.Warn("completion API request failed",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.String("url", url),
|
||||
zap.String("model", model),
|
||||
zap.String("error", errorDetail))
|
||||
return "", fmt.Errorf("openai completion failed, status: %s", errorDetail)
|
||||
}
|
||||
|
||||
var completion OpenAIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil {
|
||||
if c.client == nil {
|
||||
return "", errors.New("openai completion client not initialized")
|
||||
}
|
||||
if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil {
|
||||
if apiErr, ok := err.(*openai.APIError); ok {
|
||||
return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if completion.Error != nil {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -15,6 +14,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
type Builder struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
openAIClient *http.Client
|
||||
openAIClient *openai.Client
|
||||
openAIConfig *config.OpenAIConfig
|
||||
tokenCounter agent.TokenCounter
|
||||
maxTokens int // 最大tokens限制,默认100000
|
||||
@@ -49,6 +49,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
|
||||
|
||||
maxTokens := 100000 // 默认100k tokens,可以根据模型调整
|
||||
// 根据模型设置合理的默认值
|
||||
@@ -66,7 +67,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
||||
return &Builder{
|
||||
db: db,
|
||||
logger: logger,
|
||||
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
|
||||
openAIClient: openai.NewClient(openAIConfig, httpClient, logger),
|
||||
openAIConfig: openAIConfig,
|
||||
tokenCounter: agent.NewTikTokenCounter(),
|
||||
maxTokens: maxTokens,
|
||||
@@ -962,30 +963,6 @@ func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk
|
||||
"max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||||
|
||||
resp, err := b.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
@@ -994,8 +971,11 @@ func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
if b.openAIClient == nil {
|
||||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
@@ -1076,30 +1056,6 @@ AI回复:
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||||
|
||||
resp, err := b.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
@@ -1108,8 +1064,11 @@ AI回复:
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
if b.openAIClient == nil {
|
||||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
@@ -1277,39 +1236,6 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
||||
"max_tokens": 8000,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
|
||||
|
||||
resp, err := b.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是上下文过长错误
|
||||
if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") {
|
||||
return "", fmt.Errorf("context length exceeded")
|
||||
}
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
bodyStr := string(body)
|
||||
// 检查是否是上下文过长错误
|
||||
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
|
||||
return "", fmt.Errorf("context length exceeded")
|
||||
}
|
||||
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr)
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
@@ -1318,8 +1244,20 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
if b.openAIClient == nil {
|
||||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
|
||||
var apiErr *openai.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
bodyStr := strings.ToLower(apiErr.Body)
|
||||
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
|
||||
return "", fmt.Errorf("context length exceeded")
|
||||
}
|
||||
} else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") {
|
||||
return "", fmt.Errorf("context length exceeded")
|
||||
}
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
|
||||
144
internal/openai/openai.go
Normal file
144
internal/openai/openai.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Client 统一封装与OpenAI兼容模型交互的HTTP客户端。
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
config *config.OpenAIConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// APIError 表示OpenAI接口返回的非200错误。
|
||||
type APIError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("openai api error: status=%d body=%s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// NewClient 创建一个新的OpenAI客户端。
|
||||
func NewClient(cfg *config.OpenAIConfig, httpClient *http.Client, logger *zap.Logger) *Client {
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 动态更新OpenAI配置。
|
||||
func (c *Client) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
c.config = cfg
|
||||
}
|
||||
|
||||
// ChatCompletion 调用 /chat/completions 接口。
|
||||
func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out interface{}) error {
|
||||
if c == nil {
|
||||
return fmt.Errorf("openai client is not initialized")
|
||||
}
|
||||
if c.config == nil {
|
||||
return fmt.Errorf("openai config is nil")
|
||||
}
|
||||
if strings.TrimSpace(c.config.APIKey) == "" {
|
||||
return fmt.Errorf("openai api key is empty")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal openai payload: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Debug("sending OpenAI chat completion request",
|
||||
zap.Int("payloadSizeKB", len(body)/1024))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build openai request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||
|
||||
requestStart := time.Now()
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("call openai api: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyChan := make(chan []byte, 1)
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
bodyChan <- responseBody
|
||||
}()
|
||||
|
||||
var respBody []byte
|
||||
select {
|
||||
case respBody = <-bodyChan:
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("read openai response: %w", err)
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("read openai response timeout: %w", ctx.Err())
|
||||
case <-time.After(25 * time.Minute):
|
||||
return fmt.Errorf("read openai response timeout (25m)")
|
||||
}
|
||||
|
||||
c.logger.Debug("received OpenAI response",
|
||||
zap.Int("status", resp.StatusCode),
|
||||
zap.Duration("duration", time.Since(requestStart)),
|
||||
zap.Int("responseSizeKB", len(respBody)/1024),
|
||||
)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
c.logger.Warn("OpenAI chat completion returned non-200",
|
||||
zap.Int("status", resp.StatusCode),
|
||||
zap.String("body", string(respBody)),
|
||||
)
|
||||
return &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
}
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
if err := json.Unmarshal(respBody, out); err != nil {
|
||||
c.logger.Error("failed to unmarshal OpenAI response",
|
||||
zap.Error(err),
|
||||
zap.String("body", string(respBody)),
|
||||
)
|
||||
return fmt.Errorf("unmarshal openai response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user