mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 16:20:28 +02:00
Update agent.go
This commit is contained in:
@@ -35,14 +35,14 @@ func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logge
|
||||
// 配置HTTP Transport,优化连接管理和超时设置
|
||||
transport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Timeout: 300 * time.Second,
|
||||
KeepAlive: 300 * time.Second,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 5 * time.Minute, // 响应头超时
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
ResponseHeaderTimeout: 60 * time.Minute, // 响应头超时:增加到15分钟,应对大响应
|
||||
DisableKeepAlives: false, // 启用连接复用
|
||||
}
|
||||
|
||||
@@ -666,8 +666,89 @@ func (a *Agent) convertToOpenAIType(configType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// callOpenAI 调用OpenAI API
|
||||
// isRetryableError 判断错误是否可重试
|
||||
func (a *Agent) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
// 网络相关错误,可以重试
|
||||
retryableErrors := []string{
|
||||
"connection reset",
|
||||
"connection reset by peer",
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"i/o timeout",
|
||||
"context deadline exceeded",
|
||||
"no such host",
|
||||
"network is unreachable",
|
||||
"broken pipe",
|
||||
"EOF",
|
||||
"read tcp",
|
||||
"write tcp",
|
||||
"dial tcp",
|
||||
}
|
||||
for _, retryable := range retryableErrors {
|
||||
if strings.Contains(strings.ToLower(errStr), retryable) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// callOpenAI 调用OpenAI API(带重试机制)
|
||||
func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
|
||||
maxRetries := 3
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
response, err := a.callOpenAISingle(ctx, messages, tools)
|
||||
if err == nil {
|
||||
if attempt > 0 {
|
||||
a.logger.Info("OpenAI API调用重试成功",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// 如果不是可重试的错误,直接返回
|
||||
if !a.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果不是最后一次重试,等待后重试
|
||||
if attempt < maxRetries-1 {
|
||||
// 指数退避:2s, 4s, 8s...
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second // 最大30秒
|
||||
}
|
||||
a.logger.Warn("OpenAI API调用失败,准备重试",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
|
||||
// 检查上下文是否已取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("上下文已取消: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
// 继续重试
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// callOpenAISingle 单次调用OpenAI API(不包含重试逻辑)
|
||||
func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
|
||||
reqBody := OpenAIRequest{
|
||||
Model: a.config.Model,
|
||||
Messages: messages,
|
||||
@@ -682,6 +763,14 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录请求大小(用于诊断)
|
||||
requestSize := len(jsonData)
|
||||
a.logger.Debug("准备发送OpenAI请求",
|
||||
zap.Int("messagesCount", len(messages)),
|
||||
zap.Int("requestSizeKB", requestSize/1024),
|
||||
zap.Int("toolsCount", len(tools)),
|
||||
)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -690,15 +779,57 @@ func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
|
||||
|
||||
// 记录请求开始时间
|
||||
requestStartTime := time.Now()
|
||||
resp, err := a.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
// 记录响应头接收时间
|
||||
headerReceiveTime := time.Now()
|
||||
headerReceiveDuration := headerReceiveTime.Sub(requestStartTime)
|
||||
|
||||
a.logger.Debug("收到OpenAI响应头",
|
||||
zap.Int("statusCode", resp.StatusCode),
|
||||
zap.Duration("headerReceiveDuration", headerReceiveDuration),
|
||||
zap.Int64("contentLength", resp.ContentLength),
|
||||
)
|
||||
|
||||
// 使用带超时的读取(通过context控制)
|
||||
bodyChan := make(chan []byte, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
bodyChan <- body
|
||||
}()
|
||||
|
||||
var body []byte
|
||||
select {
|
||||
case body = <-bodyChan:
|
||||
// 读取成功
|
||||
bodyReceiveTime := time.Now()
|
||||
bodyReceiveDuration := bodyReceiveTime.Sub(headerReceiveTime)
|
||||
totalDuration := bodyReceiveTime.Sub(requestStartTime)
|
||||
|
||||
a.logger.Debug("完成读取OpenAI响应体",
|
||||
zap.Int("bodySizeKB", len(body)/1024),
|
||||
zap.Duration("bodyReceiveDuration", bodyReceiveDuration),
|
||||
zap.Duration("totalDuration", totalDuration),
|
||||
)
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("读取响应体超时: %w", ctx.Err())
|
||||
case <-time.After(25 * time.Minute):
|
||||
// 额外的安全超时:25分钟(小于30分钟的总超时)
|
||||
return nil, fmt.Errorf("读取响应体超时(超过25分钟)")
|
||||
}
|
||||
|
||||
// 记录响应内容(用于调试)
|
||||
|
||||
Reference in New Issue
Block a user