package ingress import ( "context" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/websocket" ) func TestRawTCPServiceEstablishConnection(t *testing.T) { originListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) listenerClosed := make(chan struct{}) tcpListenRoutine(originListener, listenerClosed) rawTCPService := &rawTCPService{name: ServiceWarpRouting} req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) require.NoError(t, err) originListener.Close() <-listenerClosed req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) require.NoError(t, err) // Origin not listening for new connection, should return an error _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger) require.Error(t, err) } func TestTCPOverWSServiceEstablishConnection(t *testing.T) { originListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) listenerClosed := make(chan struct{}) tcpListenRoutine(originListener, listenerClosed) originURL := &url.URL{ Scheme: "tcp", Host: originListener.Addr().String(), } baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil) require.NoError(t, err) baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") bastionReq := baseReq.Clone(context.Background()) carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) tests := []struct { testCase string service *tcpOverWSService req *http.Request bastionMode bool expectErr bool }{ { testCase: "specific TCP service", service: newTCPOverWSService(originURL), req: baseReq, }, { testCase: "bastion service", service: newBastionService(), bastionMode: true, req: bastionReq, }, { testCase: "invalid bastion request", service: newBastionService(), bastionMode: true, req: baseReq, expectErr: true, }, { testCase: "bastion service", service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")), req: bastionReq, bastionMode: true, }, { testCase: "bastion service", service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")), req: baseReq, bastionMode: true, expectErr: true, }, } for _, test := range tests { t.Run(test.testCase, func(t *testing.T) { if test.expectErr { bastionHost, _ := carrier.ResolveBastionDest(test.req, test.bastionMode, config.BastionFlag) _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } }) } originListener.Close() <-listenerClosed for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { // Origin not listening for new connection, should return an error bastionHost, _ := carrier.ResolveBastionDest(bastionReq, false, "bastion") _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } } func TestHTTPServiceHostHeaderOverride(t *testing.T) { cfg := OriginRequestConfig{ HTTPHostHeader: t.Name(), } handler := func(w http.ResponseWriter, r *http.Request) { require.Equal(t, r.Host, t.Name()) if websocket.IsWebSocketUpgrade(r) { respHeaders := websocket.NewResponseHeader(r) for k, v := range respHeaders { w.Header().Set(k, v[0]) } w.WriteHeader(http.StatusSwitchingProtocols) return } // return the X-Forwarded-Host header for assertions // as the httptest Server URL isn't available here yet w.Write([]byte(r.Header.Get("X-Forwarded-Host"))) } origin := httptest.NewServer(http.HandlerFunc(handler)) defer origin.Close() originURL, err := url.Parse(origin.URL) require.NoError(t, err) httpService := &httpService{ url: originURL, } shutdownC := make(chan struct{}) require.NoError(t, httpService.start(TestLogger, shutdownC, cfg)) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) require.NoError(t, err) resp, err := httpService.RoundTrip(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, respBody, []byte(originURL.Host)) } // TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { handler := func(w http.ResponseWriter, r *http.Request) { require.NotNil(t, r.TLS) // Echo the X-Forwarded-Proto header for assertions w.Write([]byte(r.Header.Get("X-Forwarded-Proto"))) } origin := httptest.NewTLSServer(http.HandlerFunc(handler)) defer origin.Close() originURL, err := url.Parse(origin.URL) require.NoError(t, err) require.Equal(t, "https", originURL.Scheme) cfg := OriginRequestConfig{ NoTLSVerify: true, } httpService := &httpService{ url: originURL, } shutdownC := make(chan struct{}) require.NoError(t, httpService.start(TestLogger, shutdownC, cfg)) // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header protos := []string{"https", "http", "dne"} for _, p := range protos { req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) require.NoError(t, err) req.Header.Add("X-Forwarded-Proto", p) resp, err := httpService.RoundTrip(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, respBody, []byte(p)) } } func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { go func() { for { conn, err := listener.Accept() if err != nil { close(closeChan) return } // Close immediately, this test is not about testing read/write on connection conn.Close() } }() }