Add files via upload

This commit is contained in:
公明
2025-11-25 22:48:01 +08:00
committed by GitHub
parent 9254542f3d
commit 8ffdbbae52
4 changed files with 207 additions and 245 deletions

View File

@@ -1,11 +1,9 @@
package agent
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
@@ -14,6 +12,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
@@ -21,7 +20,7 @@ import (
// Agent AI代理
type Agent struct {
openAIClient *http.Client
openAIClient *openai.Client
config *config.OpenAIConfig
agentConfig *config.AgentConfig
memoryCompressor *MemoryCompressor
@@ -94,6 +93,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
Timeout: 30 * time.Minute, // 从5分钟增加到30分钟
Transport: transport,
}
llmClient := openai.NewClient(cfg, httpClient, logger)
var memoryCompressor *MemoryCompressor
if cfg != nil {
@@ -112,7 +112,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
}
return &Agent{
openAIClient: httpClient,
openAIClient: llmClient,
config: cfg,
agentConfig: agentCfg,
memoryCompressor: memoryCompressor,
@@ -1016,95 +1016,17 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
reqBody.Tools = tools
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
// 记录请求大小(用于诊断)
requestSize := len(jsonData)
a.logger.Debug("准备发送OpenAI请求",
zap.Int("messagesCount", len(messages)),
zap.Int("requestSizeKB", requestSize/1024),
zap.Int("toolsCount", len(tools)),
)
req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
// 记录请求开始时间
requestStartTime := time.Now()
resp, err := a.openAIClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 记录响应头接收时间
headerReceiveTime := time.Now()
headerReceiveDuration := headerReceiveTime.Sub(requestStartTime)
a.logger.Debug("收到OpenAI响应头",
zap.Int("statusCode", resp.StatusCode),
zap.Duration("headerReceiveDuration", headerReceiveDuration),
zap.Int64("contentLength", resp.ContentLength),
)
// 使用带超时的读取通过context控制
bodyChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go func() {
body, err := io.ReadAll(resp.Body)
if err != nil {
errChan <- err
return
}
bodyChan <- body
}()
var body []byte
select {
case body = <-bodyChan:
// 读取成功
bodyReceiveTime := time.Now()
bodyReceiveDuration := bodyReceiveTime.Sub(headerReceiveTime)
totalDuration := bodyReceiveTime.Sub(requestStartTime)
a.logger.Debug("完成读取OpenAI响应体",
zap.Int("bodySizeKB", len(body)/1024),
zap.Duration("bodyReceiveDuration", bodyReceiveDuration),
zap.Duration("totalDuration", totalDuration),
)
case err := <-errChan:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("读取响应体超时: %w", ctx.Err())
case <-time.After(25 * time.Minute):
// 额外的安全超时25分钟小于30分钟的总超时
return nil, fmt.Errorf("读取响应体超时超过25分钟")
}
// 记录响应内容(用于调试)
if resp.StatusCode != http.StatusOK {
a.logger.Warn("OpenAI API返回非200状态码",
zap.Int("status", resp.StatusCode),
zap.String("body", string(body)),
)
}
var response OpenAIResponse
if err := json.Unmarshal(body, &response); err != nil {
a.logger.Error("解析OpenAI响应失败",
zap.Error(err),
zap.String("body", string(body)),
)
return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body))
if a.openAIClient == nil {
return nil, fmt.Errorf("OpenAI客户端未初始化")
}
if err := a.openAIClient.ChatCompletion(ctx, reqBody, &response); err != nil {
return nil, err
}
return &response, nil

View File

@@ -1,18 +1,16 @@
package agent
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"github.com/pkoukk/tiktoken-go"
"go.uber.org/zap"
@@ -143,15 +141,15 @@ func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
if cfg == nil {
return
}
// 更新summaryModel字段
if cfg.Model != "" {
mc.summaryModel = cfg.Model
}
// 更新completionClient中的配置如果是OpenAICompletionClient
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
openAIClient.config = cfg
openAIClient.UpdateConfig(cfg)
mc.logger.Info("MemoryCompressor配置已更新",
zap.String("model", cfg.Model),
)
@@ -410,9 +408,9 @@ type CompletionClient interface {
// OpenAICompletionClient 基于 OpenAI Chat Completion。
type OpenAICompletionClient struct {
config *config.OpenAIConfig
httpClient *http.Client
logger *zap.Logger
config *config.OpenAIConfig
client *openai.Client
logger *zap.Logger
}
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
@@ -421,9 +419,17 @@ func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, lo
logger = zap.NewNop()
}
return &OpenAICompletionClient{
config: cfg,
httpClient: client,
logger: logger,
config: cfg,
client: openai.NewClient(cfg, client, logger),
logger: logger,
}
}
// UpdateConfig 更新底层配置。
func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) {
c.config = cfg
if c.client != nil {
c.client.UpdateConfig(cfg)
}
}
@@ -443,11 +449,6 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
},
}
body, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
requestCtx := ctx
var cancel context.CancelFunc
if timeout > 0 {
@@ -455,57 +456,14 @@ func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, pro
defer cancel()
}
// 处理 BaseURL 路径拼接:移除末尾斜杠,然后添加 /chat/completions
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
url := baseURL + "/chat/completions"
c.logger.Debug("calling completion API",
zap.String("url", url),
zap.String("model", model),
zap.Int("prompt_length", len(prompt)))
req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// 读取响应体以获取更详细的错误信息
bodyBytes, readErr := io.ReadAll(resp.Body)
errorDetail := resp.Status
if readErr == nil && len(bodyBytes) > 0 {
// 尝试解析错误响应
var errorResp struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
} `json:"error"`
}
if json.Unmarshal(bodyBytes, &errorResp) == nil && errorResp.Error.Message != "" {
errorDetail = fmt.Sprintf("%s: %s", resp.Status, errorResp.Error.Message)
} else {
errorDetail = fmt.Sprintf("%s: %s", resp.Status, string(bodyBytes))
}
}
c.logger.Warn("completion API request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("url", url),
zap.String("model", model),
zap.String("error", errorDetail))
return "", fmt.Errorf("openai completion failed, status: %s", errorDetail)
}
var completion OpenAIResponse
if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil {
if c.client == nil {
return "", errors.New("openai completion client not initialized")
}
if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil {
if apiErr, ok := err.(*openai.APIError); ok {
return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body)
}
return "", err
}
if completion.Error != nil {

View File

@@ -1,11 +1,10 @@
package attackchain
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sort"
"strings"
@@ -15,6 +14,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/openai"
"github.com/google/uuid"
"go.uber.org/zap"
@@ -24,7 +24,7 @@ import (
type Builder struct {
db *database.DB
logger *zap.Logger
openAIClient *http.Client
openAIClient *openai.Client
openAIConfig *config.OpenAIConfig
tokenCounter agent.TokenCounter
maxTokens int // 最大tokens限制默认100000
@@ -49,6 +49,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}
httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
maxTokens := 100000 // 默认100k tokens可以根据模型调整
// 根据模型设置合理的默认值
@@ -66,7 +67,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
return &Builder{
db: db,
logger: logger,
openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport},
openAIClient: openai.NewClient(openAIConfig, httpClient, logger),
openAIConfig: openAIConfig,
tokenCounter: agent.NewTikTokenCounter(),
maxTokens: maxTokens,
@@ -962,30 +963,6 @@ func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk
"max_tokens": 4000, // 增加摘要长度,以容纳更详细的内容
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
resp, err := b.openAIClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
}
var apiResponse struct {
Choices []struct {
Message struct {
@@ -994,8 +971,11 @@ func (b *Builder) summarizeContextChunk(ctx context.Context, chunk *ContextChunk
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
if b.openAIClient == nil {
return "", fmt.Errorf("OpenAI客户端未初始化")
}
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
if len(apiResponse.Choices) == 0 {
@@ -1076,30 +1056,6 @@ AI回复
"max_tokens": 1000,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
resp, err := b.openAIClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body))
}
var apiResponse struct {
Choices []struct {
Message struct {
@@ -1108,8 +1064,11 @@ AI回复
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
if b.openAIClient == nil {
return "", fmt.Errorf("OpenAI客户端未初始化")
}
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
return "", fmt.Errorf("请求失败: %w", err)
}
if len(apiResponse.Choices) == 0 {
@@ -1277,39 +1236,6 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
"max_tokens": 8000,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey)
resp, err := b.openAIClient.Do(req)
if err != nil {
// 检查是否是上下文过长错误
if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") {
return "", fmt.Errorf("context length exceeded")
}
return "", fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
bodyStr := string(body)
// 检查是否是上下文过长错误
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
return "", fmt.Errorf("context length exceeded")
}
return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr)
}
var apiResponse struct {
Choices []struct {
Message struct {
@@ -1318,8 +1244,20 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
if b.openAIClient == nil {
return "", fmt.Errorf("OpenAI客户端未初始化")
}
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
var apiErr *openai.APIError
if errors.As(err, &apiErr) {
bodyStr := strings.ToLower(apiErr.Body)
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
return "", fmt.Errorf("context length exceeded")
}
} else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") {
return "", fmt.Errorf("context length exceeded")
}
return "", fmt.Errorf("请求失败: %w", err)
}
if len(apiResponse.Choices) == 0 {

144
internal/openai/openai.go Normal file
View File

@@ -0,0 +1,144 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
// Client 统一封装与OpenAI兼容模型交互的HTTP客户端。
type Client struct {
httpClient *http.Client
config *config.OpenAIConfig
logger *zap.Logger
}
// APIError 表示OpenAI接口返回的非200错误。
type APIError struct {
StatusCode int
Body string
}
func (e *APIError) Error() string {
return fmt.Sprintf("openai api error: status=%d body=%s", e.StatusCode, e.Body)
}
// NewClient 创建一个新的OpenAI客户端。
func NewClient(cfg *config.OpenAIConfig, httpClient *http.Client, logger *zap.Logger) *Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
if logger == nil {
logger = zap.NewNop()
}
return &Client{
httpClient: httpClient,
config: cfg,
logger: logger,
}
}
// UpdateConfig 动态更新OpenAI配置。
func (c *Client) UpdateConfig(cfg *config.OpenAIConfig) {
c.config = cfg
}
// ChatCompletion 调用 /chat/completions 接口。
func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out interface{}) error {
if c == nil {
return fmt.Errorf("openai client is not initialized")
}
if c.config == nil {
return fmt.Errorf("openai config is nil")
}
if strings.TrimSpace(c.config.APIKey) == "" {
return fmt.Errorf("openai api key is empty")
}
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal openai payload: %w", err)
}
c.logger.Debug("sending OpenAI chat completion request",
zap.Int("payloadSizeKB", len(body)/1024))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build openai request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
requestStart := time.Now()
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("call openai api: %w", err)
}
defer resp.Body.Close()
bodyChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
errChan <- err
return
}
bodyChan <- responseBody
}()
var respBody []byte
select {
case respBody = <-bodyChan:
case err := <-errChan:
return fmt.Errorf("read openai response: %w", err)
case <-ctx.Done():
return fmt.Errorf("read openai response timeout: %w", ctx.Err())
case <-time.After(25 * time.Minute):
return fmt.Errorf("read openai response timeout (25m)")
}
c.logger.Debug("received OpenAI response",
zap.Int("status", resp.StatusCode),
zap.Duration("duration", time.Since(requestStart)),
zap.Int("responseSizeKB", len(respBody)/1024),
)
if resp.StatusCode != http.StatusOK {
c.logger.Warn("OpenAI chat completion returned non-200",
zap.Int("status", resp.StatusCode),
zap.String("body", string(respBody)),
)
return &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
}
}
if out != nil {
if err := json.Unmarshal(respBody, out); err != nil {
c.logger.Error("failed to unmarshal OpenAI response",
zap.Error(err),
zap.String("body", string(respBody)),
)
return fmt.Errorf("unmarshal openai response: %w", err)
}
}
return nil
}