mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
4442e7de30
When provider is set to "claude" in config, all OpenAI-compatible API calls are automatically bridged to Anthropic Claude Messages API, including: - Non-streaming and streaming chat completions - Tool calls (function calling) with full bidirectional conversion - Eino multi-agent via HTTP transport hook (claudeRoundTripper) - System message extraction, auth header conversion (Bearer → x-api-key) - SSE stream format conversion (content_block_delta → OpenAI delta) - TestOpenAI handler support for Claude connectivity testing Zero impact when provider is "openai" or empty (default behavior unchanged). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
494 lines
13 KiB
Go
494 lines
13 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/config"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Client 统一封装与OpenAI兼容模型交互的HTTP客户端。
|
|
type Client struct {
|
|
httpClient *http.Client
|
|
config *config.OpenAIConfig
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// APIError 表示OpenAI接口返回的非200错误。
|
|
type APIError struct {
|
|
StatusCode int
|
|
Body string
|
|
}
|
|
|
|
func (e *APIError) Error() string {
|
|
return fmt.Sprintf("openai api error: status=%d body=%s", e.StatusCode, e.Body)
|
|
}
|
|
|
|
// NewClient 创建一个新的OpenAI客户端。
|
|
func NewClient(cfg *config.OpenAIConfig, httpClient *http.Client, logger *zap.Logger) *Client {
|
|
if httpClient == nil {
|
|
httpClient = http.DefaultClient
|
|
}
|
|
if logger == nil {
|
|
logger = zap.NewNop()
|
|
}
|
|
return &Client{
|
|
httpClient: httpClient,
|
|
config: cfg,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// UpdateConfig 动态更新OpenAI配置。
|
|
func (c *Client) UpdateConfig(cfg *config.OpenAIConfig) {
|
|
c.config = cfg
|
|
}
|
|
|
|
// ChatCompletion 调用 /chat/completions 接口。
|
|
func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out interface{}) error {
|
|
if c == nil {
|
|
return fmt.Errorf("openai client is not initialized")
|
|
}
|
|
if c.config == nil {
|
|
return fmt.Errorf("openai config is nil")
|
|
}
|
|
if strings.TrimSpace(c.config.APIKey) == "" {
|
|
return fmt.Errorf("openai api key is empty")
|
|
}
|
|
if c.isClaude() {
|
|
return c.claudeChatCompletion(ctx, payload, out)
|
|
}
|
|
|
|
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
|
if baseURL == "" {
|
|
baseURL = "https://api.openai.com/v1"
|
|
}
|
|
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal openai payload: %w", err)
|
|
}
|
|
|
|
c.logger.Debug("sending OpenAI chat completion request",
|
|
zap.Int("payloadSizeKB", len(body)/1024))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
|
if err != nil {
|
|
return fmt.Errorf("build openai request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
|
|
requestStart := time.Now()
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("call openai api: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
bodyChan := make(chan []byte, 1)
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
bodyChan <- responseBody
|
|
}()
|
|
|
|
var respBody []byte
|
|
select {
|
|
case respBody = <-bodyChan:
|
|
case err := <-errChan:
|
|
return fmt.Errorf("read openai response: %w", err)
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("read openai response timeout: %w", ctx.Err())
|
|
case <-time.After(25 * time.Minute):
|
|
return fmt.Errorf("read openai response timeout (25m)")
|
|
}
|
|
|
|
c.logger.Debug("received OpenAI response",
|
|
zap.Int("status", resp.StatusCode),
|
|
zap.Duration("duration", time.Since(requestStart)),
|
|
zap.Int("responseSizeKB", len(respBody)/1024),
|
|
)
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
c.logger.Warn("OpenAI chat completion returned non-200",
|
|
zap.Int("status", resp.StatusCode),
|
|
zap.String("body", string(respBody)),
|
|
)
|
|
return &APIError{
|
|
StatusCode: resp.StatusCode,
|
|
Body: string(respBody),
|
|
}
|
|
}
|
|
|
|
if out != nil {
|
|
if err := json.Unmarshal(respBody, out); err != nil {
|
|
c.logger.Error("failed to unmarshal OpenAI response",
|
|
zap.Error(err),
|
|
zap.String("body", string(respBody)),
|
|
)
|
|
return fmt.Errorf("unmarshal openai response: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ChatCompletionStream 调用 /chat/completions 的流式模式(stream=true),并在每个 delta 到达时回调 onDelta。
|
|
// 返回最终拼接的 content(只拼 content delta;工具调用 delta 未做处理)。
|
|
func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, onDelta func(delta string) error) (string, error) {
|
|
if c == nil {
|
|
return "", fmt.Errorf("openai client is not initialized")
|
|
}
|
|
if c.config == nil {
|
|
return "", fmt.Errorf("openai config is nil")
|
|
}
|
|
if strings.TrimSpace(c.config.APIKey) == "" {
|
|
return "", fmt.Errorf("openai api key is empty")
|
|
}
|
|
if c.isClaude() {
|
|
return c.claudeChatCompletionStream(ctx, payload, onDelta)
|
|
}
|
|
|
|
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
|
if baseURL == "" {
|
|
baseURL = "https://api.openai.com/v1"
|
|
}
|
|
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal openai payload: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", fmt.Errorf("build openai request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
|
|
requestStart := time.Now()
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("call openai api: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// 非200:读完 body 返回
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", &APIError{
|
|
StatusCode: resp.StatusCode,
|
|
Body: string(respBody),
|
|
}
|
|
}
|
|
|
|
type streamDelta struct {
|
|
// OpenAI 兼容流式通常使用 content;但部分兼容实现可能用 text。
|
|
Content string `json:"content,omitempty"`
|
|
Text string `json:"text,omitempty"`
|
|
}
|
|
type streamChoice struct {
|
|
Delta streamDelta `json:"delta"`
|
|
FinishReason *string `json:"finish_reason,omitempty"`
|
|
}
|
|
type streamResponse struct {
|
|
ID string `json:"id,omitempty"`
|
|
Choices []streamChoice `json:"choices"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
var full strings.Builder
|
|
|
|
// 典型 SSE 结构:
|
|
// data: {...}\n\n
|
|
// data: [DONE]\n\n
|
|
for {
|
|
line, readErr := reader.ReadString('\n')
|
|
if readErr != nil {
|
|
if readErr == io.EOF {
|
|
break
|
|
}
|
|
return full.String(), fmt.Errorf("read openai stream: %w", readErr)
|
|
}
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed == "" {
|
|
continue
|
|
}
|
|
if !strings.HasPrefix(trimmed, "data:") {
|
|
continue
|
|
}
|
|
dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
|
if dataStr == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chunk streamResponse
|
|
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
|
|
// 解析失败跳过(兼容各种兼容层的差异)
|
|
continue
|
|
}
|
|
if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" {
|
|
return full.String(), fmt.Errorf("openai stream error: %s", chunk.Error.Message)
|
|
}
|
|
if len(chunk.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
delta := chunk.Choices[0].Delta.Content
|
|
if delta == "" {
|
|
delta = chunk.Choices[0].Delta.Text
|
|
}
|
|
if delta == "" {
|
|
continue
|
|
}
|
|
|
|
full.WriteString(delta)
|
|
if onDelta != nil {
|
|
if err := onDelta(delta); err != nil {
|
|
return full.String(), err
|
|
}
|
|
}
|
|
}
|
|
|
|
c.logger.Debug("received OpenAI stream completion",
|
|
zap.Duration("duration", time.Since(requestStart)),
|
|
zap.Int("contentLen", full.Len()),
|
|
)
|
|
|
|
return full.String(), nil
|
|
}
|
|
|
|
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
|
|
type StreamToolCall struct {
|
|
Index int
|
|
ID string
|
|
Type string
|
|
FunctionName string
|
|
FunctionArgsStr string
|
|
}
|
|
|
|
// ChatCompletionStreamWithToolCalls 流式模式:同时把 content delta 实时回调,并在结束后返回 tool_calls 和 finish_reason。
|
|
func (c *Client) ChatCompletionStreamWithToolCalls(
|
|
ctx context.Context,
|
|
payload interface{},
|
|
onContentDelta func(delta string) error,
|
|
) (string, []StreamToolCall, string, error) {
|
|
if c == nil {
|
|
return "", nil, "", fmt.Errorf("openai client is not initialized")
|
|
}
|
|
if c.config == nil {
|
|
return "", nil, "", fmt.Errorf("openai config is nil")
|
|
}
|
|
if strings.TrimSpace(c.config.APIKey) == "" {
|
|
return "", nil, "", fmt.Errorf("openai api key is empty")
|
|
}
|
|
if c.isClaude() {
|
|
return c.claudeChatCompletionStreamWithToolCalls(ctx, payload, onContentDelta)
|
|
}
|
|
|
|
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
|
if baseURL == "" {
|
|
baseURL = "https://api.openai.com/v1"
|
|
}
|
|
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", nil, "", fmt.Errorf("marshal openai payload: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
|
if err != nil {
|
|
return "", nil, "", fmt.Errorf("build openai request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
|
|
|
requestStart := time.Now()
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return "", nil, "", fmt.Errorf("call openai api: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return "", nil, "", &APIError{
|
|
StatusCode: resp.StatusCode,
|
|
Body: string(respBody),
|
|
}
|
|
}
|
|
|
|
// delta tool_calls 的增量结构
|
|
type toolCallFunctionDelta struct {
|
|
Name string `json:"name,omitempty"`
|
|
Arguments string `json:"arguments,omitempty"`
|
|
}
|
|
type toolCallDelta struct {
|
|
Index int `json:"index,omitempty"`
|
|
ID string `json:"id,omitempty"`
|
|
Type string `json:"type,omitempty"`
|
|
Function toolCallFunctionDelta `json:"function,omitempty"`
|
|
}
|
|
type streamDelta2 struct {
|
|
Content string `json:"content,omitempty"`
|
|
Text string `json:"text,omitempty"`
|
|
ToolCalls []toolCallDelta `json:"tool_calls,omitempty"`
|
|
}
|
|
type streamChoice2 struct {
|
|
Delta streamDelta2 `json:"delta"`
|
|
FinishReason *string `json:"finish_reason,omitempty"`
|
|
}
|
|
type streamResponse2 struct {
|
|
Choices []streamChoice2 `json:"choices"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
type toolCallAccum struct {
|
|
id string
|
|
typ string
|
|
name string
|
|
args strings.Builder
|
|
}
|
|
toolCallAccums := make(map[int]*toolCallAccum)
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
var full strings.Builder
|
|
finishReason := ""
|
|
|
|
for {
|
|
line, readErr := reader.ReadString('\n')
|
|
if readErr != nil {
|
|
if readErr == io.EOF {
|
|
break
|
|
}
|
|
return full.String(), nil, finishReason, fmt.Errorf("read openai stream: %w", readErr)
|
|
}
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed == "" {
|
|
continue
|
|
}
|
|
if !strings.HasPrefix(trimmed, "data:") {
|
|
continue
|
|
}
|
|
dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
|
if dataStr == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chunk streamResponse2
|
|
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
|
|
// 兼容:解析失败跳过
|
|
continue
|
|
}
|
|
if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" {
|
|
return full.String(), nil, finishReason, fmt.Errorf("openai stream error: %s", chunk.Error.Message)
|
|
}
|
|
if len(chunk.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
choice := chunk.Choices[0]
|
|
if choice.FinishReason != nil && strings.TrimSpace(*choice.FinishReason) != "" {
|
|
finishReason = strings.TrimSpace(*choice.FinishReason)
|
|
}
|
|
|
|
delta := choice.Delta
|
|
|
|
content := delta.Content
|
|
if content == "" {
|
|
content = delta.Text
|
|
}
|
|
if content != "" {
|
|
full.WriteString(content)
|
|
if onContentDelta != nil {
|
|
if err := onContentDelta(content); err != nil {
|
|
return full.String(), nil, finishReason, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(delta.ToolCalls) > 0 {
|
|
for _, tc := range delta.ToolCalls {
|
|
acc, ok := toolCallAccums[tc.Index]
|
|
if !ok {
|
|
acc = &toolCallAccum{}
|
|
toolCallAccums[tc.Index] = acc
|
|
}
|
|
if tc.ID != "" {
|
|
acc.id = tc.ID
|
|
}
|
|
if tc.Type != "" {
|
|
acc.typ = tc.Type
|
|
}
|
|
if tc.Function.Name != "" {
|
|
acc.name = tc.Function.Name
|
|
}
|
|
if tc.Function.Arguments != "" {
|
|
acc.args.WriteString(tc.Function.Arguments)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 组装 tool calls
|
|
indices := make([]int, 0, len(toolCallAccums))
|
|
for idx := range toolCallAccums {
|
|
indices = append(indices, idx)
|
|
}
|
|
// 手写简单排序(避免额外 import)
|
|
for i := 0; i < len(indices); i++ {
|
|
for j := i + 1; j < len(indices); j++ {
|
|
if indices[j] < indices[i] {
|
|
indices[i], indices[j] = indices[j], indices[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
toolCalls := make([]StreamToolCall, 0, len(indices))
|
|
for _, idx := range indices {
|
|
acc := toolCallAccums[idx]
|
|
tc := StreamToolCall{
|
|
Index: idx,
|
|
ID: acc.id,
|
|
Type: acc.typ,
|
|
FunctionName: acc.name,
|
|
FunctionArgsStr: acc.args.String(),
|
|
}
|
|
toolCalls = append(toolCalls, tc)
|
|
}
|
|
|
|
c.logger.Debug("received OpenAI stream completion (tool_calls)",
|
|
zap.Duration("duration", time.Since(requestStart)),
|
|
zap.Int("contentLen", full.Len()),
|
|
zap.Int("toolCalls", len(toolCalls)),
|
|
zap.String("finishReason", finishReason),
|
|
)
|
|
|
|
if strings.TrimSpace(finishReason) == "" {
|
|
finishReason = "stop"
|
|
}
|
|
|
|
return full.String(), toolCalls, finishReason, nil
|
|
}
|