mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
Update memory_compressor.go
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -105,6 +106,9 @@ func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error)
|
||||
if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" {
|
||||
cfg.SummaryModel = cfg.OpenAIConfig.Model
|
||||
}
|
||||
if cfg.SummaryModel == "" {
|
||||
return nil, errors.New("summary model is required (either SummaryModel or OpenAIConfig.Model must be set)")
|
||||
}
|
||||
if cfg.TokenCounter == nil {
|
||||
cfg.TokenCounter = NewTikTokenCounter()
|
||||
}
|
||||
@@ -393,6 +397,9 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
|
||||
if c.config == nil {
|
||||
return "", errors.New("openai config is required")
|
||||
}
|
||||
if model == "" {
|
||||
return "", errors.New("model name is required")
|
||||
}
|
||||
|
||||
reqBody := OpenAIRequest{
|
||||
Model: model,
|
||||
@@ -413,7 +420,16 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, c.config.BaseURL+"/chat/completions", bytes.NewReader(body))
|
||||
// 处理 BaseURL 路径拼接:移除末尾斜杠,然后添加 /chat/completions
|
||||
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
||||
url := baseURL + "/chat/completions"
|
||||
|
||||
c.logger.Debug("calling completion API",
|
||||
zap.String("url", url),
|
||||
zap.String("model", model),
|
||||
zap.Int("prompt_length", len(prompt)))
|
||||
|
||||
req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -427,7 +443,30 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("openai completion failed, status: %s", resp.Status)
|
||||
// 读取响应体以获取更详细的错误信息
|
||||
bodyBytes, readErr := io.ReadAll(resp.Body)
|
||||
errorDetail := resp.Status
|
||||
if readErr == nil && len(bodyBytes) > 0 {
|
||||
// 尝试解析错误响应
|
||||
var errorResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(bodyBytes, &errorResp) == nil && errorResp.Error.Message != "" {
|
||||
errorDetail = fmt.Sprintf("%s: %s", resp.Status, errorResp.Error.Message)
|
||||
} else {
|
||||
errorDetail = fmt.Sprintf("%s: %s", resp.Status, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
c.logger.Warn("completion API request failed",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.String("url", url),
|
||||
zap.String("model", model),
|
||||
zap.String("error", errorDetail))
|
||||
return "", fmt.Errorf("openai completion failed, status: %s", errorDetail)
|
||||
}
|
||||
|
||||
var completion OpenAIResponse
|
||||
|
||||
Reference in New Issue
Block a user