package validation import ( "bytes" "fmt" "io" "testing" "context" "crypto/tls" "crypto/x509" "net" "net/http" "net/http/httptest" "net/url" "strings" "github.com/stretchr/testify/assert" ) func TestValidateHostname(t *testing.T) { var inputHostname string hostname, err := ValidateHostname(inputHostname) assert.Equal(t, err, nil) assert.Empty(t, hostname) inputHostname = "hello.example.com" hostname, err = ValidateHostname(inputHostname) assert.Nil(t, err) assert.Equal(t, "hello.example.com", hostname) inputHostname = "http://hello.example.com" hostname, err = ValidateHostname(inputHostname) assert.Nil(t, err) assert.Equal(t, "hello.example.com", hostname) inputHostname = "bücher.example.com" hostname, err = ValidateHostname(inputHostname) assert.Nil(t, err) assert.Equal(t, "xn--bcher-kva.example.com", hostname) inputHostname = "http://bücher.example.com" hostname, err = ValidateHostname(inputHostname) assert.Nil(t, err) assert.Equal(t, "xn--bcher-kva.example.com", hostname) inputHostname = "http%3A%2F%2Fhello.example.com" hostname, err = ValidateHostname(inputHostname) assert.Nil(t, err) assert.Equal(t, "hello.example.com", hostname) } func TestValidateUrl(t *testing.T) { type testCase struct { input string expectedOutput string } testCases := []testCase{ {"http://localhost", "http://localhost"}, {"http://localhost/", "http://localhost"}, {"http://localhost/api", "http://localhost"}, {"http://localhost/api/", "http://localhost"}, {"https://localhost", "https://localhost"}, {"https://localhost/", "https://localhost"}, {"https://localhost/api", "https://localhost"}, {"https://localhost/api/", "https://localhost"}, {"https://localhost:8080", "https://localhost:8080"}, {"https://localhost:8080/", "https://localhost:8080"}, {"https://localhost:8080/api", "https://localhost:8080"}, {"https://localhost:8080/api/", "https://localhost:8080"}, {"localhost", "http://localhost"}, {"localhost/", "http://localhost/"}, {"localhost/api", "http://localhost/api"}, {"localhost/api/", "http://localhost/api/"}, {"localhost:8080", "http://localhost:8080"}, {"localhost:8080/", "http://localhost:8080/"}, {"localhost:8080/api", "http://localhost:8080/api"}, {"localhost:8080/api/", "http://localhost:8080/api/"}, {"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"}, {"http://127.0.0.1:8080", "http://127.0.0.1:8080"}, {"127.0.0.1:8080", "http://127.0.0.1:8080"}, {"127.0.0.1", "http://127.0.0.1"}, {"https://127.0.0.1:8080", "https://127.0.0.1:8080"}, {"[::1]:8080", "http://[::1]:8080"}, {"http://[::1]", "http://[::1]"}, {"http://[::1]:8080", "http://[::1]:8080"}, {"[::1]", "http://[::1]"}, {"https://example.com", "https://example.com"}, {"example.com", "http://example.com"}, {"http://hello.example.com", "http://hello.example.com"}, {"hello.example.com", "http://hello.example.com"}, {"hello.example.com:8080", "http://hello.example.com:8080"}, {"https://hello.example.com:8080", "https://hello.example.com:8080"}, {"https://bücher.example.com", "https://xn--bcher-kva.example.com"}, {"bücher.example.com", "http://xn--bcher-kva.example.com"}, {"https%3A%2F%2Fhello.example.com", "https://hello.example.com"}, {"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"}, } for i, testCase := range testCases { validUrl, err := ValidateUrl(testCase.input) assert.NoError(t, err, "test case %v", i) assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i) } validUrl, err := ValidateUrl("") assert.Equal(t, fmt.Errorf("URL should not be empty"), err) assert.Empty(t, validUrl) validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt") assert.Equal(t, "Currently Cloudflare Tunnel does not support ftp protocol.", err.Error()) assert.Empty(t, validUrl) } func TestNewAccessValidatorOk(t *testing.T) { ctx := context.Background() url := "test.cloudflareaccess.com" access, err := NewAccessValidator(ctx, url, url, "") assert.NoError(t, err) assert.NotNil(t, access) assert.Error(t, access.Validate(ctx, "")) assert.Error(t, access.Validate(ctx, "invalid")) req := httptest.NewRequest("GET", "https://test.cloudflareaccess.com", nil) req.Header.Set(accessJwtHeader, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c") assert.Error(t, access.ValidateRequest(ctx, req)) } func TestNewAccessValidatorErr(t *testing.T) { ctx := context.Background() urls := []string{ "", "ftp://test.cloudflareaccess.com", "wss://cloudflarenone.com", } for _, url := range urls { access, err := NewAccessValidator(ctx, url, url, "") assert.Error(t, err, url) assert.Nil(t, access) } } type testRoundTripper func(req *http.Request) (*http.Response, error) func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } func emptyResponse(statusCode int) *http.Response { return &http.Response{ StatusCode: statusCode, Body: io.NopCloser(bytes.NewReader(nil)), Header: make(http.Header), } } func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) { client := http.DefaultClient server := httptest.NewServer(handler) client.Transport = &http.Transport{ Proxy: func(req *http.Request) (*url.URL, error) { return url.Parse(server.URL) }, } return server, client, nil } func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) { client := http.DefaultClient server := httptest.NewTLSServer(handler) cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0]) if err != nil { server.Close() return nil, nil, err } certpool := x509.NewCertPool() certpool.AddCert(cert) client.Transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:]) }, TLSClientConfig: &tls.Config{ RootCAs: certpool, }, } return server, client, nil }