diff --git a/config.yaml b/config.yaml index 4ea29edc..eab30ef0 100644 --- a/config.yaml +++ b/config.yaml @@ -17,7 +17,7 @@ server: port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080 # 认证配置 auth: - password: # Web 登录密码,请修改为强密码 + password: fWZWQVO5PgtZ49Vo-wYF5q9g # Web 登录密码,请修改为强密码 session_duration_hours: 12 # 登录有效期(小时),超时后需重新登录 # 日志配置 log: @@ -34,7 +34,9 @@ log: # - DeepSeek: https://api.deepseek.com/v1 # - 其他兼容 OpenAI 协议的 API # 常用模型: gpt-4, gpt-3.5-turbo, deepseek-chat, claude-3-opus 等 +# provider: 可选值 openai(默认) | claude(自动桥接到 Anthropic Claude Messages API) openai: + provider: openai # API 提供商: openai(默认,兼容OpenAI协议) | claude(自动桥接到Anthropic Claude Messages API) base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 # API 基础 URL(必填) api_key: sk-xxxxxx # API 密钥(必填) model: qwen3-max # 模型名称(必填) diff --git a/internal/config/config.go b/internal/config/config.go index f5a2f0b7..fac227ae 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -129,6 +129,7 @@ type MCPConfig struct { } type OpenAIConfig struct { + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` // API 提供商: "openai"(默认) 或 "claude",claude 时自动桥接为 Anthropic Messages API APIKey string `yaml:"api_key" json:"api_key"` BaseURL string `yaml:"base_url" json:"base_url"` Model string `yaml:"model" json:"model"` diff --git a/internal/handler/config.go b/internal/handler/config.go index d9f2140c..fffc33b1 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -3,9 +3,7 @@ package handler import ( "bytes" "context" - "encoding/json" "fmt" - "io" "net/http" "os" "path/filepath" @@ -18,6 +16,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/security" "github.com/gin-gonic/gin" @@ -795,9 +794,10 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { // TestOpenAIRequest 测试OpenAI连接请求 type TestOpenAIRequest struct { - BaseURL string `json:"base_url"` - APIKey string `json:"api_key"` - Model string `json:"model"` + Provider string `json:"provider"` + BaseURL string `json:"base_url"` + APIKey string `json:"api_key"` + Model string `json:"model"` } // TestOpenAI 测试OpenAI API连接是否可用 @@ -819,7 +819,11 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) { baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/") if baseURL == "" { - baseURL = "https://api.openai.com/v1" + if strings.EqualFold(strings.TrimSpace(req.Provider), "claude") { + baseURL = "https://api.anthropic.com" + } else { + baseURL = "https://api.openai.com/v1" + } } // 构造一个最小的 chat completion 请求 @@ -831,57 +835,19 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) { "max_tokens": 5, } - body, err := json.Marshal(payload) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "构造请求失败"}) - return + // 使用内部 openai Client 进行测试,若 provider 为 claude 会自动走桥接层 + tmpCfg := &config.OpenAIConfig{ + Provider: req.Provider, + BaseURL: baseURL, + APIKey: strings.TrimSpace(req.APIKey), + Model: req.Model, } + client := openai.NewClient(tmpCfg, nil, h.logger) ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) defer cancel() - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body)) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "构造HTTP请求失败: " + err.Error()}) - return - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(req.APIKey)) - start := time.Now() - resp, err := http.DefaultClient.Do(httpReq) - latency := time.Since(start) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "error": "连接失败: " + err.Error(), - }) - return - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - // 尝试提取错误信息 - var errResp struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } - errMsg := string(respBody) - if json.Unmarshal(respBody, &errResp) == nil && errResp.Error.Message != "" { - errMsg = errResp.Error.Message - } - c.JSON(http.StatusOK, gin.H{ - "success": false, - "error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", resp.StatusCode, errMsg), - "status_code": resp.StatusCode, - }) - return - } - - // 解析响应并严格验证是否为有效的 chat completion 响应 var chatResp struct { ID string `json:"id"` Object string `json:"object"` @@ -893,10 +859,21 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) { } `json:"message"` } `json:"choices"` } - if err := json.Unmarshal(respBody, &chatResp); err != nil { + err := client.ChatCompletion(ctx, payload, &chatResp) + latency := time.Since(start) + + if err != nil { + if apiErr, ok := err.(*openai.APIError); ok { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body), + "status_code": apiErr.StatusCode, + }) + return + } c.JSON(http.StatusOK, gin.H{ "success": false, - "error": "API 响应不是有效的 JSON,请检查 Base URL 是否正确", + "error": "连接失败: " + err.Error(), }) return } @@ -905,14 +882,14 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) { if len(chatResp.Choices) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, - "error": "API 响应缺少 choices 字段,请检查 Base URL 路径是否正确(通常以 /v1 结尾)", + "error": "API 响应缺少 choices 字段,请检查 Base URL 路径是否正确", }) return } if chatResp.ID == "" && chatResp.Model == "" { c.JSON(http.StatusOK, gin.H{ "success": false, - "error": "API 响应格式不符合 OpenAI 规范,请检查 Base URL 是否正确", + "error": "API 响应格式不符合预期,请检查 Base URL 是否正确", }) return } @@ -1272,6 +1249,9 @@ func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) { func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) { root := doc.Content[0] openaiNode := ensureMap(root, "openai") + if cfg.Provider != "" { + setStringInMap(openaiNode, "provider", cfg.Provider) + } setStringInMap(openaiNode, "api_key", cfg.APIKey) setStringInMap(openaiNode, "base_url", cfg.BaseURL) setStringInMap(openaiNode, "model", cfg.Model) diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index cda89530..68864618 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -19,6 +19,7 @@ import ( "cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/einomcp" + "cyberstrike-ai/internal/openai" einoopenai "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/adk" @@ -141,6 +142,9 @@ func RunDeepAgent( }, } + // 若配置为 Claude provider,注入自动桥接 transport,对 Eino 透明走 Anthropic Messages API + httpClient = openai.NewEinoHTTPClient(&appCfg.OpenAI, httpClient) + baseModelCfg := &einoopenai.ChatModelConfig{ APIKey: appCfg.OpenAI.APIKey, BaseURL: strings.TrimSuffix(appCfg.OpenAI.BaseURL, "/"), diff --git a/internal/openai/claude_bridge.go b/internal/openai/claude_bridge.go new file mode 100644 index 00000000..b6e75d51 --- /dev/null +++ b/internal/openai/claude_bridge.go @@ -0,0 +1,1073 @@ +package openai + +// claude_bridge.go 将 OpenAI 格式的请求/响应自动转换为 Anthropic Claude Messages API 格式。 +// 当 config.Provider == "claude" 时,Client 自动走此桥接层,对上层调用方完全透明。 +// +// 转换规则: +// Request: OpenAI /chat/completions → Claude /v1/messages +// Response: Claude /v1/messages → OpenAI /chat/completions 格式 +// Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式 +// Auth: Bearer → x-api-key +// Tools: OpenAI tools[] → Claude tools[] (input_schema) + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + + "go.uber.org/zap" +) + +// ============================================================ +// Claude Request Types +// ============================================================ + +// claudeRequest 表示 Anthropic Messages API 的请求体。 +type claudeRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []claudeMessage `json:"messages"` + Tools []claudeTool `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type claudeMessage struct { + Role string `json:"role"` + Content claudeMessageContent `json:"content"` +} + +// claudeMessageContent 可以是纯字符串或 content block 数组。 +// MarshalJSON / UnmarshalJSON 自动处理两种形式。 +type claudeMessageContent struct { + Text string // 纯文本形式(简写) + Blocks []claudeContentBlock // 多 block 形式(tool_use / tool_result 必须用这种) +} + +func (c claudeMessageContent) MarshalJSON() ([]byte, error) { + if len(c.Blocks) > 0 { + return json.Marshal(c.Blocks) + } + return json.Marshal(c.Text) +} + +func (c *claudeMessageContent) UnmarshalJSON(data []byte) error { + // 尝试字符串 + var s string + if err := json.Unmarshal(data, &s); err == nil { + c.Text = s + return nil + } + // 尝试数组 + return json.Unmarshal(data, &c.Blocks) +} + +type claudeContentBlock struct { + Type string `json:"type"` + + // text block + Text string `json:"text,omitempty"` + + // tool_use block (assistant 返回) + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + + // tool_result block (user 提交) + ToolUseID string `json:"tool_use_id,omitempty"` + Content string `json:"content,omitempty"` + IsError bool `json:"is_error,omitempty"` +} + +type claudeTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"input_schema"` +} + +// ============================================================ +// Claude Response Types +// ============================================================ + +type claudeResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []claudeContentBlock `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage *claudeUsage `json:"usage,omitempty"` + Error *claudeError `json:"error,omitempty"` +} + +type claudeUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type claudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// ============================================================ +// Conversion: OpenAI Request → Claude Request +// ============================================================ + +// convertOpenAIToClaude 将任意 OpenAI payload (map 或 struct) 转换为 claudeRequest。 +func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) { + // 先统一序列化为 JSON,再以 map 反序列化,方便处理各种输入形式 + raw, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("claude bridge: marshal payload: %w", err) + } + + var oai map[string]interface{} + if err := json.Unmarshal(raw, &oai); err != nil { + return nil, fmt.Errorf("claude bridge: unmarshal payload: %w", err) + } + + req := &claudeRequest{} + + // model + if m, ok := oai["model"].(string); ok { + req.Model = m + } + + // max_tokens (Claude 必需) + if mt, ok := oai["max_tokens"].(float64); ok && mt > 0 { + req.MaxTokens = int(mt) + } else { + req.MaxTokens = 8192 // Claude 默认最大输出(兼容 Haiku/Sonnet/Opus) + } + + // stream + if s, ok := oai["stream"].(bool); ok { + req.Stream = s + } + + // messages + msgs, _ := oai["messages"].([]interface{}) + for i := 0; i < len(msgs); i++ { + mm, ok := msgs[i].(map[string]interface{}) + if !ok { + continue + } + role, _ := mm["role"].(string) + content, _ := mm["content"].(string) + + // system message → 提取到顶级 system 字段 + if role == "system" { + if req.System != "" { + req.System += "\n\n" + } + req.System += content + continue + } + + // tool_calls (assistant 消息中包含工具调用) + if role == "assistant" { + var blocks []claudeContentBlock + if content != "" { + blocks = append(blocks, claudeContentBlock{Type: "text", Text: content}) + } + + if tcs, ok := mm["tool_calls"].([]interface{}); ok { + for _, tc := range tcs { + tcMap, ok := tc.(map[string]interface{}) + if !ok { + continue + } + tcID, _ := tcMap["id"].(string) + fn, _ := tcMap["function"].(map[string]interface{}) + fnName, _ := fn["name"].(string) + fnArgs, _ := fn["arguments"] + + // 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝 + if strings.TrimSpace(fnName) == "" { + fnName = "unknown_function" + } + if strings.TrimSpace(tcID) == "" { + tcID = fmt.Sprintf("call_%d", time.Now().UnixNano()) + } + + var inputRaw json.RawMessage + switch v := fnArgs.(type) { + case string: + inputRaw = json.RawMessage(v) + default: + inputRaw, _ = json.Marshal(v) + } + // 防止空字符串/非法 JSON 导致 Marshal 失败 + if len(inputRaw) == 0 || !json.Valid(inputRaw) { + inputRaw = json.RawMessage("{}") + } + blocks = append(blocks, claudeContentBlock{ + Type: "tool_use", + ID: tcID, + Name: fnName, + Input: inputRaw, + }) + } + } + + if len(blocks) > 0 { + req.Messages = append(req.Messages, claudeMessage{ + Role: "assistant", + Content: claudeMessageContent{Blocks: blocks}, + }) + } + continue + } + + // tool result (role == "tool" in OpenAI) + // Claude 要求同一轮的多个 tool_result 合并为一个 user 消息(多 block), + // 否则违反 user/assistant 交替规则。 + if role == "tool" { + var toolBlocks []claudeContentBlock + // 收集当前及后续连续的 tool 消息 + for ; i < len(msgs); i++ { + tmm, ok := msgs[i].(map[string]interface{}) + if !ok { + break + } + tr, _ := tmm["role"].(string) + if tr != "tool" { + break + } + tcID, _ := tmm["tool_call_id"].(string) + tcContent, _ := tmm["content"].(string) + toolBlocks = append(toolBlocks, claudeContentBlock{ + Type: "tool_result", + ToolUseID: tcID, + Content: tcContent, + }) + } + i-- // 外层 for 会 i++,回退一步 + req.Messages = append(req.Messages, claudeMessage{ + Role: "user", + Content: claudeMessageContent{Blocks: toolBlocks}, + }) + continue + } + + // 普通 user/assistant 消息 + req.Messages = append(req.Messages, claudeMessage{ + Role: role, + Content: claudeMessageContent{Text: content}, + }) + } + + // tools + if tools, ok := oai["tools"].([]interface{}); ok { + for _, t := range tools { + tMap, ok := t.(map[string]interface{}) + if !ok { + continue + } + fn, ok := tMap["function"].(map[string]interface{}) + if !ok { + continue + } + ct := claudeTool{} + ct.Name, _ = fn["name"].(string) + ct.Description, _ = fn["description"].(string) + if params, ok := fn["parameters"].(map[string]interface{}); ok { + ct.InputSchema = params + } else { + ct.InputSchema = map[string]interface{}{"type": "object", "properties": map[string]interface{}{}} + } + req.Tools = append(req.Tools, ct) + } + } + + return req, nil +} + +// ============================================================ +// Conversion: Claude Response → OpenAI Response (non-streaming) +// ============================================================ + +// claudeToOpenAIResponseJSON 将 Claude 响应 JSON 转为 OpenAI 兼容的 JSON。 +func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) { + var cr claudeResponse + if err := json.Unmarshal(claudeBody, &cr); err != nil { + return nil, fmt.Errorf("claude bridge: unmarshal response: %w", err) + } + + if cr.Error != nil { + return nil, fmt.Errorf("claude api error: [%s] %s", cr.Error.Type, cr.Error.Message) + } + + // 构建 OpenAI 格式的 response + oaiResp := map[string]interface{}{ + "id": cr.ID, + "object": "chat.completion", + "model": cr.Model, + "choices": []interface{}{}, + } + + var textContent string + var toolCalls []interface{} + + for _, block := range cr.Content { + switch block.Type { + case "text": + textContent += block.Text + case "tool_use": + argsStr := string(block.Input) + toolCalls = append(toolCalls, map[string]interface{}{ + "id": block.ID, + "type": "function", + "function": map[string]interface{}{ + "name": block.Name, + "arguments": argsStr, + }, + }) + } + } + + finishReason := claudeStopReasonToOpenAI(cr.StopReason) + message := map[string]interface{}{ + "role": "assistant", + "content": textContent, + } + if len(toolCalls) > 0 { + message["tool_calls"] = toolCalls + } + + choice := map[string]interface{}{ + "index": 0, + "message": message, + "finish_reason": finishReason, + } + + oaiResp["choices"] = []interface{}{choice} + + if cr.Usage != nil { + oaiResp["usage"] = map[string]interface{}{ + "prompt_tokens": cr.Usage.InputTokens, + "completion_tokens": cr.Usage.OutputTokens, + "total_tokens": cr.Usage.InputTokens + cr.Usage.OutputTokens, + } + } + + return json.Marshal(oaiResp) +} + +func claudeStopReasonToOpenAI(reason string) string { + switch reason { + case "end_turn": + return "stop" + case "tool_use": + return "tool_calls" + case "max_tokens": + return "length" + case "stop_sequence": + return "stop" + default: + return "stop" + } +} + +// ============================================================ +// Claude HTTP Calls (non-streaming & streaming) +// ============================================================ + +// claudeChatCompletion 执行非流式 Claude API 调用,返回转换后的 OpenAI 格式 JSON。 +func (c *Client) claudeChatCompletion(ctx context.Context, payload interface{}, out interface{}) error { + claudeReq, err := convertOpenAIToClaude(payload) + if err != nil { + return err + } + claudeReq.Stream = false + + body, err := json.Marshal(claudeReq) + if err != nil { + return fmt.Errorf("claude bridge: marshal: %w", err) + } + + baseURL := strings.TrimSuffix(c.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + + c.logger.Debug("sending Claude chat completion request", + zap.String("model", claudeReq.Model), + zap.Int("payloadSizeKB", len(body)/1024)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/messages", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("claude bridge: build request: %w", err) + } + c.setClaudeHeaders(req) + + requestStart := time.Now() + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("claude bridge: call api: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("claude bridge: read response: %w", err) + } + + c.logger.Debug("received Claude response", + zap.Int("status", resp.StatusCode), + zap.Duration("duration", time.Since(requestStart)), + zap.Int("responseSizeKB", len(respBody)/1024), + ) + + if resp.StatusCode != http.StatusOK { + c.logger.Warn("Claude chat completion returned non-200", + zap.Int("status", resp.StatusCode), + zap.String("body", string(respBody)), + ) + return &APIError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + } + } + + // 转换为 OpenAI 格式 + oaiJSON, err := claudeToOpenAIResponseJSON(respBody) + if err != nil { + return err + } + + if out != nil { + if err := json.Unmarshal(oaiJSON, out); err != nil { + return fmt.Errorf("claude bridge: unmarshal converted response: %w", err) + } + } + + return nil +} + +// claudeChatCompletionStream 流式调用 Claude API,将 Claude SSE 转换为 OpenAI 兼容的 delta 回调。 +func (c *Client) claudeChatCompletionStream(ctx context.Context, payload interface{}, onDelta func(delta string) error) (string, error) { + claudeReq, err := convertOpenAIToClaude(payload) + if err != nil { + return "", err + } + claudeReq.Stream = true + + body, err := json.Marshal(claudeReq) + if err != nil { + return "", fmt.Errorf("claude bridge: marshal: %w", err) + } + + baseURL := strings.TrimSuffix(c.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/messages", bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("claude bridge: build request: %w", err) + } + c.setClaudeHeaders(req) + + requestStart := time.Now() + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("claude bridge: call api: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", &APIError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + } + } + + reader := bufio.NewReader(resp.Body) + var full strings.Builder + + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + if readErr == io.EOF { + break + } + return full.String(), fmt.Errorf("claude bridge: read stream: %w", readErr) + } + trimmed := strings.TrimSpace(line) + if trimmed == "" || !strings.HasPrefix(trimmed, "data:") { + continue + } + dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if dataStr == "[DONE]" { + break + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + eventType, _ := event["type"].(string) + + switch eventType { + case "content_block_delta": + delta, _ := event["delta"].(map[string]interface{}) + deltaType, _ := delta["type"].(string) + if deltaType == "text_delta" { + text, _ := delta["text"].(string) + if text != "" { + full.WriteString(text) + if onDelta != nil { + if err := onDelta(text); err != nil { + return full.String(), err + } + } + } + } + case "error": + errData, _ := event["error"].(map[string]interface{}) + msg, _ := errData["message"].(string) + return full.String(), fmt.Errorf("claude stream error: %s", msg) + } + } + + c.logger.Debug("received Claude stream completion", + zap.Duration("duration", time.Since(requestStart)), + zap.Int("contentLen", full.Len()), + ) + + return full.String(), nil +} + +// claudeChatCompletionStreamWithToolCalls 流式调用 Claude API,同时处理 content delta 和 tool_calls, +// 返回值与 OpenAI 版本完全一致:(content, toolCalls, finishReason, error)。 +func (c *Client) claudeChatCompletionStreamWithToolCalls( + ctx context.Context, + payload interface{}, + onContentDelta func(delta string) error, +) (string, []StreamToolCall, string, error) { + claudeReq, err := convertOpenAIToClaude(payload) + if err != nil { + return "", nil, "", err + } + claudeReq.Stream = true + + body, err := json.Marshal(claudeReq) + if err != nil { + return "", nil, "", fmt.Errorf("claude bridge: marshal: %w", err) + } + + baseURL := strings.TrimSuffix(c.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/messages", bytes.NewReader(body)) + if err != nil { + return "", nil, "", fmt.Errorf("claude bridge: build request: %w", err) + } + c.setClaudeHeaders(req) + + requestStart := time.Now() + resp, err := c.httpClient.Do(req) + if err != nil { + return "", nil, "", fmt.Errorf("claude bridge: call api: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", nil, "", &APIError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + } + } + + reader := bufio.NewReader(resp.Body) + var full strings.Builder + finishReason := "" + + // 追踪当前正在构建的 content blocks + type toolAccum struct { + id string + name string + args strings.Builder + index int + } + var currentToolCalls []toolAccum + currentBlockIndex := -1 + currentBlockType := "" + + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + if readErr == io.EOF { + break + } + return full.String(), nil, finishReason, fmt.Errorf("claude bridge: read stream: %w", readErr) + } + trimmed := strings.TrimSpace(line) + if trimmed == "" || !strings.HasPrefix(trimmed, "data:") { + continue + } + dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if dataStr == "[DONE]" { + break + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + eventType, _ := event["type"].(string) + + switch eventType { + case "content_block_start": + idx, _ := event["index"].(float64) + currentBlockIndex = int(idx) + cb, _ := event["content_block"].(map[string]interface{}) + blockType, _ := cb["type"].(string) + currentBlockType = blockType + + if blockType == "tool_use" { + id, _ := cb["id"].(string) + name, _ := cb["name"].(string) + currentToolCalls = append(currentToolCalls, toolAccum{ + id: id, + name: name, + index: currentBlockIndex, + }) + } + + case "content_block_delta": + delta, _ := event["delta"].(map[string]interface{}) + deltaType, _ := delta["type"].(string) + + if deltaType == "text_delta" { + text, _ := delta["text"].(string) + if text != "" { + full.WriteString(text) + if onContentDelta != nil { + if err := onContentDelta(text); err != nil { + return full.String(), nil, finishReason, err + } + } + } + } else if deltaType == "input_json_delta" { + partialJSON, _ := delta["partial_json"].(string) + if partialJSON != "" && currentBlockType == "tool_use" && len(currentToolCalls) > 0 { + currentToolCalls[len(currentToolCalls)-1].args.WriteString(partialJSON) + } + } + + case "content_block_stop": + // block 完成,不需要特殊处理 + + case "message_delta": + delta, _ := event["delta"].(map[string]interface{}) + if sr, ok := delta["stop_reason"].(string); ok { + finishReason = claudeStopReasonToOpenAI(sr) + } + + case "message_stop": + // 消息完成 + + case "error": + errData, _ := event["error"].(map[string]interface{}) + msg, _ := errData["message"].(string) + return full.String(), nil, finishReason, fmt.Errorf("claude stream error: %s", msg) + } + } + + // 转换 tool calls 为 OpenAI 格式的 StreamToolCall + var toolCalls []StreamToolCall + for i, tc := range currentToolCalls { + toolCalls = append(toolCalls, StreamToolCall{ + Index: i, + ID: tc.id, + Type: "function", + FunctionName: tc.name, + FunctionArgsStr: tc.args.String(), + }) + } + + if finishReason == "" { + finishReason = "stop" + } + + c.logger.Debug("received Claude stream completion (tool_calls)", + zap.Duration("duration", time.Since(requestStart)), + zap.Int("contentLen", full.Len()), + zap.Int("toolCalls", len(toolCalls)), + zap.String("finishReason", finishReason), + ) + + return full.String(), toolCalls, finishReason, nil +} + +// ============================================================ +// Helpers +// ============================================================ + +// setClaudeHeaders 设置 Anthropic API 要求的请求头。 +func (c *Client) setClaudeHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", c.config.APIKey) + req.Header.Set("anthropic-version", "2023-06-01") +} + +// isClaude 判断当前配置是否为 Claude provider。 +func (c *Client) isClaude() bool { + return isClaudeProvider(c.config) +} + +func isClaudeProvider(cfg *config.OpenAIConfig) bool { + if cfg == nil { + return false + } + return strings.EqualFold(strings.TrimSpace(cfg.Provider), "claude") || + strings.EqualFold(strings.TrimSpace(cfg.Provider), "anthropic") +} + +// ============================================================ +// Eino HTTP Client Bridge +// ============================================================ + +// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个支持 Claude 自动桥接的 http.Client。 +// 当 cfg.Provider 为 claude 时,会拦截 /chat/completions 请求,透明转换为 Anthropic Messages API。 +func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client { + if base == nil { + base = http.DefaultClient + } + if !isClaudeProvider(cfg) { + return base + } + + cloned := *base + transport := base.Transport + if transport == nil { + transport = http.DefaultTransport + } + cloned.Transport = &claudeRoundTripper{ + base: transport, + config: cfg, + } + return &cloned +} + +// claudeRoundTripper 是一个 http.RoundTripper,用于将 OpenAI 协议透明桥接到 Claude API。 +type claudeRoundTripper struct { + base http.RoundTripper + config *config.OpenAIConfig +} + +func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // 只拦截 chat completions + if !strings.HasSuffix(req.URL.Path, "/chat/completions") { + return rt.base.RoundTrip(req) + } + + // 读取原请求体 + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("claude bridge: read request body: %w", err) + } + _ = req.Body.Close() + + var payload interface{} + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("claude bridge: unmarshal request: %w", err) + } + + // 转换为 Claude 请求 + claudeReq, err := convertOpenAIToClaude(payload) + if err != nil { + return nil, err + } + + // 构造 Claude 请求 + baseURL := strings.TrimSuffix(rt.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + + claudeBody, err := json.Marshal(claudeReq) + if err != nil { + return nil, fmt.Errorf("claude bridge: marshal claude request: %w", err) + } + + newReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, baseURL+"/v1/messages", bytes.NewReader(claudeBody)) + if err != nil { + return nil, fmt.Errorf("claude bridge: build request: %w", err) + } + newReq.Header.Set("Content-Type", "application/json") + newReq.Header.Set("x-api-key", rt.config.APIKey) + newReq.Header.Set("anthropic-version", "2023-06-01") + + resp, err := rt.base.RoundTrip(newReq) + if err != nil { + return nil, err + } + + // 非 200:尝试把 Claude 错误格式转成 OpenAI 错误格式,便于 Eino 解析 + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() + converted := rt.tryConvertClaudeErrorToOpenAI(bodyBytes) + return &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(converted)), + ContentLength: int64(len(converted)), + Request: req, + }, nil + } + + // 非流式:一次性转换响应体 + if !claudeReq.Stream { + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + oaiJSON, err := claudeToOpenAIResponseJSON(respBody) + if err != nil { + return nil, err + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader(oaiJSON)), + ContentLength: int64(len(oaiJSON)), + Request: req, + }, nil + } + + // 流式:通过 pipe 实时转换 SSE + pr, pw := io.Pipe() + + // writeLine 将数据写入 pipe,返回 false 表示 pipe 已关闭(消费端断开),应立即退出。 + writeLine := func(data string) bool { + _, err := pw.Write([]byte(data)) + return err == nil + } + + go func() { + defer resp.Body.Close() + + reader := bufio.NewReader(resp.Body) + blockToToolIndex := make(map[int]int) + nextToolIndex := 0 + + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + if readErr == io.EOF { + writeLine("data: [DONE]\n\n") + } else { + // 非 EOF 错误:写入错误事件并通知消费端 + oaiErr := map[string]interface{}{ + "error": map[string]interface{}{ + "message": readErr.Error(), + "type": "claude_stream_read_error", + }, + } + b, _ := json.Marshal(oaiErr) + writeLine("data: " + string(b) + "\n\n") + writeLine("data: [DONE]\n\n") + } + pw.Close() + return + } + trimmed := strings.TrimSpace(line) + if trimmed == "" || !strings.HasPrefix(trimmed, "data:") { + continue + } + dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if dataStr == "[DONE]" { + writeLine("data: [DONE]\n\n") + pw.Close() + return + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + eventType, _ := event["type"].(string) + + switch eventType { + case "content_block_start": + blockIdxFlt, _ := event["index"].(float64) + blockIdx := int(blockIdxFlt) + cb, _ := event["content_block"].(map[string]interface{}) + bt, _ := cb["type"].(string) + + if bt == "tool_use" { + id, _ := cb["id"].(string) + name, _ := cb["name"].(string) + blockToToolIndex[blockIdx] = nextToolIndex + toolIdx := nextToolIndex + nextToolIndex++ + + oaiChunk := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{ + { + "index": toolIdx, + "id": id, + "type": "function", + "function": map[string]interface{}{ + "name": name, + }, + }, + }, + }, + }, + }, + } + b, _ := json.Marshal(oaiChunk) + if !writeLine("data: " + string(b) + "\n\n") { + pw.Close() + return + } + } + + case "content_block_delta": + blockIdxFlt, _ := event["index"].(float64) + blockIdx := int(blockIdxFlt) + delta, _ := event["delta"].(map[string]interface{}) + dt, _ := delta["type"].(string) + + if dt == "text_delta" { + text, _ := delta["text"].(string) + oaiChunk := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "delta": map[string]interface{}{ + "content": text, + }, + }, + }, + } + b, _ := json.Marshal(oaiChunk) + if !writeLine("data: " + string(b) + "\n\n") { + pw.Close() + return + } + } else if dt == "input_json_delta" { + partial, _ := delta["partial_json"].(string) + if partial != "" { + if toolIdx, ok := blockToToolIndex[blockIdx]; ok { + oaiChunk := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "delta": map[string]interface{}{ + "tool_calls": []map[string]interface{}{ + { + "index": toolIdx, + "function": map[string]interface{}{ + "arguments": partial, + }, + }, + }, + }, + }, + }, + } + b, _ := json.Marshal(oaiChunk) + if !writeLine("data: " + string(b) + "\n\n") { + pw.Close() + return + } + } + } + } + + case "message_delta": + d, _ := event["delta"].(map[string]interface{}) + if sr, ok := d["stop_reason"].(string); ok { + finishReason := claudeStopReasonToOpenAI(sr) + oaiChunk := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "delta": map[string]interface{}{}, + "finish_reason": finishReason, + }, + }, + } + b, _ := json.Marshal(oaiChunk) + if !writeLine("data: " + string(b) + "\n\n") { + pw.Close() + return + } + } + + case "message_stop": + writeLine("data: [DONE]\n\n") + pw.Close() + return + + case "error": + errData, _ := event["error"].(map[string]interface{}) + msg, _ := errData["message"].(string) + oaiChunk := map[string]interface{}{ + "error": map[string]interface{}{ + "message": msg, + "type": "claude_stream_error", + }, + } + b, _ := json.Marshal(oaiChunk) + writeLine("data: " + string(b) + "\n\n") + writeLine("data: [DONE]\n\n") + pw.Close() + return + } + } + }() + + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: pr, + Request: req, + }, nil +} + +// tryConvertClaudeErrorToOpenAI 尝试把 Claude 错误格式转换为 OpenAI 错误格式 JSON。 +func (rt *claudeRoundTripper) tryConvertClaudeErrorToOpenAI(body []byte) []byte { + var ce struct { + Type string `json:"type"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(body, &ce); err != nil || ce.Error.Message == "" { + return body + } + oaiErr := map[string]interface{}{ + "error": map[string]interface{}{ + "message": ce.Error.Message, + "type": ce.Error.Type, + "code": ce.Type, + }, + } + b, _ := json.Marshal(oaiErr) + return b +} diff --git a/internal/openai/openai.go b/internal/openai/openai.go index 637ee48b..2c675e5f 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -64,6 +64,9 @@ func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out in if strings.TrimSpace(c.config.APIKey) == "" { return fmt.Errorf("openai api key is empty") } + if c.isClaude() { + return c.claudeChatCompletion(ctx, payload, out) + } baseURL := strings.TrimSuffix(c.config.BaseURL, "/") if baseURL == "" { @@ -156,6 +159,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, if strings.TrimSpace(c.config.APIKey) == "" { return "", fmt.Errorf("openai api key is empty") } + if c.isClaude() { + return c.claudeChatCompletionStream(ctx, payload, onDelta) + } baseURL := strings.TrimSuffix(c.config.BaseURL, "/") if baseURL == "" { @@ -294,6 +300,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls( if strings.TrimSpace(c.config.APIKey) == "" { return "", nil, "", fmt.Errorf("openai api key is empty") } + if c.isClaude() { + return c.claudeChatCompletionStreamWithToolCalls(ctx, payload, onContentDelta) + } baseURL := strings.TrimSuffix(c.config.BaseURL, "/") if baseURL == "" { diff --git a/web/static/js/settings.js b/web/static/js/settings.js index 40df9d54..70d4cb8c 100644 --- a/web/static/js/settings.js +++ b/web/static/js/settings.js @@ -102,6 +102,10 @@ async function loadConfig(loadTools = true) { currentConfig = await response.json(); // 填充OpenAI配置 + const providerEl = document.getElementById('openai-provider'); + if (providerEl) { + providerEl.value = currentConfig.openai.provider || 'openai'; + } document.getElementById('openai-api-key').value = currentConfig.openai.api_key || ''; document.getElementById('openai-base-url').value = currentConfig.openai.base_url || ''; document.getElementById('openai-model').value = currentConfig.openai.model || ''; @@ -753,6 +757,7 @@ async function applySettings() { }); // 验证必填字段 + const provider = document.getElementById('openai-provider')?.value || 'openai'; const apiKey = document.getElementById('openai-api-key').value.trim(); const baseUrl = document.getElementById('openai-base-url').value.trim(); const model = document.getElementById('openai-model').value.trim(); @@ -821,6 +826,7 @@ async function applySettings() { const wecomAgentIdVal = document.getElementById('robot-wecom-agent-id')?.value.trim(); const config = { openai: { + provider: provider, api_key: apiKey, base_url: baseUrl, model: model @@ -981,6 +987,7 @@ async function testOpenAIConnection() { const btn = document.getElementById('test-openai-btn'); const resultEl = document.getElementById('test-openai-result'); + const provider = document.getElementById('openai-provider')?.value || 'openai'; const baseUrl = document.getElementById('openai-base-url').value.trim(); const apiKey = document.getElementById('openai-api-key').value.trim(); const model = document.getElementById('openai-model').value.trim(); @@ -1001,6 +1008,7 @@ async function testOpenAIConnection() { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ + provider: provider, base_url: baseUrl, api_key: apiKey, model: model diff --git a/web/templates/index.html b/web/templates/index.html index ae19629c..7f5825a5 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -1365,6 +1365,13 @@

OpenAI 配置

+
+ + +