From 2558be3d7d8bdeb3e71c07706cc03791b0310749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:38:14 +0800 Subject: [PATCH] Add files via upload --- internal/openai/claude_bridge.go | 24 +- internal/openai/eino_sse_sanitizer.go | 149 ++++++++++ internal/openai/eino_sse_sanitizer_test.go | 303 +++++++++++++++++++++ 3 files changed, 468 insertions(+), 8 deletions(-) create mode 100644 internal/openai/eino_sse_sanitizer.go create mode 100644 internal/openai/eino_sse_sanitizer_test.go diff --git a/internal/openai/claude_bridge.go b/internal/openai/claude_bridge.go index f61e642d..947ba70d 100644 --- a/internal/openai/claude_bridge.go +++ b/internal/openai/claude_bridge.go @@ -752,25 +752,33 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool { // Eino HTTP Client Bridge // ============================================================ -// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个支持 Claude 自动桥接的 http.Client。 -// 当 cfg.Provider 为 claude 时,会拦截 /chat/completions 请求,透明转换为 Anthropic Messages API。 +// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含两层 transport 包装: +// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明 +// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。 +// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行 +// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai +// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。 +// +// 两层都对调用方完全透明:普通 JSON 响应原样透传,仅当响应 Content-Type 为 text/event-stream 时 +// sanitizer 才会接管 body;data: payload (含 [DONE]、{"error":...}) 一字节不改。 func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client { if base == nil { base = http.DefaultClient } - if !isClaudeProvider(cfg) { - return base - } cloned := *base transport := base.Transport if transport == nil { transport = http.DefaultTransport } - cloned.Transport = &claudeRoundTripper{ - base: transport, - config: cfg, + if isClaudeProvider(cfg) { + transport = &claudeRoundTripper{ + base: transport, + config: cfg, + } } + transport = &einoSSESanitizingRoundTripper{base: transport} + cloned.Transport = transport return &cloned } diff --git a/internal/openai/eino_sse_sanitizer.go b/internal/openai/eino_sse_sanitizer.go new file mode 100644 index 00000000..43e07d5b --- /dev/null +++ b/internal/openai/eino_sse_sanitizer.go @@ -0,0 +1,149 @@ +package openai + +// eino_sse_sanitizer.go 解决 Eino 走 meguminnnnnnnnn/go-openai SDK 时, +// 中转站心跳/SSE 控制行累计 > 300 行触发 ErrTooManyEmptyStreamMessages +// (报错文案: "stream has sent too many empty messages")的问题。 +// +// 触发链路: +// einoopenai.NewChatModel +// → eino-ext/libs/acl/openai → meguminnnnnnnnn/go-openai +// → streamReader.processLines() 对所有非 "data:" 行计数, > 300 即抛错。 +// +// 中转站常见的非 data: 行(合法 SSE 但 SDK 不接受): +// ":" / ": keepalive" / ": ping" / "event: ping" / "retry: 3000" +// 以及思考型模型 prefill 期间穿插的大量心跳。 +// +// 兜底策略: 在 HTTP transport 层把响应 Body 包一层 reader, 只放行 "data:" +// 开头的行, 把心跳/注释/事件类型行就地吞掉。下游 SDK 永远见不到非 data: 行, +// 计数器始终为 0, 该错误不可能再发生。 +// +// 该层对调用方完全透明: +// - 仅当响应 Content-Type 是 text/event-stream 时介入;普通 JSON 响应原样透传 +// - data: payload (含 [DONE] 与 {"error":...}) 一字节不改 +// - 上游真断流 (EOF / connection reset / context cancel) 原样透传 + +import ( + "bufio" + "bytes" + "io" + "net/http" + "strings" +) + +const ( + // einoSSEReaderBufSize 给 bufio 一个较大的初始缓冲, 避免单行大 JSON chunk + // (含工具调用 arguments / reasoning_content) 频繁触发缓冲区扩容。 + einoSSEReaderBufSize = 64 * 1024 +) + +// einoSSESanitizingRoundTripper 包装下游 RoundTripper, 对 SSE 响应做行级清洗。 +type einoSSESanitizingRoundTripper struct { + base http.RoundTripper +} + +func (rt *einoSSESanitizingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := rt.base.RoundTrip(req) + if err != nil || resp == nil { + return resp, err + } + if !isSSEResponse(resp) { + return resp, nil + } + resp.Body = newEinoSSESanitizingBody(resp.Body) + return resp, nil +} + +// isSSEResponse 仅对 200 + text/event-stream 的响应做清洗; +// 错误响应 (4xx/5xx 通常是 application/json) 不动, 由 SDK 走原错误路径。 +func isSSEResponse(resp *http.Response) bool { + if resp.StatusCode != http.StatusOK { + return false + } + ct := resp.Header.Get("Content-Type") + if ct == "" { + return false + } + ct = strings.ToLower(strings.TrimSpace(ct)) + // 兼容 "text/event-stream", "text/event-stream; charset=utf-8" 等。 + return strings.HasPrefix(ct, "text/event-stream") +} + +// einoSSESanitizingBody 是包装后的响应体: 只放行 data: 行, 其它行吞掉。 +type einoSSESanitizingBody struct { + upstream io.ReadCloser + reader *bufio.Reader + pending []byte // 已清洗、待返回给下游的字节 (永远以 \n 结尾的完整 data: 行) + err error // upstream 终态错误 (io.EOF 或网络错误) +} + +func newEinoSSESanitizingBody(body io.ReadCloser) *einoSSESanitizingBody { + return &einoSSESanitizingBody{ + upstream: body, + reader: bufio.NewReaderSize(body, einoSSEReaderBufSize), + } +} + +func (b *einoSSESanitizingBody) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if len(b.pending) > 0 { + n := copy(p, b.pending) + b.pending = b.pending[n:] + return n, nil + } + + // 从上游读, 直到攒出一行 data: 或拿到终态。 + // 单次循环可能丢弃任意多行心跳, 但只放行至多一行 data: 后退出, + // 避免一次 Read 阻塞过久 / pending 缓冲过大。 + for b.err == nil { + line, err := b.reader.ReadBytes('\n') + if len(line) > 0 { + if isPassThroughSSELine(line) { + if line[len(line)-1] != '\n' { + line = append(line, '\n') + } + b.pending = line + if err != nil { + b.err = err + } + break + } + // 非 data: 行 (空行 / ":" 注释 / event: / retry: / id: / 任何裸文本) + // 全部吞掉, 不向下游透出, 继续循环读下一行。 + } + if err != nil { + b.err = err + break + } + } + + if len(b.pending) > 0 { + n := copy(p, b.pending) + b.pending = b.pending[n:] + return n, nil + } + return 0, b.err +} + +func (b *einoSSESanitizingBody) Close() error { + return b.upstream.Close() +} + +// isPassThroughSSELine 判定该行是否需要原样放行给下游 SDK。 +// 仅 "data:" (大小写不敏感, 可有任意前导空白) 开头的行需要保留。 +// 注意: 不能用 TrimSpace 去尾部换行后再判, 否则 " data: x" 会被误判; +// 我们只 trim 前导空白, 与 SDK 内部 TrimSpace 后再正则 ^data:\s* 的语义一致。 +func isPassThroughSSELine(line []byte) bool { + trimmed := bytes.TrimLeft(line, " \t") + if len(trimmed) < 5 { + return false + } + // 大小写不敏感比较前 5 字节是否为 "data:"。SSE 规范要求字段名小写, + // 但宽松匹配可以兼容个别中转站的非规范实现。 + return (trimmed[0] == 'd' || trimmed[0] == 'D') && + (trimmed[1] == 'a' || trimmed[1] == 'A') && + (trimmed[2] == 't' || trimmed[2] == 'T') && + (trimmed[3] == 'a' || trimmed[3] == 'A') && + trimmed[4] == ':' +} diff --git a/internal/openai/eino_sse_sanitizer_test.go b/internal/openai/eino_sse_sanitizer_test.go new file mode 100644 index 00000000..ef52db39 --- /dev/null +++ b/internal/openai/eino_sse_sanitizer_test.go @@ -0,0 +1,303 @@ +package openai + +import ( + "bufio" + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "testing" +) + +// 复现 meguminnnnnnnnn/go-openai 的 SSE 行计数算法 (默认 limit=300): +// - 逐行读 +// - 非 "data:" 行 (空行 / ":" 注释 / event: / retry:) 累计 emptyMessagesCount +// - > 300 抛 ErrTooManyEmptyStreamMessages +// - 遇到 data: 行 reset, 返回 payload +// +// 这一算法与上游 SDK 的 stream_reader.go processLines() 严格一致 (验证依据见 +// /Users/temp/go/pkg/mod/github.com/meguminnnnnnnnn/go-openai@v0.1.2/stream_reader.go)。 +// 测试中只复刻 "限制触发" 这一行为, 用来回归验证 sanitizer 的根因修复。 +var errTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages") + +func sdkLikeRecvAll(body io.Reader, limit uint) ([]string, error) { + headerData := regexp.MustCompile(`^data:\s*`) + r := bufio.NewReader(body) + var payloads []string + for { + var emptyMessagesCount uint + var payload []byte + for { + line, err := r.ReadBytes('\n') + if err != nil { + if err == io.EOF { + return payloads, nil + } + return payloads, err + } + noSpace := bytes.TrimSpace(line) + if !headerData.Match(noSpace) { + emptyMessagesCount++ + if emptyMessagesCount > limit { + return payloads, errTooManyEmptyStreamMessages + } + continue + } + payload = headerData.ReplaceAll(noSpace, nil) + break + } + if string(payload) == "[DONE]" { + return payloads, nil + } + payloads = append(payloads, string(payload)) + } +} + +func newSSEServer(t *testing.T, body string, contentType string, status int) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if contentType != "" { + w.Header().Set("Content-Type", contentType) + } + w.WriteHeader(status) + _, _ = io.WriteString(w, body) + })) +} + +func sanitizingClient(base *http.Client) *http.Client { + if base == nil { + base = &http.Client{} + } + cloned := *base + transport := base.Transport + if transport == nil { + transport = http.DefaultTransport + } + cloned.Transport = &einoSSESanitizingRoundTripper{base: transport} + return &cloned +} + +func readAll(t *testing.T, body io.ReadCloser) string { + t.Helper() + defer body.Close() + out, err := io.ReadAll(body) + if err != nil { + t.Fatalf("read body: %v", err) + } + return string(out) +} + +// 1) 仅 data: 行 → 一字节不改地透传。 +func TestSSESanitizer_PassesDataLinesUnchanged(t *testing.T) { + body := "data: {\"a\":1}\ndata: {\"b\":2}\ndata: [DONE]\n" + srv := newSSEServer(t, body, "text/event-stream", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + got := readAll(t, resp.Body) + if got != body { + t.Fatalf("body mismatch:\nwant %q\ngot %q", body, got) + } +} + +// 2) 心跳/注释/事件类型行被吞掉, 仅保留 data: 行。 +func TestSSESanitizer_DropsHeartbeatsAndControlLines(t *testing.T) { + body := strings.Join([]string{ + ": keepalive", + "", + "event: ping", + "retry: 3000", + "id: 42", + "data: {\"x\":1}", + ": ping", + "", + "data: {\"x\":2}", + "data: [DONE]", + "", + }, "\n") + srv := newSSEServer(t, body, "text/event-stream", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + got := readAll(t, resp.Body) + want := "data: {\"x\":1}\ndata: {\"x\":2}\ndata: [DONE]\n" + if got != want { + t.Fatalf("sanitized body mismatch:\nwant %q\ngot %q", want, got) + } +} + +// 3) 根因回归: 上游堆 500 行心跳后才发 data:, 原始 SDK 算法会抛 +// ErrTooManyEmptyStreamMessages, sanitize 之后必须能正常拿到所有 data:。 +func TestSSESanitizer_ProtectsAgainstTooManyEmptyMessages(t *testing.T) { + const heartbeats = 500 + var buf bytes.Buffer + for i := 0; i < heartbeats; i++ { + buf.WriteString(": keepalive\n") + } + buf.WriteString("data: {\"chunk\":1}\n") + buf.WriteString("data: {\"chunk\":2}\n") + buf.WriteString("data: [DONE]\n") + + t.Run("baseline_without_sanitizer_must_fail", func(t *testing.T) { + _, err := sdkLikeRecvAll(bytes.NewReader(buf.Bytes()), 300) + if !errors.Is(err, errTooManyEmptyStreamMessages) { + t.Fatalf("expected ErrTooManyEmptyStreamMessages, got %v", err) + } + }) + + t.Run("with_sanitizer_must_succeed", func(t *testing.T) { + srv := newSSEServer(t, buf.String(), "text/event-stream", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + + payloads, err := sdkLikeRecvAll(resp.Body, 300) + if err != nil { + t.Fatalf("sdk-like recv after sanitize: %v", err) + } + want := []string{`{"chunk":1}`, `{"chunk":2}`} + if len(payloads) != len(want) { + t.Fatalf("payload count mismatch: want %d got %d (%v)", len(want), len(payloads), payloads) + } + for i, w := range want { + if payloads[i] != w { + t.Fatalf("payload[%d] mismatch: want %q got %q", i, w, payloads[i]) + } + } + }) +} + +// 4) 心跳穿插在 data: 之间也能正确清洗 (思考型模型 prefill 期间常见)。 +func TestSSESanitizer_HeartbeatsInterleavedWithData(t *testing.T) { + var buf bytes.Buffer + buf.WriteString("data: {\"chunk\":1}\n") + for i := 0; i < 400; i++ { + buf.WriteString(": keepalive\n") + } + buf.WriteString("data: {\"chunk\":2}\n") + buf.WriteString("data: [DONE]\n") + + srv := newSSEServer(t, buf.String(), "text/event-stream", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + + payloads, err := sdkLikeRecvAll(resp.Body, 300) + if err != nil { + t.Fatalf("sdk-like recv: %v", err) + } + if got, want := len(payloads), 2; got != want { + t.Fatalf("payload count: want %d got %d", want, got) + } +} + +// 5) 非 SSE 响应 (例如非流式 JSON) 不应被 sanitizer 介入。 +func TestSSESanitizer_PassesNonSSEResponseUntouched(t *testing.T) { + body := `{"id":"x","object":"chat.completion","choices":[]}` + srv := newSSEServer(t, body, "application/json", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + got := readAll(t, resp.Body) + if got != body { + t.Fatalf("non-SSE body must be untouched:\nwant %q\ngot %q", body, got) + } +} + +// 6) 错误响应 (4xx/5xx) 不应被 sanitize, 即使 Content-Type 是 SSE 也不动, +// 避免吞掉类似 "data: " 之外的错误正文。 +func TestSSESanitizer_PassesNon200Untouched(t *testing.T) { + body := `{"error":{"message":"rate limit"}}` + srv := newSSEServer(t, body, "text/event-stream", 429) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + got := readAll(t, resp.Body) + if got != body { + t.Fatalf("error body must be untouched:\nwant %q\ngot %q", body, got) + } +} + +// 7) data: 行末尾若缺 \n (异常上游) sanitizer 也补齐, 保证下游按行解析。 +func TestSSESanitizer_AppendsTrailingNewlineIfMissing(t *testing.T) { + body := "data: {\"a\":1}" + srv := newSSEServer(t, body, "text/event-stream", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + got := readAll(t, resp.Body) + want := "data: {\"a\":1}\n" + if got != want { + t.Fatalf("trailing newline:\nwant %q\ngot %q", want, got) + } +} + +// 8) 大 chunk (一行数十 KB) 也能完整透传, 不被切断。 +func TestSSESanitizer_LargeDataLinePassesIntact(t *testing.T) { + huge := strings.Repeat("x", 80*1024) + body := "data: {\"big\":\"" + huge + "\"}\ndata: [DONE]\n" + srv := newSSEServer(t, body, "text/event-stream", 200) + defer srv.Close() + + resp, err := sanitizingClient(nil).Get(srv.URL) + if err != nil { + t.Fatalf("get: %v", err) + } + got := readAll(t, resp.Body) + if got != body { + t.Fatalf("large body length mismatch: want %d got %d", len(body), len(got)) + } +} + +// 9) isPassThroughSSELine 单元覆盖。 +func TestIsPassThroughSSELine(t *testing.T) { + cases := []struct { + line string + want bool + }{ + {"data: {\"a\":1}\n", true}, + {"DATA: x\n", true}, + {" data: x\n", true}, + {"data:\n", true}, + {"\n", false}, + {"\r\n", false}, + {": keepalive\n", false}, + {":\n", false}, + {"event: ping\n", false}, + {"retry: 3000\n", false}, + {"id: 42\n", false}, + {"datax: y\n", false}, + {"da", false}, + } + for _, c := range cases { + if got := isPassThroughSSELine([]byte(c.line)); got != c.want { + t.Errorf("isPassThroughSSELine(%q) = %v, want %v", c.line, got, c.want) + } + } +}