package carrier import ( "context" "crypto/tls" "crypto/x509" "fmt" "math/rand" "testing" "time" gws "github.com/gorilla/websocket" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/websocket" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/tlsconfig" cfwebsocket "github.com/cloudflare/cloudflared/websocket" ) func websocketClientTLSConfig(t *testing.T) *tls.Config { certPool := x509.NewCertPool() helloCert, err := tlsconfig.GetHelloCertificateX509() assert.NoError(t, err) certPool.AddCert(helloCert) assert.NotNil(t, certPool) return &tls.Config{RootCAs: certPool} } func TestWebsocketHeaders(t *testing.T) { req := testRequest(t, "http://example.com", nil) wsHeaders := websocketHeaders(req) for _, header := range stripWebsocketHeaders { assert.Empty(t, wsHeaders[header]) } assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent")) } func TestServe(t *testing.T) { log := zerolog.Nop() shutdownC := make(chan struct{}) errC := make(chan error) listener, err := hello.CreateTLSListener("localhost:1111") assert.NoError(t, err) defer listener.Close() go func() { errC <- hello.StartHelloWorldServer(&log, listener, shutdownC) }() req := testRequest(t, "https://localhost:1111/ws", nil) tlsConfig := websocketClientTLSConfig(t) assert.NotNil(t, tlsConfig) d := gws.Dialer{TLSClientConfig: tlsConfig} conn, resp, err := clientConnect(req, &d) assert.NoError(t, err) assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) for i := 0; i < 1000; i++ { messageSize := rand.Int()%2048 + 1 clientMessage := make([]byte, messageSize) // rand.Read always returns len(clientMessage) and a nil error rand.Read(clientMessage) err = conn.WriteMessage(websocket.BinaryFrame, clientMessage) assert.NoError(t, err) messageType, message, err := conn.ReadMessage() assert.NoError(t, err) assert.Equal(t, websocket.BinaryFrame, messageType) assert.Equal(t, clientMessage, message) } _ = conn.Close() close(shutdownC) <-errC } func TestWebsocketWrapper(t *testing.T) { listener, err := hello.CreateTLSListener("localhost:0") require.NoError(t, err) serverErrorChan := make(chan error) helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background()) defer func() { <-serverErrorChan }() defer cancelHelloSvr() go func() { log := zerolog.Nop() serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done()) }() tlsConfig := websocketClientTLSConfig(t) d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute} testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String()) req := testRequest(t, testAddr, nil) conn, resp, err := clientConnect(req, &d) require.NoError(t, err) assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) // Websocket now connected to test server so lets check our wrapper wrapper := cfwebsocket.GorillaConn{Conn: conn} buf := make([]byte, 100) wrapper.Write([]byte("abc")) n, err := wrapper.Read(buf) require.NoError(t, err) require.Equal(t, n, 3) require.Equal(t, "abc", string(buf[:n])) // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read wrapper.Write([]byte("abc")) buf = buf[:1] n, err = wrapper.Read(buf) require.NoError(t, err) require.Equal(t, n, 1) require.Equal(t, "a", string(buf[:n])) buf = buf[:cap(buf)] n, err = wrapper.Read(buf) require.NoError(t, err) require.Equal(t, n, 2) require.Equal(t, "bc", string(buf[:n])) }