package openai import ( "bufio" "bytes" "errors" "io" "net/http" "net/http/httptest" "regexp" "strings" "testing" ) // 复现 meguminnnnnnnnn/go-openai 的 SSE 行计数算法 (默认 limit=300): // - 逐行读 // - 非 "data:" 行 (空行 / ":" 注释 / event: / retry:) 累计 emptyMessagesCount // - > 300 抛 ErrTooManyEmptyStreamMessages // - 遇到 data: 行 reset, 返回 payload // // 这一算法与上游 SDK 的 stream_reader.go processLines() 严格一致 (验证依据见 // /Users/temp/go/pkg/mod/github.com/meguminnnnnnnnn/go-openai@v0.1.2/stream_reader.go)。 // 测试中只复刻 "限制触发" 这一行为, 用来回归验证 sanitizer 的根因修复。 var errTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages") func sdkLikeRecvAll(body io.Reader, limit uint) ([]string, error) { headerData := regexp.MustCompile(`^data:\s*`) r := bufio.NewReader(body) var payloads []string for { var emptyMessagesCount uint var payload []byte for { line, err := r.ReadBytes('\n') if err != nil { if err == io.EOF { return payloads, nil } return payloads, err } noSpace := bytes.TrimSpace(line) if !headerData.Match(noSpace) { emptyMessagesCount++ if emptyMessagesCount > limit { return payloads, errTooManyEmptyStreamMessages } continue } payload = headerData.ReplaceAll(noSpace, nil) break } if string(payload) == "[DONE]" { return payloads, nil } payloads = append(payloads, string(payload)) } } func newSSEServer(t *testing.T, body string, contentType string, status int) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if contentType != "" { w.Header().Set("Content-Type", contentType) } w.WriteHeader(status) _, _ = io.WriteString(w, body) })) } func sanitizingClient(base *http.Client) *http.Client { if base == nil { base = &http.Client{} } cloned := *base transport := base.Transport if transport == nil { transport = http.DefaultTransport } cloned.Transport = &einoSSESanitizingRoundTripper{base: transport} return &cloned } func readAll(t *testing.T, body io.ReadCloser) string { t.Helper() defer body.Close() out, err := io.ReadAll(body) if err != nil { t.Fatalf("read body: %v", err) } return string(out) } // 1) 仅 data: 行 → 一字节不改地透传。 func TestSSESanitizer_PassesDataLinesUnchanged(t *testing.T) { body := "data: {\"a\":1}\ndata: {\"b\":2}\ndata: [DONE]\n" srv := newSSEServer(t, body, "text/event-stream", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } got := readAll(t, resp.Body) if got != body { t.Fatalf("body mismatch:\nwant %q\ngot %q", body, got) } } // 2) 心跳/注释/事件类型行被吞掉, 仅保留 data: 行。 func TestSSESanitizer_DropsHeartbeatsAndControlLines(t *testing.T) { body := strings.Join([]string{ ": keepalive", "", "event: ping", "retry: 3000", "id: 42", "data: {\"x\":1}", ": ping", "", "data: {\"x\":2}", "data: [DONE]", "", }, "\n") srv := newSSEServer(t, body, "text/event-stream", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } got := readAll(t, resp.Body) want := "data: {\"x\":1}\ndata: {\"x\":2}\ndata: [DONE]\n" if got != want { t.Fatalf("sanitized body mismatch:\nwant %q\ngot %q", want, got) } } // 3) 根因回归: 上游堆 500 行心跳后才发 data:, 原始 SDK 算法会抛 // ErrTooManyEmptyStreamMessages, sanitize 之后必须能正常拿到所有 data:。 func TestSSESanitizer_ProtectsAgainstTooManyEmptyMessages(t *testing.T) { const heartbeats = 500 var buf bytes.Buffer for i := 0; i < heartbeats; i++ { buf.WriteString(": keepalive\n") } buf.WriteString("data: {\"chunk\":1}\n") buf.WriteString("data: {\"chunk\":2}\n") buf.WriteString("data: [DONE]\n") t.Run("baseline_without_sanitizer_must_fail", func(t *testing.T) { _, err := sdkLikeRecvAll(bytes.NewReader(buf.Bytes()), 300) if !errors.Is(err, errTooManyEmptyStreamMessages) { t.Fatalf("expected ErrTooManyEmptyStreamMessages, got %v", err) } }) t.Run("with_sanitizer_must_succeed", func(t *testing.T) { srv := newSSEServer(t, buf.String(), "text/event-stream", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } defer resp.Body.Close() payloads, err := sdkLikeRecvAll(resp.Body, 300) if err != nil { t.Fatalf("sdk-like recv after sanitize: %v", err) } want := []string{`{"chunk":1}`, `{"chunk":2}`} if len(payloads) != len(want) { t.Fatalf("payload count mismatch: want %d got %d (%v)", len(want), len(payloads), payloads) } for i, w := range want { if payloads[i] != w { t.Fatalf("payload[%d] mismatch: want %q got %q", i, w, payloads[i]) } } }) } // 4) 心跳穿插在 data: 之间也能正确清洗 (思考型模型 prefill 期间常见)。 func TestSSESanitizer_HeartbeatsInterleavedWithData(t *testing.T) { var buf bytes.Buffer buf.WriteString("data: {\"chunk\":1}\n") for i := 0; i < 400; i++ { buf.WriteString(": keepalive\n") } buf.WriteString("data: {\"chunk\":2}\n") buf.WriteString("data: [DONE]\n") srv := newSSEServer(t, buf.String(), "text/event-stream", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } defer resp.Body.Close() payloads, err := sdkLikeRecvAll(resp.Body, 300) if err != nil { t.Fatalf("sdk-like recv: %v", err) } if got, want := len(payloads), 2; got != want { t.Fatalf("payload count: want %d got %d", want, got) } } // 5) 非 SSE 响应 (例如非流式 JSON) 不应被 sanitizer 介入。 func TestSSESanitizer_PassesNonSSEResponseUntouched(t *testing.T) { body := `{"id":"x","object":"chat.completion","choices":[]}` srv := newSSEServer(t, body, "application/json", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } got := readAll(t, resp.Body) if got != body { t.Fatalf("non-SSE body must be untouched:\nwant %q\ngot %q", body, got) } } // 6) 错误响应 (4xx/5xx) 不应被 sanitize, 即使 Content-Type 是 SSE 也不动, // 避免吞掉类似 "data: " 之外的错误正文。 func TestSSESanitizer_PassesNon200Untouched(t *testing.T) { body := `{"error":{"message":"rate limit"}}` srv := newSSEServer(t, body, "text/event-stream", 429) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } got := readAll(t, resp.Body) if got != body { t.Fatalf("error body must be untouched:\nwant %q\ngot %q", body, got) } } // 7) data: 行末尾若缺 \n (异常上游) sanitizer 也补齐, 保证下游按行解析。 func TestSSESanitizer_AppendsTrailingNewlineIfMissing(t *testing.T) { body := "data: {\"a\":1}" srv := newSSEServer(t, body, "text/event-stream", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } got := readAll(t, resp.Body) want := "data: {\"a\":1}\n" if got != want { t.Fatalf("trailing newline:\nwant %q\ngot %q", want, got) } } // 8) 大 chunk (一行数十 KB) 也能完整透传, 不被切断。 func TestSSESanitizer_LargeDataLinePassesIntact(t *testing.T) { huge := strings.Repeat("x", 80*1024) body := "data: {\"big\":\"" + huge + "\"}\ndata: [DONE]\n" srv := newSSEServer(t, body, "text/event-stream", 200) defer srv.Close() resp, err := sanitizingClient(nil).Get(srv.URL) if err != nil { t.Fatalf("get: %v", err) } got := readAll(t, resp.Body) if got != body { t.Fatalf("large body length mismatch: want %d got %d", len(body), len(got)) } } // 9) isPassThroughSSELine 单元覆盖。 func TestIsPassThroughSSELine(t *testing.T) { cases := []struct { line string want bool }{ {"data: {\"a\":1}\n", true}, {"DATA: x\n", true}, {" data: x\n", true}, {"data:\n", true}, {"\n", false}, {"\r\n", false}, {": keepalive\n", false}, {":\n", false}, {"event: ping\n", false}, {"retry: 3000\n", false}, {"id: 42\n", false}, {"datax: y\n", false}, {"da", false}, } for _, c := range cases { if got := isPassThroughSSELine([]byte(c.line)); got != c.want { t.Errorf("isPassThroughSSELine(%q) = %v, want %v", c.line, got, c.want) } } }