Add files via upload

This commit is contained in:
公明
2026-04-29 22:38:14 +08:00
committed by GitHub
parent f6bb455313
commit 2558be3d7d
3 changed files with 468 additions and 8 deletions
+16 -8
View File
@@ -752,25 +752,33 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool {
// Eino HTTP Client Bridge
// ============================================================
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个支持 Claude 自动桥接的 http.Client。
// 当 cfg.Provider 为 claude 时,会拦截 /chat/completions 请求,透明转换为 Anthropic Messages API。
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含两层 transport 包装:
// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
//
// 两层都对调用方完全透明:普通 JSON 响应原样透传,仅当响应 Content-Type 为 text/event-stream 时
// sanitizer 才会接管 bodydata: payload (含 [DONE]、{"error":...}) 一字节不改。
func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client {
if base == nil {
base = http.DefaultClient
}
if !isClaudeProvider(cfg) {
return base
}
cloned := *base
transport := base.Transport
if transport == nil {
transport = http.DefaultTransport
}
cloned.Transport = &claudeRoundTripper{
base: transport,
config: cfg,
if isClaudeProvider(cfg) {
transport = &claudeRoundTripper{
base: transport,
config: cfg,
}
}
transport = &einoSSESanitizingRoundTripper{base: transport}
cloned.Transport = transport
return &cloned
}
+149
View File
@@ -0,0 +1,149 @@
package openai
// eino_sse_sanitizer.go 解决 Eino 走 meguminnnnnnnnn/go-openai SDK 时,
// 中转站心跳/SSE 控制行累计 > 300 行触发 ErrTooManyEmptyStreamMessages
// (报错文案: "stream has sent too many empty messages")的问题。
//
// 触发链路:
// einoopenai.NewChatModel
// → eino-ext/libs/acl/openai → meguminnnnnnnnn/go-openai
// → streamReader.processLines() 对所有非 "data:" 行计数, > 300 即抛错。
//
// 中转站常见的非 data: 行(合法 SSE 但 SDK 不接受):
// ":" / ": keepalive" / ": ping" / "event: ping" / "retry: 3000"
// 以及思考型模型 prefill 期间穿插的大量心跳。
//
// 兜底策略: 在 HTTP transport 层把响应 Body 包一层 reader, 只放行 "data:"
// 开头的行, 把心跳/注释/事件类型行就地吞掉。下游 SDK 永远见不到非 data: 行,
// 计数器始终为 0, 该错误不可能再发生。
//
// 该层对调用方完全透明:
// - 仅当响应 Content-Type 是 text/event-stream 时介入;普通 JSON 响应原样透传
// - data: payload (含 [DONE] 与 {"error":...}) 一字节不改
// - 上游真断流 (EOF / connection reset / context cancel) 原样透传
import (
"bufio"
"bytes"
"io"
"net/http"
"strings"
)
const (
// einoSSEReaderBufSize 给 bufio 一个较大的初始缓冲, 避免单行大 JSON chunk
// (含工具调用 arguments / reasoning_content) 频繁触发缓冲区扩容。
einoSSEReaderBufSize = 64 * 1024
)
// einoSSESanitizingRoundTripper 包装下游 RoundTripper, 对 SSE 响应做行级清洗。
type einoSSESanitizingRoundTripper struct {
base http.RoundTripper
}
func (rt *einoSSESanitizingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rt.base.RoundTrip(req)
if err != nil || resp == nil {
return resp, err
}
if !isSSEResponse(resp) {
return resp, nil
}
resp.Body = newEinoSSESanitizingBody(resp.Body)
return resp, nil
}
// isSSEResponse 仅对 200 + text/event-stream 的响应做清洗;
// 错误响应 (4xx/5xx 通常是 application/json) 不动, 由 SDK 走原错误路径。
func isSSEResponse(resp *http.Response) bool {
if resp.StatusCode != http.StatusOK {
return false
}
ct := resp.Header.Get("Content-Type")
if ct == "" {
return false
}
ct = strings.ToLower(strings.TrimSpace(ct))
// 兼容 "text/event-stream", "text/event-stream; charset=utf-8" 等。
return strings.HasPrefix(ct, "text/event-stream")
}
// einoSSESanitizingBody 是包装后的响应体: 只放行 data: 行, 其它行吞掉。
type einoSSESanitizingBody struct {
upstream io.ReadCloser
reader *bufio.Reader
pending []byte // 已清洗、待返回给下游的字节 (永远以 \n 结尾的完整 data: 行)
err error // upstream 终态错误 (io.EOF 或网络错误)
}
func newEinoSSESanitizingBody(body io.ReadCloser) *einoSSESanitizingBody {
return &einoSSESanitizingBody{
upstream: body,
reader: bufio.NewReaderSize(body, einoSSEReaderBufSize),
}
}
func (b *einoSSESanitizingBody) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if len(b.pending) > 0 {
n := copy(p, b.pending)
b.pending = b.pending[n:]
return n, nil
}
// 从上游读, 直到攒出一行 data: 或拿到终态。
// 单次循环可能丢弃任意多行心跳, 但只放行至多一行 data: 后退出,
// 避免一次 Read 阻塞过久 / pending 缓冲过大。
for b.err == nil {
line, err := b.reader.ReadBytes('\n')
if len(line) > 0 {
if isPassThroughSSELine(line) {
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
b.pending = line
if err != nil {
b.err = err
}
break
}
// 非 data: 行 (空行 / ":" 注释 / event: / retry: / id: / 任何裸文本)
// 全部吞掉, 不向下游透出, 继续循环读下一行。
}
if err != nil {
b.err = err
break
}
}
if len(b.pending) > 0 {
n := copy(p, b.pending)
b.pending = b.pending[n:]
return n, nil
}
return 0, b.err
}
func (b *einoSSESanitizingBody) Close() error {
return b.upstream.Close()
}
// isPassThroughSSELine 判定该行是否需要原样放行给下游 SDK。
// 仅 "data:" (大小写不敏感, 可有任意前导空白) 开头的行需要保留。
// 注意: 不能用 TrimSpace 去尾部换行后再判, 否则 " data: x" 会被误判;
// 我们只 trim 前导空白, 与 SDK 内部 TrimSpace 后再正则 ^data:\s* 的语义一致。
func isPassThroughSSELine(line []byte) bool {
trimmed := bytes.TrimLeft(line, " \t")
if len(trimmed) < 5 {
return false
}
// 大小写不敏感比较前 5 字节是否为 "data:"。SSE 规范要求字段名小写,
// 但宽松匹配可以兼容个别中转站的非规范实现。
return (trimmed[0] == 'd' || trimmed[0] == 'D') &&
(trimmed[1] == 'a' || trimmed[1] == 'A') &&
(trimmed[2] == 't' || trimmed[2] == 'T') &&
(trimmed[3] == 'a' || trimmed[3] == 'A') &&
trimmed[4] == ':'
}
+303
View File
@@ -0,0 +1,303 @@
package openai
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
)
// 复现 meguminnnnnnnnn/go-openai 的 SSE 行计数算法 (默认 limit=300):
// - 逐行读
// - 非 "data:" 行 (空行 / ":" 注释 / event: / retry:) 累计 emptyMessagesCount
// - > 300 抛 ErrTooManyEmptyStreamMessages
// - 遇到 data: 行 reset, 返回 payload
//
// 这一算法与上游 SDK 的 stream_reader.go processLines() 严格一致 (验证依据见
// /Users/temp/go/pkg/mod/github.com/meguminnnnnnnnn/go-openai@v0.1.2/stream_reader.go)。
// 测试中只复刻 "限制触发" 这一行为, 用来回归验证 sanitizer 的根因修复。
var errTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
func sdkLikeRecvAll(body io.Reader, limit uint) ([]string, error) {
headerData := regexp.MustCompile(`^data:\s*`)
r := bufio.NewReader(body)
var payloads []string
for {
var emptyMessagesCount uint
var payload []byte
for {
line, err := r.ReadBytes('\n')
if err != nil {
if err == io.EOF {
return payloads, nil
}
return payloads, err
}
noSpace := bytes.TrimSpace(line)
if !headerData.Match(noSpace) {
emptyMessagesCount++
if emptyMessagesCount > limit {
return payloads, errTooManyEmptyStreamMessages
}
continue
}
payload = headerData.ReplaceAll(noSpace, nil)
break
}
if string(payload) == "[DONE]" {
return payloads, nil
}
payloads = append(payloads, string(payload))
}
}
func newSSEServer(t *testing.T, body string, contentType string, status int) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if contentType != "" {
w.Header().Set("Content-Type", contentType)
}
w.WriteHeader(status)
_, _ = io.WriteString(w, body)
}))
}
func sanitizingClient(base *http.Client) *http.Client {
if base == nil {
base = &http.Client{}
}
cloned := *base
transport := base.Transport
if transport == nil {
transport = http.DefaultTransport
}
cloned.Transport = &einoSSESanitizingRoundTripper{base: transport}
return &cloned
}
func readAll(t *testing.T, body io.ReadCloser) string {
t.Helper()
defer body.Close()
out, err := io.ReadAll(body)
if err != nil {
t.Fatalf("read body: %v", err)
}
return string(out)
}
// 1) 仅 data: 行 → 一字节不改地透传。
func TestSSESanitizer_PassesDataLinesUnchanged(t *testing.T) {
body := "data: {\"a\":1}\ndata: {\"b\":2}\ndata: [DONE]\n"
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("body mismatch:\nwant %q\ngot %q", body, got)
}
}
// 2) 心跳/注释/事件类型行被吞掉, 仅保留 data: 行。
func TestSSESanitizer_DropsHeartbeatsAndControlLines(t *testing.T) {
body := strings.Join([]string{
": keepalive",
"",
"event: ping",
"retry: 3000",
"id: 42",
"data: {\"x\":1}",
": ping",
"",
"data: {\"x\":2}",
"data: [DONE]",
"",
}, "\n")
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
want := "data: {\"x\":1}\ndata: {\"x\":2}\ndata: [DONE]\n"
if got != want {
t.Fatalf("sanitized body mismatch:\nwant %q\ngot %q", want, got)
}
}
// 3) 根因回归: 上游堆 500 行心跳后才发 data:, 原始 SDK 算法会抛
// ErrTooManyEmptyStreamMessages, sanitize 之后必须能正常拿到所有 data:。
func TestSSESanitizer_ProtectsAgainstTooManyEmptyMessages(t *testing.T) {
const heartbeats = 500
var buf bytes.Buffer
for i := 0; i < heartbeats; i++ {
buf.WriteString(": keepalive\n")
}
buf.WriteString("data: {\"chunk\":1}\n")
buf.WriteString("data: {\"chunk\":2}\n")
buf.WriteString("data: [DONE]\n")
t.Run("baseline_without_sanitizer_must_fail", func(t *testing.T) {
_, err := sdkLikeRecvAll(bytes.NewReader(buf.Bytes()), 300)
if !errors.Is(err, errTooManyEmptyStreamMessages) {
t.Fatalf("expected ErrTooManyEmptyStreamMessages, got %v", err)
}
})
t.Run("with_sanitizer_must_succeed", func(t *testing.T) {
srv := newSSEServer(t, buf.String(), "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
defer resp.Body.Close()
payloads, err := sdkLikeRecvAll(resp.Body, 300)
if err != nil {
t.Fatalf("sdk-like recv after sanitize: %v", err)
}
want := []string{`{"chunk":1}`, `{"chunk":2}`}
if len(payloads) != len(want) {
t.Fatalf("payload count mismatch: want %d got %d (%v)", len(want), len(payloads), payloads)
}
for i, w := range want {
if payloads[i] != w {
t.Fatalf("payload[%d] mismatch: want %q got %q", i, w, payloads[i])
}
}
})
}
// 4) 心跳穿插在 data: 之间也能正确清洗 (思考型模型 prefill 期间常见)。
func TestSSESanitizer_HeartbeatsInterleavedWithData(t *testing.T) {
var buf bytes.Buffer
buf.WriteString("data: {\"chunk\":1}\n")
for i := 0; i < 400; i++ {
buf.WriteString(": keepalive\n")
}
buf.WriteString("data: {\"chunk\":2}\n")
buf.WriteString("data: [DONE]\n")
srv := newSSEServer(t, buf.String(), "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
defer resp.Body.Close()
payloads, err := sdkLikeRecvAll(resp.Body, 300)
if err != nil {
t.Fatalf("sdk-like recv: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count: want %d got %d", want, got)
}
}
// 5) 非 SSE 响应 (例如非流式 JSON) 不应被 sanitizer 介入。
func TestSSESanitizer_PassesNonSSEResponseUntouched(t *testing.T) {
body := `{"id":"x","object":"chat.completion","choices":[]}`
srv := newSSEServer(t, body, "application/json", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("non-SSE body must be untouched:\nwant %q\ngot %q", body, got)
}
}
// 6) 错误响应 (4xx/5xx) 不应被 sanitize, 即使 Content-Type 是 SSE 也不动,
// 避免吞掉类似 "data: " 之外的错误正文。
func TestSSESanitizer_PassesNon200Untouched(t *testing.T) {
body := `{"error":{"message":"rate limit"}}`
srv := newSSEServer(t, body, "text/event-stream", 429)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("error body must be untouched:\nwant %q\ngot %q", body, got)
}
}
// 7) data: 行末尾若缺 \n (异常上游) sanitizer 也补齐, 保证下游按行解析。
func TestSSESanitizer_AppendsTrailingNewlineIfMissing(t *testing.T) {
body := "data: {\"a\":1}"
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
want := "data: {\"a\":1}\n"
if got != want {
t.Fatalf("trailing newline:\nwant %q\ngot %q", want, got)
}
}
// 8) 大 chunk (一行数十 KB) 也能完整透传, 不被切断。
func TestSSESanitizer_LargeDataLinePassesIntact(t *testing.T) {
huge := strings.Repeat("x", 80*1024)
body := "data: {\"big\":\"" + huge + "\"}\ndata: [DONE]\n"
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("large body length mismatch: want %d got %d", len(body), len(got))
}
}
// 9) isPassThroughSSELine 单元覆盖。
func TestIsPassThroughSSELine(t *testing.T) {
cases := []struct {
line string
want bool
}{
{"data: {\"a\":1}\n", true},
{"DATA: x\n", true},
{" data: x\n", true},
{"data:\n", true},
{"\n", false},
{"\r\n", false},
{": keepalive\n", false},
{":\n", false},
{"event: ping\n", false},
{"retry: 3000\n", false},
{"id: 42\n", false},
{"datax: y\n", false},
{"da", false},
}
for _, c := range cases {
if got := isPassThroughSSELine([]byte(c.line)); got != c.want {
t.Errorf("isPassThroughSSELine(%q) = %v, want %v", c.line, got, c.want)
}
}
}