diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 428d0caf..02ceb668 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -1,11 +1,9 @@ package agent import ( - "bytes" "context" "encoding/json" "fmt" - "io" "net" "net/http" "strings" @@ -14,6 +12,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/storage" "go.uber.org/zap" @@ -21,7 +20,7 @@ import ( // Agent AI代理 type Agent struct { - openAIClient *http.Client + openAIClient *openai.Client config *config.OpenAIConfig agentConfig *config.AgentConfig memoryCompressor *MemoryCompressor @@ -94,6 +93,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer Timeout: 30 * time.Minute, // 从5分钟增加到30分钟 Transport: transport, } + llmClient := openai.NewClient(cfg, httpClient, logger) var memoryCompressor *MemoryCompressor if cfg != nil { @@ -112,7 +112,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer } return &Agent{ - openAIClient: httpClient, + openAIClient: llmClient, config: cfg, agentConfig: agentCfg, memoryCompressor: memoryCompressor, @@ -1016,95 +1016,17 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to reqBody.Tools = tools } - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, err - } - - // 记录请求大小(用于诊断) - requestSize := len(jsonData) a.logger.Debug("准备发送OpenAI请求", zap.Int("messagesCount", len(messages)), - zap.Int("requestSizeKB", requestSize/1024), zap.Int("toolsCount", len(tools)), ) - req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+a.config.APIKey) - - // 记录请求开始时间 - requestStartTime := time.Now() - resp, err := a.openAIClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - // 记录响应头接收时间 - headerReceiveTime := time.Now() - headerReceiveDuration := headerReceiveTime.Sub(requestStartTime) - - a.logger.Debug("收到OpenAI响应头", - zap.Int("statusCode", resp.StatusCode), - zap.Duration("headerReceiveDuration", headerReceiveDuration), - zap.Int64("contentLength", resp.ContentLength), - ) - - // 使用带超时的读取(通过context控制) - bodyChan := make(chan []byte, 1) - errChan := make(chan error, 1) - - go func() { - body, err := io.ReadAll(resp.Body) - if err != nil { - errChan <- err - return - } - bodyChan <- body - }() - - var body []byte - select { - case body = <-bodyChan: - // 读取成功 - bodyReceiveTime := time.Now() - bodyReceiveDuration := bodyReceiveTime.Sub(headerReceiveTime) - totalDuration := bodyReceiveTime.Sub(requestStartTime) - - a.logger.Debug("完成读取OpenAI响应体", - zap.Int("bodySizeKB", len(body)/1024), - zap.Duration("bodyReceiveDuration", bodyReceiveDuration), - zap.Duration("totalDuration", totalDuration), - ) - case err := <-errChan: - return nil, err - case <-ctx.Done(): - return nil, fmt.Errorf("读取响应体超时: %w", ctx.Err()) - case <-time.After(25 * time.Minute): - // 额外的安全超时:25分钟(小于30分钟的总超时) - return nil, fmt.Errorf("读取响应体超时(超过25分钟)") - } - - // 记录响应内容(用于调试) - if resp.StatusCode != http.StatusOK { - a.logger.Warn("OpenAI API返回非200状态码", - zap.Int("status", resp.StatusCode), - zap.String("body", string(body)), - ) - } - var response OpenAIResponse - if err := json.Unmarshal(body, &response); err != nil { - a.logger.Error("解析OpenAI响应失败", - zap.Error(err), - zap.String("body", string(body)), - ) - return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body)) + if a.openAIClient == nil { + return nil, fmt.Errorf("OpenAI客户端未初始化") + } + if err := a.openAIClient.ChatCompletion(ctx, reqBody, &response); err != nil { + return nil, err } return &response, nil diff --git a/internal/agent/memory_compressor.go b/internal/agent/memory_compressor.go index 648e762b..3f1ac66f 100644 --- a/internal/agent/memory_compressor.go +++ b/internal/agent/memory_compressor.go @@ -1,18 +1,16 @@ package agent import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "io" "net/http" "strings" "sync" "time" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/openai" "github.com/pkoukk/tiktoken-go" "go.uber.org/zap" @@ -143,15 +141,15 @@ func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) { if cfg == nil { return } - + // 更新summaryModel字段 if cfg.Model != "" { mc.summaryModel = cfg.Model } - + // 更新completionClient中的配置(如果是OpenAICompletionClient) if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok { - openAIClient.config = cfg + openAIClient.UpdateConfig(cfg) mc.logger.Info("MemoryCompressor配置已更新", zap.String("model", cfg.Model), ) @@ -410,9 +408,9 @@ type CompletionClient interface { // OpenAICompletionClient 基于 OpenAI Chat Completion。 type OpenAICompletionClient struct { - config *config.OpenAIConfig - httpClient *http.Client - logger *zap.Logger + config *config.OpenAIConfig + client *openai.Client + logger *zap.Logger } // NewOpenAICompletionClient 创建 OpenAICompletionClient。 @@ -421,9 +419,17 @@ func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, lo logger = zap.NewNop() } return &OpenAICompletionClient{ - config: cfg, - httpClient: client, - logger: logger, + config: cfg, + client: openai.NewClient(cfg, client, logger), + logger: logger, + } +} + +// UpdateConfig 更新底层配置。 +func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) { + c.config = cfg + if c.client != nil { + c.client.UpdateConfig(cfg) } } @@ -443,11 +449,6 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro }, } - body, err := json.Marshal(reqBody) - if err != nil { - return "", err - } - requestCtx := ctx var cancel context.CancelFunc if timeout > 0 { @@ -455,57 +456,14 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro defer cancel() } - // 处理 BaseURL 路径拼接:移除末尾斜杠,然后添加 /chat/completions - baseURL := strings.TrimSuffix(c.config.BaseURL, "/") - url := baseURL + "/chat/completions" - - c.logger.Debug("calling completion API", - zap.String("url", url), - zap.String("model", model), - zap.Int("prompt_length", len(prompt))) - - req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.config.APIKey) - - resp, err := c.httpClient.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - // 读取响应体以获取更详细的错误信息 - bodyBytes, readErr := io.ReadAll(resp.Body) - errorDetail := resp.Status - if readErr == nil && len(bodyBytes) > 0 { - // 尝试解析错误响应 - var errorResp struct { - Error struct { - Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code"` - } `json:"error"` - } - if json.Unmarshal(bodyBytes, &errorResp) == nil && errorResp.Error.Message != "" { - errorDetail = fmt.Sprintf("%s: %s", resp.Status, errorResp.Error.Message) - } else { - errorDetail = fmt.Sprintf("%s: %s", resp.Status, string(bodyBytes)) - } - } - c.logger.Warn("completion API request failed", - zap.Int("status_code", resp.StatusCode), - zap.String("url", url), - zap.String("model", model), - zap.String("error", errorDetail)) - return "", fmt.Errorf("openai completion failed, status: %s", errorDetail) - } - var completion OpenAIResponse - if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { + if c.client == nil { + return "", errors.New("openai completion client not initialized") + } + if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil { + if apiErr, ok := err.(*openai.APIError); ok { + return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body) + } return "", err } if completion.Error != nil { diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go index 83be2218..394454dc 100644 --- a/internal/attackchain/builder.go +++ b/internal/attackchain/builder.go @@ -1,11 +1,10 @@ package attackchain import ( - "bytes" "context" "encoding/json" + "errors" "fmt" - "io" "net/http" "sort" "strings" @@ -15,6 +14,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/openai" "github.com/google/uuid" "go.uber.org/zap" @@ -24,7 +24,7 @@ import ( type Builder struct { db *database.DB logger *zap.Logger - openAIClient *http.Client + openAIClient *openai.Client openAIConfig *config.OpenAIConfig tokenCounter agent.TokenCounter maxTokens int // 最大tokens限制,默认100000 @@ -49,6 +49,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap. MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, } + httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport} maxTokens := 100000 // 默认100k tokens,可以根据模型调整 // 根据模型设置合理的默认值 @@ -66,7 +67,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap. return &Builder{ db: db, logger: logger, - openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport}, + openAIClient: openai.NewClient(openAIConfig, httpClient, logger), openAIConfig: openAIConfig, tokenCounter: agent.NewTikTokenCounter(), maxTokens: maxTokens, @@ -962,30 +963,6 @@ func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk "max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容 } - jsonData, err := json.Marshal(requestBody) - if err != nil { - return "", fmt.Errorf("序列化请求失败: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("创建请求失败: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey) - - resp, err := b.openAIClient.Do(req) - if err != nil { - return "", fmt.Errorf("请求失败: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body)) - } - var apiResponse struct { Choices []struct { Message struct { @@ -994,8 +971,11 @@ func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk } `json:"choices"` } - if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { - return "", fmt.Errorf("解析响应失败: %w", err) + if b.openAIClient == nil { + return "", fmt.Errorf("OpenAI客户端未初始化") + } + if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil { + return "", fmt.Errorf("请求失败: %w", err) } if len(apiResponse.Choices) == 0 { @@ -1076,30 +1056,6 @@ AI回复: "max_tokens": 1000, } - jsonData, err := json.Marshal(requestBody) - if err != nil { - return "", fmt.Errorf("序列化请求失败: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("创建请求失败: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey) - - resp, err := b.openAIClient.Do(req) - if err != nil { - return "", fmt.Errorf("请求失败: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body)) - } - var apiResponse struct { Choices []struct { Message struct { @@ -1108,8 +1064,11 @@ AI回复: } `json:"choices"` } - if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { - return "", fmt.Errorf("解析响应失败: %w", err) + if b.openAIClient == nil { + return "", fmt.Errorf("OpenAI客户端未初始化") + } + if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil { + return "", fmt.Errorf("请求失败: %w", err) } if len(apiResponse.Choices) == 0 { @@ -1277,39 +1236,6 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) ( "max_tokens": 8000, } - jsonData, err := json.Marshal(requestBody) - if err != nil { - return "", fmt.Errorf("序列化请求失败: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) - if err != nil { - return "", fmt.Errorf("创建请求失败: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey) - - resp, err := b.openAIClient.Do(req) - if err != nil { - // 检查是否是上下文过长错误 - if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") { - return "", fmt.Errorf("context length exceeded") - } - return "", fmt.Errorf("请求失败: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - bodyStr := string(body) - // 检查是否是上下文过长错误 - if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") { - return "", fmt.Errorf("context length exceeded") - } - return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr) - } - var apiResponse struct { Choices []struct { Message struct { @@ -1318,8 +1244,20 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) ( } `json:"choices"` } - if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { - return "", fmt.Errorf("解析响应失败: %w", err) + if b.openAIClient == nil { + return "", fmt.Errorf("OpenAI客户端未初始化") + } + if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil { + var apiErr *openai.APIError + if errors.As(err, &apiErr) { + bodyStr := strings.ToLower(apiErr.Body) + if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") { + return "", fmt.Errorf("context length exceeded") + } + } else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") { + return "", fmt.Errorf("context length exceeded") + } + return "", fmt.Errorf("请求失败: %w", err) } if len(apiResponse.Choices) == 0 { diff --git a/internal/openai/openai.go b/internal/openai/openai.go new file mode 100644 index 00000000..e07f0d61 --- /dev/null +++ b/internal/openai/openai.go @@ -0,0 +1,144 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + + "go.uber.org/zap" +) + +// Client 统一封装与OpenAI兼容模型交互的HTTP客户端。 +type Client struct { + httpClient *http.Client + config *config.OpenAIConfig + logger *zap.Logger +} + +// APIError 表示OpenAI接口返回的非200错误。 +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + return fmt.Sprintf("openai api error: status=%d body=%s", e.StatusCode, e.Body) +} + +// NewClient 创建一个新的OpenAI客户端。 +func NewClient(cfg *config.OpenAIConfig, httpClient *http.Client, logger *zap.Logger) *Client { + if httpClient == nil { + httpClient = http.DefaultClient + } + if logger == nil { + logger = zap.NewNop() + } + return &Client{ + httpClient: httpClient, + config: cfg, + logger: logger, + } +} + +// UpdateConfig 动态更新OpenAI配置。 +func (c *Client) UpdateConfig(cfg *config.OpenAIConfig) { + c.config = cfg +} + +// ChatCompletion 调用 /chat/completions 接口。 +func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out interface{}) error { + if c == nil { + return fmt.Errorf("openai client is not initialized") + } + if c.config == nil { + return fmt.Errorf("openai config is nil") + } + if strings.TrimSpace(c.config.APIKey) == "" { + return fmt.Errorf("openai api key is empty") + } + + baseURL := strings.TrimSuffix(c.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal openai payload: %w", err) + } + + c.logger.Debug("sending OpenAI chat completion request", + zap.Int("payloadSizeKB", len(body)/1024)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("build openai request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.APIKey) + + requestStart := time.Now() + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("call openai api: %w", err) + } + defer resp.Body.Close() + + bodyChan := make(chan []byte, 1) + errChan := make(chan error, 1) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + errChan <- err + return + } + bodyChan <- responseBody + }() + + var respBody []byte + select { + case respBody = <-bodyChan: + case err := <-errChan: + return fmt.Errorf("read openai response: %w", err) + case <-ctx.Done(): + return fmt.Errorf("read openai response timeout: %w", ctx.Err()) + case <-time.After(25 * time.Minute): + return fmt.Errorf("read openai response timeout (25m)") + } + + c.logger.Debug("received OpenAI response", + zap.Int("status", resp.StatusCode), + zap.Duration("duration", time.Since(requestStart)), + zap.Int("responseSizeKB", len(respBody)/1024), + ) + + if resp.StatusCode != http.StatusOK { + c.logger.Warn("OpenAI chat completion returned non-200", + zap.Int("status", resp.StatusCode), + zap.String("body", string(respBody)), + ) + return &APIError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + } + } + + if out != nil { + if err := json.Unmarshal(respBody, out); err != nil { + c.logger.Error("failed to unmarshal OpenAI response", + zap.Error(err), + zap.String("body", string(respBody)), + ) + return fmt.Errorf("unmarshal openai response: %w", err) + } + } + + return nil +}