cloudflared-mirror/validation/validation_test.go

374 lines
12 KiB
Go

package validation
import (
"bytes"
"fmt"
"io/ioutil"
"testing"
"time"
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"github.com/stretchr/testify/assert"
)
func TestValidateHostname(t *testing.T) {
var inputHostname string
hostname, err := ValidateHostname(inputHostname)
assert.Equal(t, err, nil)
assert.Empty(t, hostname)
inputHostname = "hello.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "hello.example.com", hostname)
inputHostname = "http://hello.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "hello.example.com", hostname)
inputHostname = "bücher.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
inputHostname = "http://bücher.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
inputHostname = "http%3A%2F%2Fhello.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "hello.example.com", hostname)
}
func TestValidateUrl(t *testing.T) {
type testCase struct {
input string
expectedOutput string
}
testCases := []testCase{
{"http://localhost", "http://localhost"},
{"http://localhost/", "http://localhost"},
{"http://localhost/api", "http://localhost"},
{"http://localhost/api/", "http://localhost"},
{"https://localhost", "https://localhost"},
{"https://localhost/", "https://localhost"},
{"https://localhost/api", "https://localhost"},
{"https://localhost/api/", "https://localhost"},
{"https://localhost:8080", "https://localhost:8080"},
{"https://localhost:8080/", "https://localhost:8080"},
{"https://localhost:8080/api", "https://localhost:8080"},
{"https://localhost:8080/api/", "https://localhost:8080"},
{"localhost", "http://localhost"},
{"localhost/", "http://localhost/"},
{"localhost/api", "http://localhost/api"},
{"localhost/api/", "http://localhost/api/"},
{"localhost:8080", "http://localhost:8080"},
{"localhost:8080/", "http://localhost:8080/"},
{"localhost:8080/api", "http://localhost:8080/api"},
{"localhost:8080/api/", "http://localhost:8080/api/"},
{"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
{"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
{"127.0.0.1:8080", "http://127.0.0.1:8080"},
{"127.0.0.1", "http://127.0.0.1"},
{"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
{"[::1]:8080", "http://[::1]:8080"},
{"http://[::1]", "http://[::1]"},
{"http://[::1]:8080", "http://[::1]:8080"},
{"[::1]", "http://[::1]"},
{"https://example.com", "https://example.com"},
{"example.com", "http://example.com"},
{"http://hello.example.com", "http://hello.example.com"},
{"hello.example.com", "http://hello.example.com"},
{"hello.example.com:8080", "http://hello.example.com:8080"},
{"https://hello.example.com:8080", "https://hello.example.com:8080"},
{"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
{"bücher.example.com", "http://xn--bcher-kva.example.com"},
{"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
{"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
}
for i, testCase := range testCases {
validUrl, err := ValidateUrl(testCase.input)
assert.NoError(t, err, "test case %v", i)
assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
}
validUrl, err := ValidateUrl("")
assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
assert.Equal(t, "Currently Cloudflare Tunnel does not support ftp protocol.", err.Error())
assert.Empty(t, validUrl)
}
func TestToggleProtocol(t *testing.T) {
assert.Equal(t, "https", toggleProtocol("http"))
assert.Equal(t, "http", toggleProtocol("https"))
assert.Equal(t, "random", toggleProtocol("random"))
assert.Equal(t, "", toggleProtocol(""))
}
// Happy path 1: originURL is HTTP, and HTTP connections work
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
originURL := "http://127.0.0.1/"
hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(200), nil
}
if req.URL.Scheme == "https" {
t.Fatal("http works, shouldn't have tried with https")
}
panic("Shouldn't reach here")
})))
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(503), nil
}
if req.URL.Scheme == "https" {
t.Fatal("http works, shouldn't have tried with https")
}
panic("Shouldn't reach here")
})))
}
// Happy path 2: originURL is HTTPS, and HTTPS connections work
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
originURL := "https://127.0.0.1:1234/"
hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
t.Fatal("https works, shouldn't have tried with http")
}
if req.URL.Scheme == "https" {
return emptyResponse(200), nil
}
panic("Shouldn't reach here")
})))
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
t.Fatal("https works, shouldn't have tried with http")
}
if req.URL.Scheme == "https" {
return emptyResponse(503), nil
}
panic("Shouldn't reach here")
})))
}
// Error path 1: originURL is HTTPS, but HTTP connections work
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := "https://127.0.0.1:1234/"
hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(200), nil
}
if req.URL.Scheme == "https" {
return nil, assert.AnError
}
panic("Shouldn't reach here")
})))
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(503), nil
}
if req.URL.Scheme == "https" {
return nil, assert.AnError
}
panic("Shouldn't reach here")
})))
}
// Error path 2: originURL is HTTP, but HTTPS connections work
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
originURL := "http://127.0.0.1:1234/"
hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return nil, assert.AnError
}
if req.URL.Scheme == "https" {
return emptyResponse(200), nil
}
panic("Shouldn't reach here")
})))
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return nil, assert.AnError
}
if req.URL.Scheme == "https" {
return emptyResponse(503), nil
}
panic("Shouldn't reach here")
})))
}
// Ensure the client does not follow 302 responses
func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
hostname := "example.com"
redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/followedRedirect" {
t.Fatal("shouldn't have followed the 302")
}
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.Header().Set("Location", "/followedRedirect")
w.WriteHeader(302)
}))
assert.NoError(t, err)
defer redirectServer.Close()
assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
}
// Ensure validation times out when origin URL is nonresponsive
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
originURL := "http://127.0.0.1/"
hostname := "example.com"
oldValidationTimeout := validationTimeout
defer func() {
validationTimeout = oldValidationTimeout
}()
validationTimeout = 500 * time.Millisecond
// Use createMockServerAndClient, not createSecureMockServerAndClient.
// The latter will bail with HTTP 400 immediately on an http:// request,
// which defeats the purpose of a 'nonresponsive origin' test.
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
time.Sleep(1 * time.Second)
w.WriteHeader(200)
}))
if !assert.NoError(t, err) {
t.FailNow()
}
defer server.Close()
err = ValidateHTTPService(originURL, hostname, client.Transport)
fmt.Println(err)
if err, ok := err.(net.Error); assert.True(t, ok) {
assert.True(t, err.Timeout())
}
}
func TestNewAccessValidatorOk(t *testing.T) {
ctx := context.Background()
url := "test.cloudflareaccess.com"
access, err := NewAccessValidator(ctx, url, url, "")
assert.NoError(t, err)
assert.NotNil(t, access)
assert.Error(t, access.Validate(ctx, ""))
assert.Error(t, access.Validate(ctx, "invalid"))
req := httptest.NewRequest("GET", "https://test.cloudflareaccess.com", nil)
req.Header.Set(accessJwtHeader, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
assert.Error(t, access.ValidateRequest(ctx, req))
}
func TestNewAccessValidatorErr(t *testing.T) {
ctx := context.Background()
urls := []string{
"",
"ftp://test.cloudflareaccess.com",
"wss://cloudflarenone.com",
}
for _, url := range urls {
access, err := NewAccessValidator(ctx, url, url, "")
assert.Error(t, err, url)
assert.Nil(t, access)
}
}
type testRoundTripper func(req *http.Request) (*http.Response, error)
func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func emptyResponse(statusCode int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Header: make(http.Header),
}
}
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
client := http.DefaultClient
server := httptest.NewServer(handler)
client.Transport = &http.Transport{
Proxy: func(req *http.Request) (*url.URL, error) {
return url.Parse(server.URL)
},
}
return server, client, nil
}
func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
client := http.DefaultClient
server := httptest.NewTLSServer(handler)
cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0])
if err != nil {
server.Close()
return nil, nil, err
}
certpool := x509.NewCertPool()
certpool.AddCert(cert)
client.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:])
},
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
}
return server, client, nil
}