package dbconnect import ( "context" "fmt" "io" "io/ioutil" "net" "net/http" "net/http/httptest" "strings" "testing" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" ) func TestNewInsecureProxy(t *testing.T) { origins := []string{ "", ":/", "http://localhost", "tcp://localhost:9000?debug=true", "mongodb://127.0.0.1", } for _, origin := range origins { proxy, err := NewInsecureProxy(context.Background(), origin) assert.Error(t, err) assert.Empty(t, proxy) } } func TestProxyIsAllowed(t *testing.T) { proxy := helperNewProxy(t) req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil) assert.True(t, proxy.IsAllowed(req)) proxy = helperNewProxy(t, true) req.Header.Set("Cf-access-jwt-assertion", "xxx") assert.False(t, proxy.IsAllowed(req)) } func TestProxyStart(t *testing.T) { proxy := helperNewProxy(t) ctx := context.Background() listenerC := make(chan net.Listener) err := proxy.Start(ctx, "1.1.1.1:", listenerC) assert.Error(t, err) err = proxy.Start(ctx, "127.0.0.1:-1", listenerC) assert.Error(t, err) ctx, cancel := context.WithTimeout(ctx, 0) defer cancel() err = proxy.Start(ctx, "127.0.0.1:", listenerC) assert.IsType(t, http.ErrServerClosed, err) } func TestProxyHTTPRouter(t *testing.T) { proxy := helperNewProxy(t) router := proxy.httpRouter() tests := []struct { path string method string valid bool }{ {"", "GET", false}, {"/", "GET", false}, {"/ping", "GET", true}, {"/ping", "HEAD", true}, {"/ping", "POST", false}, {"/submit", "POST", true}, {"/submit", "GET", false}, {"/submit/extra", "POST", false}, } for _, test := range tests { match := &mux.RouteMatch{} ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match) assert.True(t, ok == test.valid, test.path) } } func TestProxyHTTPPing(t *testing.T) { proxy := helperNewProxy(t) server := httptest.NewServer(proxy.httpPing()) defer server.Close() client := server.Client() res, err := client.Get(server.URL) assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, int64(2), res.ContentLength) res, err = client.Head(server.URL) assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, int64(-1), res.ContentLength) } func TestProxyHTTPSubmit(t *testing.T) { proxy := helperNewProxy(t) server := httptest.NewServer(proxy.httpSubmit()) defer server.Close() client := server.Client() tests := []struct { input string status int output string }{ {"", http.StatusBadRequest, "request body cannot be empty"}, {"{}", http.StatusBadRequest, "cannot provide an empty statement"}, {"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"}, {"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"}, {"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"}, } for _, test := range tests { res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input)) assert.NoError(t, err) assert.Equal(t, test.status, res.StatusCode) if res.StatusCode > http.StatusOK { assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type")) } else { assert.Equal(t, "application/json", res.Header.Get("Content-type")) } data, err := ioutil.ReadAll(res.Body) defer res.Body.Close() str := string(data) assert.NoError(t, err) assert.Equal(t, test.output, str) } } func TestProxyHTTPSubmitForbidden(t *testing.T) { proxy := helperNewProxy(t, true) server := httptest.NewServer(proxy.httpSubmit()) defer server.Close() client := server.Client() res, err := client.Get(server.URL) assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, res.StatusCode) assert.Zero(t, res.ContentLength) } func TestProxyHTTPRespond(t *testing.T) { proxy := helperNewProxy(t) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxy.httpRespond(w, r, http.StatusAccepted, "Hello") })) defer server.Close() client := server.Client() res, err := client.Get(server.URL) assert.NoError(t, err) assert.Equal(t, http.StatusAccepted, res.StatusCode) assert.Equal(t, int64(5), res.ContentLength) data, err := ioutil.ReadAll(res.Body) defer res.Body.Close() assert.Equal(t, []byte("Hello"), data) } func TestProxyHTTPRespondForbidden(t *testing.T) { proxy := helperNewProxy(t, true) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxy.httpRespond(w, r, http.StatusAccepted, "Hello") })) defer server.Close() client := server.Client() res, err := client.Get(server.URL) assert.NoError(t, err) assert.Equal(t, http.StatusAccepted, res.StatusCode) assert.Equal(t, int64(0), res.ContentLength) } func TestHTTPError(t *testing.T) { _, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0) assert.Error(t, errTimeout) tests := []struct { input error status int output error }{ {nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")}, {io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")}, {context.DeadlineExceeded, http.StatusRequestTimeout, nil}, {context.Canceled, 444, nil}, {errTimeout, http.StatusRequestTimeout, nil}, {fmt.Errorf(""), http.StatusInternalServerError, nil}, } for _, test := range tests { status, err := httpError(http.StatusInternalServerError, test.input) assert.Error(t, err) assert.Equal(t, test.status, status) if test.output == nil { test.output = test.input } assert.Equal(t, test.output, err) } } func helperNewProxy(t *testing.T, secure ...bool) *Proxy { t.Helper() proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "") assert.NoError(t, err) assert.NotNil(t, proxy) if len(secure) == 0 { proxy.accessValidator = nil // Mark as insecure } return proxy }