diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 5407acaa..384298c3 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -130,6 +130,10 @@ func (w *mockSSERespWriter) Write(data []byte) (int, error) { return len(data), nil } +func (w *mockSSERespWriter) WriteString(str string) (int, error) { + return w.Write([]byte(str)) +} + func (w *mockSSERespWriter) ReadBytes() []byte { return <-w.writeNotification } @@ -156,7 +160,6 @@ func TestProxySingleOrigin(t *testing.T) { t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) - t.Run("testProxySSEAllData", testProxySSEAllData(proxy)) cancel() } @@ -276,17 +279,15 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { // Regression test to guarantee that we always write the contents downstream even if EOF is reached without // hitting the delimiter -func testProxySSEAllData(proxy *Proxy) func(t *testing.T) { - return func(t *testing.T) { - eyeballReader := io.NopCloser(strings.NewReader("data\r\r")) - responseWriter := newMockSSERespWriter() +func TestProxySSEAllData(t *testing.T) { + eyeballReader := io.NopCloser(strings.NewReader("data\r\r")) + responseWriter := newMockSSERespWriter() - // responseWriter uses an unbuffered channel, so we call in a different go-routine - go cfio.Copy(responseWriter, eyeballReader) + // responseWriter uses an unbuffered channel, so we call in a different go-routine + go cfio.Copy(responseWriter, eyeballReader) - result := string(<-responseWriter.writeNotification) - require.Equal(t, "data\r\r", result) - } + result := string(<-responseWriter.writeNotification) + require.Equal(t, "data\r\r", result) } func TestProxyMultipleOrigins(t *testing.T) {