cloudflared-mirror/dbconnect/proxy_test.go

239 lines
5.9 KiB
Go

package dbconnect
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
)
func TestNewInsecureProxy(t *testing.T) {
origins := []string{
"",
":/",
"http://localhost",
"tcp://localhost:9000?debug=true",
"mongodb://127.0.0.1",
}
for _, origin := range origins {
proxy, err := NewInsecureProxy(context.Background(), origin)
assert.Error(t, err)
assert.Empty(t, proxy)
}
}
func TestProxyIsAllowed(t *testing.T) {
proxy := helperNewProxy(t)
req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
assert.True(t, proxy.IsAllowed(req))
proxy = helperNewProxy(t, true)
req.Header.Set("Cf-access-jwt-assertion", "xxx")
assert.False(t, proxy.IsAllowed(req))
}
func TestProxyStart(t *testing.T) {
proxy := helperNewProxy(t)
ctx := context.Background()
listenerC := make(chan net.Listener)
err := proxy.Start(ctx, "1.1.1.1:", listenerC)
assert.Error(t, err)
err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
assert.Error(t, err)
ctx, cancel := context.WithTimeout(ctx, 0)
defer cancel()
err = proxy.Start(ctx, "127.0.0.1:", listenerC)
assert.IsType(t, http.ErrServerClosed, err)
}
func TestProxyHTTPRouter(t *testing.T) {
proxy := helperNewProxy(t)
router := proxy.httpRouter()
tests := []struct {
path string
method string
valid bool
}{
{"", "GET", false},
{"/", "GET", false},
{"/ping", "GET", true},
{"/ping", "HEAD", true},
{"/ping", "POST", false},
{"/submit", "POST", true},
{"/submit", "GET", false},
{"/submit/extra", "POST", false},
}
for _, test := range tests {
match := &mux.RouteMatch{}
ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
assert.True(t, ok == test.valid, test.path)
}
}
func TestProxyHTTPPing(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpPing())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(2), res.ContentLength)
res, err = client.Head(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(-1), res.ContentLength)
}
func TestProxyHTTPSubmit(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
tests := []struct {
input string
status int
output string
}{
{"", http.StatusBadRequest, "request body cannot be empty"},
{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
}
for _, test := range tests {
res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
assert.NoError(t, err)
assert.Equal(t, test.status, res.StatusCode)
if res.StatusCode > http.StatusOK {
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
} else {
assert.Equal(t, "application/json", res.Header.Get("Content-type"))
}
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
str := string(data)
assert.NoError(t, err)
assert.Equal(t, test.output, str)
}
}
func TestProxyHTTPSubmitForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
assert.Zero(t, res.ContentLength)
}
func TestProxyHTTPRespond(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(5), res.ContentLength)
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
assert.Equal(t, []byte("Hello"), data)
}
func TestProxyHTTPRespondForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(0), res.ContentLength)
}
func TestHTTPError(t *testing.T) {
_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
assert.Error(t, errTimeout)
tests := []struct {
input error
status int
output error
}{
{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
{context.Canceled, 444, nil},
{errTimeout, http.StatusRequestTimeout, nil},
{fmt.Errorf(""), http.StatusInternalServerError, nil},
}
for _, test := range tests {
status, err := httpError(http.StatusInternalServerError, test.input)
assert.Error(t, err)
assert.Equal(t, test.status, status)
if test.output == nil {
test.output = test.input
}
assert.Equal(t, test.output, err)
}
}
func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
t.Helper()
proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
assert.NoError(t, err)
assert.NotNil(t, proxy)
if len(secure) == 0 {
proxy.accessValidator = nil // Mark as insecure
}
return proxy
}