Add db-connect, a SQL over HTTPS server

pull/149/head
Ashcon Partovi 3 years ago
parent 13bf65ce4e
commit 759cd019be

@ -85,3 +85,14 @@ stretch: &stretch
- make test
jessie: *stretch
# cfsetup compose
default-stack: test_dbconnect
test_dbconnect:
compose:
up-args:
- --renew-anon-volumes
- --abort-on-container-exit
- --exit-code-from=cloudflared
files:
- dbconnect_tests/dbconnect.yaml

@ -26,15 +26,15 @@ var logger = log.CreateLogger()
type lock struct {
lockFilePath string
backoff *origin.BackoffHandler
sigHandler *signalHandler
sigHandler *signalHandler
}
type signalHandler struct {
sigChannel chan os.Signal
signals []os.Signal
sigChannel chan os.Signal
signals []os.Signal
}
func (s *signalHandler) register(handler func()){
func (s *signalHandler) register(handler func()) {
s.sigChannel = make(chan os.Signal, 1)
signal.Notify(s.sigChannel, s.signals...)
go func(s *signalHandler) {
@ -59,8 +59,8 @@ func newLock(path string) *lock {
return &lock{
lockFilePath: lockPath,
backoff: &origin.BackoffHandler{MaxRetries: 7},
sigHandler: &signalHandler{
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
sigHandler: &signalHandler{
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
},
}
}
@ -68,8 +68,8 @@ func newLock(path string) *lock {
func (l *lock) Acquire() error {
// Intercept SIGINT and SIGTERM to release lock before exiting
l.sigHandler.register(func() {
l.deleteLockFile()
os.Exit(0)
l.deleteLockFile()
os.Exit(0)
})
// Check for a path.lock file

@ -7,18 +7,18 @@ import (
"net"
"net/url"
"os"
"reflect"
"runtime"
"runtime/trace"
"sync"
"syscall"
"time"
"github.com/cloudflare/cloudflared/awsuploader"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/cmd/sqlgateway"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/dbconnect"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
@ -37,7 +37,6 @@ import (
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh/terminal"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
)
@ -138,43 +137,7 @@ func Commands() []*cli.Command {
ArgsUsage: " ", // can't be the empty string or we get the default output
Hidden: false,
},
{
Name: "db",
Action: func(c *cli.Context) error {
tags := make(map[string]string)
tags["hostname"] = c.String("hostname")
raven.SetTagsContext(tags)
fmt.Printf("\nSQL Database Password: ")
pass, err := terminal.ReadPassword(int(syscall.Stdin))
if err != nil {
logger.Error(err)
}
go sqlgateway.StartProxy(c, logger, string(pass))
raven.CapturePanic(func() { err = tunnel(c) }, nil)
if err != nil {
raven.CaptureError(err, nil)
}
return err
},
Before: Before,
Usage: "SQL Gateway is an SQL over HTTP reverse proxy",
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "db",
Value: true,
Usage: "Enable the SQL Gateway Proxy",
},
&cli.StringFlag{
Name: "address",
Value: "",
Usage: "Database connection string: db://user:pass",
},
},
Hidden: true,
},
dbConnectCmd(),
}
var subcommands []*cli.Command
@ -644,6 +607,60 @@ func addPortIfMissing(uri *url.URL, port int) string {
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
func dbConnectCmd() *cli.Command {
cmd := dbconnect.Cmd()
// Append the tunnel commands so users can customize the daemon settings.
cmd.Flags = appendFlags(Flags(), cmd.Flags...)
// Override before to run tunnel validation before dbconnect validation.
cmd.Before = func(c *cli.Context) error {
err := Before(c)
if err == nil {
err = dbconnect.CmdBefore(c)
}
return err
}
// Override action to setup the Proxy, then if successful, start the tunnel daemon.
cmd.Action = func(c *cli.Context) error {
err := dbconnect.CmdAction(c)
if err == nil {
err = tunnel(c)
}
return err
}
return cmd
}
// appendFlags will append extra flags to a slice of flags.
//
// The cli package will panic if two flags exist with the same name,
// so if extraFlags contains a flag that was already defined, modify the
// original flags to use the extra version.
func appendFlags(flags []cli.Flag, extraFlags ...cli.Flag) []cli.Flag {
for _, extra := range extraFlags {
var found bool
// Check if an extra flag overrides an existing flag.
for i, flag := range flags {
if reflect.DeepEqual(extra.Names(), flag.Names()) {
flags[i] = extra
found = true
break
}
}
// Append the extra flag if it has nothing to override.
if !found {
flags = append(flags, extra)
}
}
return flags
}
func tunnelFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{
&cli.StringFlag{

@ -1,148 +0,0 @@
package sqlgateway
import (
"database/sql"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"strings"
"time"
_ "github.com/lib/pq"
cli "gopkg.in/urfave/cli.v2"
"github.com/elgs/gosqljson"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
)
type Message struct {
Connection Connection `json:"connection"`
Command string `json:"command"`
Params []interface{} `json:"params"`
}
type Connection struct {
SSLMode string `json:"sslmode"`
Token string `json:"token"`
}
type Response struct {
Columns []string `json:"columns"`
Rows [][]string `json:"rows"`
Error string `json:"error"`
}
type Proxy struct {
Context *cli.Context
Router *mux.Router
Token string
User string
Password string
Driver string
Database string
Logger *logrus.Logger
}
func StartProxy(c *cli.Context, logger *logrus.Logger, password string) error {
proxy := NewProxy(c, logger, password)
logger.Infof("Starting SQL Gateway Proxy on port %s", strings.Split(c.String("url"), ":")[1])
err := http.ListenAndServe(":"+strings.Split(c.String("url"), ":")[1], proxy.Router)
if err != nil {
return err
}
return nil
}
func randID(n int, c *cli.Context) string {
charBytes := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
b := make([]byte, n)
for i := range b {
b[i] = charBytes[rand.Intn(len(charBytes))]
}
return fmt.Sprintf("%s&%s", c.String("hostname"), b)
}
// db://user@dbname
func parseInfo(input string) (string, string, string) {
p1 := strings.Split(input, "://")
p2 := strings.Split(p1[1], "@")
return p1[0], p2[0], p2[1]
}
func NewProxy(c *cli.Context, logger *logrus.Logger, pass string) *Proxy {
rand.Seed(time.Now().UnixNano())
driver, user, dbname := parseInfo(c.String("address"))
proxy := Proxy{
Context: c,
Router: mux.NewRouter(),
Token: randID(64, c),
Logger: logger,
User: user,
Password: pass,
Database: dbname,
Driver: driver,
}
logger.Info(fmt.Sprintf(`
--------------------
SQL Gateway Proxy
Token: %s
--------------------
`, proxy.Token))
proxy.Router.HandleFunc("/", proxy.proxyRequest).Methods("POST")
return &proxy
}
func (proxy *Proxy) proxyRequest(rw http.ResponseWriter, req *http.Request) {
var message Message
response := Response{}
err := json.NewDecoder(req.Body).Decode(&message)
if err != nil {
proxy.Logger.Error(err)
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
return
}
if message.Connection.Token != proxy.Token {
proxy.Logger.Error("Invalid token")
http.Error(rw, "400 - Invalid token", http.StatusBadRequest)
return
}
connStr := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=%s", proxy.User, proxy.Password, proxy.Database, message.Connection.SSLMode)
db, err := sql.Open(proxy.Driver, connStr)
defer db.Close()
if err != nil {
proxy.Logger.Error(err)
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
return
} else {
proxy.Logger.Info("Forwarding SQL: ", message.Command)
rw.Header().Set("Content-Type", "application/json")
headers, data, err := gosqljson.QueryDbToArray(db, "lower", message.Command, message.Params...)
if err != nil {
proxy.Logger.Error(err)
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
return
} else {
response = Response{headers, data, ""}
}
}
json.NewEncoder(rw).Encode(response)
}

@ -0,0 +1,145 @@
package dbconnect
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"unicode"
"unicode/utf8"
)
// Client is an interface to talk to any database.
//
// Currently, the only implementation is SQLClient, but its structure
// should be designed to handle a MongoClient or RedisClient in the future.
type Client interface {
Ping(context.Context) error
Submit(context.Context, *Command) (interface{}, error)
}
// NewClient creates a database client based on its URL scheme.
func NewClient(ctx context.Context, originURL *url.URL) (Client, error) {
return NewSQLClient(ctx, originURL)
}
// Command is a standard, non-vendor format for submitting database commands.
//
// When determining the scope of this struct, refer to the following litmus test:
// Could this (roughly) conform to SQL, Document-based, and Key-value command formats?
type Command struct {
Statement string `json:"statement"`
Arguments Arguments `json:"arguments,omitempty"`
Mode string `json:"mode,omitempty"`
Isolation string `json:"isolation,omitempty"`
Timeout time.Duration `json:"timeout,omitempty"`
}
// Validate enforces the contract of Command: non empty statement (both in length and logic),
// lowercase mode and isolation, non-zero timeout, and valid Arguments.
func (cmd *Command) Validate() error {
if cmd.Statement == "" {
return fmt.Errorf("cannot provide an empty statement")
}
if strings.Map(func(char rune) rune {
if char == ';' || unicode.IsSpace(char) {
return -1
}
return char
}, cmd.Statement) == "" {
return fmt.Errorf("cannot provide a statement with no logic: '%s'", cmd.Statement)
}
cmd.Mode = strings.ToLower(cmd.Mode)
cmd.Isolation = strings.ToLower(cmd.Isolation)
if cmd.Timeout.Nanoseconds() <= 0 {
cmd.Timeout = 24 * time.Hour
}
return cmd.Arguments.Validate()
}
// UnmarshalJSON converts a byte representation of JSON into a Command, which is also validated.
func (cmd *Command) UnmarshalJSON(data []byte) error {
// Alias is required to avoid infinite recursion from the default UnmarshalJSON.
type Alias Command
alias := &struct {
*Alias
}{
Alias: (*Alias)(cmd),
}
err := json.Unmarshal(data, &alias)
if err == nil {
err = cmd.Validate()
}
return err
}
// Arguments is a wrapper for either map-based or array-based Command arguments.
//
// Each field is mutually-exclusive and some Client implementations may not
// support both fields (eg. MySQL does not accept named arguments).
type Arguments struct {
Named map[string]interface{}
Positional []interface{}
}
// Validate enforces the contract of Arguments: non nil, mutually exclusive, and no empty or reserved keys.
func (args *Arguments) Validate() error {
if args.Named == nil {
args.Named = map[string]interface{}{}
}
if args.Positional == nil {
args.Positional = []interface{}{}
}
if len(args.Named) > 0 && len(args.Positional) > 0 {
return fmt.Errorf("both named and positional arguments cannot be specified: %+v and %+v", args.Named, args.Positional)
}
for key := range args.Named {
if key == "" {
return fmt.Errorf("named arguments cannot contain an empty key: %+v", args.Named)
}
if !utf8.ValidString(key) {
return fmt.Errorf("named argument does not conform to UTF-8 encoding: %s", key)
}
if strings.HasPrefix(key, "_") {
return fmt.Errorf("named argument cannot start with a reserved keyword '_': %s", key)
}
if unicode.IsNumber([]rune(key)[0]) {
return fmt.Errorf("named argument cannot start with a number: %s", key)
}
}
return nil
}
// UnmarshalJSON converts a byte representation of JSON into Arguments, which is also validated.
func (args *Arguments) UnmarshalJSON(data []byte) error {
var obj interface{}
err := json.Unmarshal(data, &obj)
if err != nil {
return err
}
named, ok := obj.(map[string]interface{})
if ok {
args.Named = named
} else {
positional, ok := obj.([]interface{})
if ok {
args.Positional = positional
} else {
return fmt.Errorf("arguments must either be an object {\"0\":\"val\"} or an array [\"val\"]: %s", string(data))
}
}
return args.Validate()
}

@ -0,0 +1,183 @@
package dbconnect
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCommandValidateEmpty(t *testing.T) {
stmts := []string{
"",
";",
" \n\t",
";\n;\t;",
}
for _, stmt := range stmts {
cmd := Command{Statement: stmt}
assert.Error(t, cmd.Validate(), stmt)
}
}
func TestCommandValidateMode(t *testing.T) {
modes := []string{
"",
"query",
"ExEc",
"PREPARE",
}
for _, mode := range modes {
cmd := Command{Statement: "Ok", Mode: mode}
assert.NoError(t, cmd.Validate(), mode)
assert.Equal(t, strings.ToLower(mode), cmd.Mode)
}
}
func TestCommandValidateIsolation(t *testing.T) {
isos := []string{
"",
"default",
"read_committed",
"SNAPshot",
}
for _, iso := range isos {
cmd := Command{Statement: "Ok", Isolation: iso}
assert.NoError(t, cmd.Validate(), iso)
assert.Equal(t, strings.ToLower(iso), cmd.Isolation)
}
}
func TestCommandValidateTimeout(t *testing.T) {
cmd := Command{Statement: "Ok", Timeout: 0}
assert.NoError(t, cmd.Validate())
assert.NotZero(t, cmd.Timeout)
cmd = Command{Statement: "Ok", Timeout: 1 * time.Second}
assert.NoError(t, cmd.Validate())
assert.Equal(t, 1*time.Second, cmd.Timeout)
}
func TestCommandValidateArguments(t *testing.T) {
cmd := Command{Statement: "Ok", Arguments: Arguments{
Named: map[string]interface{}{"key": "val"},
Positional: []interface{}{"val"},
}}
assert.Error(t, cmd.Validate())
}
func TestCommandUnmarshalJSON(t *testing.T) {
strs := []string{
"{\"statement\":\"Ok\"}",
"{\"statement\":\"Ok\",\"arguments\":[0, 3.14, \"apple\"],\"mode\":\"query\"}",
"{\"statement\":\"Ok\",\"isolation\":\"read_uncommitted\",\"timeout\":1000}",
}
for _, str := range strs {
var cmd Command
assert.NoError(t, json.Unmarshal([]byte(str), &cmd), str)
}
strs = []string{
"",
"\"",
"{}",
"{\"argument\":{\"key\":\"val\"}}",
"{\"statement\":[\"Ok\"]}",
}
for _, str := range strs {
var cmd Command
assert.Error(t, json.Unmarshal([]byte(str), &cmd), str)
}
}
func TestArgumentsValidateNotNil(t *testing.T) {
args := Arguments{}
assert.NoError(t, args.Validate())
assert.NotNil(t, args.Named)
assert.NotNil(t, args.Positional)
}
func TestArgumentsValidateMutuallyExclusive(t *testing.T) {
args := []Arguments{
Arguments{},
Arguments{Named: map[string]interface{}{"key": "val"}},
Arguments{Positional: []interface{}{"val"}},
}
for _, arg := range args {
assert.NoError(t, arg.Validate())
assert.False(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
}
args = []Arguments{
Arguments{
Named: map[string]interface{}{"key": "val"},
Positional: []interface{}{"val"},
},
}
for _, arg := range args {
assert.Error(t, arg.Validate())
assert.True(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
}
}
func TestArgumentsValidateKeys(t *testing.T) {
keys := []string{
"",
"_",
"_key",
"1",
"1key",
"\xf0\x28\x8c\xbc", // non-utf8
}
for _, key := range keys {
args := Arguments{Named: map[string]interface{}{key: "val"}}
assert.Error(t, args.Validate(), key)
}
}
func TestArgumentsUnmarshalJSON(t *testing.T) {
strs := []string{
"{}",
"{\"key\":\"val\"}",
"{\"key\":[1, 3.14, {\"key\":\"val\"}]}",
"[]",
"[\"key\",\"val\"]",
"[{}]",
}
for _, str := range strs {
var args Arguments
assert.NoError(t, json.Unmarshal([]byte(str), &args), str)
}
strs = []string{
"",
"\"",
"1",
"\"key\"",
"{\"key\",\"val\"}",
}
for _, str := range strs {
var args Arguments
assert.Error(t, json.Unmarshal([]byte(str), &args), str)
}
}

@ -0,0 +1,157 @@
package dbconnect
import (
"context"
"log"
"net"
"strconv"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
)
// Cmd is the entrypoint command for dbconnect.
//
// The tunnel package is responsible for appending this to tunnel.Commands().
func Cmd() *cli.Command {
return &cli.Command{
Category: "Database Connect (ALPHA)",
Name: "db-connect",
Usage: "Access your SQL database from Cloudflare Workers or the browser",
ArgsUsage: " ",
Description: `
Creates a connection between your database and the Cloudflare edge.
Now you can execute SQL commands anywhere you can send HTTPS requests.
Connect your database with any of the following commands, you can also try the "playground" without a database:
cloudflared db-connect --hostname sql.mysite.com --url postgres://user:pass@localhost?sslmode=disable \
--auth-domain mysite.cloudflareaccess.com --application-aud my-access-policy-tag
cloudflared db-connect --hostname sql-dev.mysite.com --url mysql://localhost --insecure
cloudflared db-connect --playground
Requests should be authenticated using Cloudflare Access, learn more about how to enable it here:
https://developers.cloudflare.com/access/service-auth/service-token/
`,
Flags: []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Usage: "URL to the database (eg. postgres://user:pass@localhost?sslmode=disable)",
EnvVars: []string{"TUNNEL_URL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "hostname",
Usage: "Hostname to accept commands over HTTPS (eg. sql.mysite.com)",
EnvVars: []string{"TUNNEL_HOSTNAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "auth-domain",
Usage: "Cloudflare Access authentication domain for your account (eg. mysite.cloudflareaccess.com)",
EnvVars: []string{"TUNNEL_ACCESS_AUTH_DOMAIN"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "application-aud",
Usage: "Cloudflare Access application \"AUD\" to verify JWTs from requests",
EnvVars: []string{"TUNNEL_ACCESS_APPLICATION_AUD"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "insecure",
Usage: "Disable authentication, the database will be open to the Internet",
Value: false,
EnvVars: []string{"TUNNEL_ACCESS_INSECURE"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "playground",
Usage: "Run a temporary, in-memory SQLite3 database for testing",
Value: false,
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "loglevel",
Value: "debug", // Make it more verbose than the tunnel default 'info'.
EnvVars: []string{"TUNNEL_LOGLEVEL"},
Hidden: true,
}),
},
Before: CmdBefore,
Action: CmdAction,
Hidden: true,
}
}
// CmdBefore runs some validation checks before running the command.
func CmdBefore(c *cli.Context) error {
// Show the help text is no flags are specified.
if c.NumFlags() == 0 {
return cli.ShowSubcommandHelp(c)
}
// Hello-world and playground are synonymous with each other,
// unset hello-world to prevent tunnel from initializing the hello package.
if c.IsSet("hello-world") {
c.Set("playground", "true")
c.Set("hello-world", "false")
}
// Unix-socket database urls are supported, but the logic is the same as url.
if c.IsSet("unix-socket") {
c.Set("url", c.String("unix-socket"))
c.Set("unix-socket", "")
}
// When playground mode is enabled, run with an in-memory database.
if c.IsSet("playground") {
c.Set("url", "sqlite3::memory:?cache=shared")
c.Set("insecure", strconv.FormatBool(!c.IsSet("auth-domain") && !c.IsSet("application-aud")))
}
// At this point, insecure configurations are valid.
if c.Bool("insecure") {
return nil
}
// Ensure that secure configurations specify a hostname, domain, and tag for JWT validation.
if !c.IsSet("hostname") || !c.IsSet("auth-domain") || !c.IsSet("application-aud") {
log.Fatal("must specify --hostname, --auth-domain, and --application-aud unless you want to run in --insecure mode")
}
return nil
}
// CmdAction starts the Proxy and sets the url in cli.Context to point to the Proxy address.
func CmdAction(c *cli.Context) error {
// STOR-612: sync with context in tunnel daemon.
ctx := context.Background()
var proxy *Proxy
var err error
if c.Bool("insecure") {
proxy, err = NewInsecureProxy(ctx, c.String("url"))
} else {
proxy, err = NewSecureProxy(ctx, c.String("url"), c.String("auth-domain"), c.String("application-aud"))
}
if err != nil {
log.Fatal(err)
return err
}
listenerC := make(chan net.Listener)
defer close(listenerC)
// Since the Proxy should only talk to the tunnel daemon, find the next available
// localhost port and start to listen to requests.
go func() {
err := proxy.Start(ctx, "127.0.0.1:", listenerC)
if err != nil {
log.Fatal(err)
}
}()
// Block until the the Proxy is online, retreive its address, and change the url to point to it.
// This is effectively "handing over" control to the tunnel package so it can run the tunnel daemon.
c.Set("url", "https://"+(<-listenerC).Addr().String())
return nil
}

@ -0,0 +1,27 @@
package dbconnect
import (
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/urfave/cli.v2"
)
func TestCmd(t *testing.T) {
tests := [][]string{
{"cloudflared", "db-connect", "--playground"},
{"cloudflared", "db-connect", "--playground", "--hostname", "sql.mysite.com"},
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--insecure"},
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--hostname", "sql.mysite.com", "--auth-domain", "mysite.cloudflareaccess.com", "--application-aud", "aud"},
}
app := &cli.App{
Name: "cloudflared",
Commands: []*cli.Command{Cmd()},
}
for _, test := range tests {
assert.NoError(t, app.Run(test))
}
}

@ -0,0 +1,271 @@
package dbconnect
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/validation"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// Proxy is an HTTP server that proxies requests to a Client.
type Proxy struct {
client Client
accessValidator *validation.Access
logger *logrus.Logger
}
// NewInsecureProxy creates a Proxy that talks to a Client at an origin.
//
// In insecure mode, the Proxy will allow all Command requests.
func NewInsecureProxy(ctx context.Context, origin string) (*Proxy, error) {
originURL, err := url.Parse(origin)
if err != nil {
return nil, errors.Wrap(err, "must provide a valid database url")
}
client, err := NewClient(ctx, originURL)
if err != nil {
return nil, err
}
err = client.Ping(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not connect to the database")
}
return &Proxy{client, nil, logrus.New()}, nil
}
// NewSecureProxy creates a Proxy that talks to a Client at an origin.
//
// In secure mode, the Proxy will reject any Command requests that are
// not authenticated by Cloudflare Access with a valid JWT.
func NewSecureProxy(ctx context.Context, origin, authDomain, applicationAUD string) (*Proxy, error) {
proxy, err := NewInsecureProxy(ctx, origin)
if err != nil {
return nil, err
}
validator, err := validation.NewAccessValidator(ctx, authDomain, authDomain, applicationAUD)
if err != nil {
return nil, err
}
proxy.accessValidator = validator
return proxy, err
}
// IsInsecure gets whether the Proxy will accept a Command from any source.
func (proxy *Proxy) IsInsecure() bool {
return proxy.accessValidator == nil
}
// IsAllowed checks whether a http.Request is allowed to receive data.
//
// By default, requests must pass through Cloudflare Access for authentication.
// If the proxy is explcitly set to insecure mode, all requests will be allowed.
func (proxy *Proxy) IsAllowed(r *http.Request, verbose ...bool) bool {
if proxy.IsInsecure() {
return true
}
// Access and Tunnel should prevent bad JWTs from even reaching the origin,
// but validate tokens anyway as an abundance of caution.
err := proxy.accessValidator.ValidateRequest(r.Context(), r)
if err == nil {
return true
}
// Warn administrators that invalid JWTs are being rejected. This is indicative
// of either a misconfiguration of the CLI or a massive failure of upstream systems.
if len(verbose) > 0 {
proxy.httpLog(r, err).Error("Failed JWT authentication")
}
return false
}
// Start the Proxy at a given address and notify the listener channel when the server is online.
func (proxy *Proxy) Start(ctx context.Context, addr string, listenerC chan<- net.Listener) error {
// STOR-611: use a seperate listener and consider web socket support.
httpListener, err := hello.CreateTLSListener(addr)
if err != nil {
return errors.Wrapf(err, "could not create listener at %s", addr)
}
errC := make(chan error)
defer close(errC)
// Starts the HTTP server and begins to serve requests.
go func() {
errC <- proxy.httpListen(ctx, httpListener)
}()
// Continually ping the server until it comes online or 10 attempts fail.
go func() {
var err error
for i := 0; i < 10; i++ {
_, err = http.Get("http://" + httpListener.Addr().String())
// Once no error was detected, notify the listener channel and return.
if err == nil {
listenerC <- httpListener
return
}
// Backoff between requests to ping the server.
<-time.After(1 * time.Second)
}
errC <- errors.Wrap(err, "took too long for the http server to start")
}()
return <-errC
}
// httpListen starts the httpServer and blocks until the context closes.
func (proxy *Proxy) httpListen(ctx context.Context, listener net.Listener) error {
httpServer := &http.Server{
Addr: listener.Addr().String(),
Handler: proxy.httpRouter(),
ReadTimeout: 10 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 60 * time.Second,
}
go func() {
<-ctx.Done()
httpServer.Close()
listener.Close()
}()
return httpServer.Serve(listener)
}
// httpRouter creates a mux.Router for the Proxy.
func (proxy *Proxy) httpRouter() *mux.Router {
router := mux.NewRouter()
router.HandleFunc("/ping", proxy.httpPing()).Methods("GET", "HEAD")
router.HandleFunc("/submit", proxy.httpSubmit()).Methods("POST")
return router
}
// httpPing tests the connection to the database.
//
// By default, this endpoint is unauthenticated to allow for health checks.
// To enable authentication, Cloudflare Access must be enabled on this route.
func (proxy *Proxy) httpPing() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
err := proxy.client.Ping(ctx)
if err == nil {
proxy.httpRespond(w, r, http.StatusOK, "")
} else {
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
}
}
}
// httpSubmit sends a command to the database and returns its response.
//
// By default, this endpoint will reject requests that do not pass through Cloudflare Access.
// To disable authentication, the --insecure flag must be specified in the command line.
func (proxy *Proxy) httpSubmit() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !proxy.IsAllowed(r, true) {
proxy.httpRespondErr(w, r, http.StatusForbidden, fmt.Errorf(""))
return
}
var cmd Command
err := json.NewDecoder(r.Body).Decode(&cmd)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusBadRequest, err)
return
}
ctx := r.Context()
data, err := proxy.client.Submit(ctx, &cmd)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusUnprocessableEntity, err)
return
}
w.Header().Set("Content-type", "application/json")
err = json.NewEncoder(w).Encode(data)
if err != nil {
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
}
}
}
// httpRespond writes a status code and string response to the response writer.
func (proxy *Proxy) httpRespond(w http.ResponseWriter, r *http.Request, status int, message string) {
w.WriteHeader(status)
// Only expose the message detail of the reponse if the request is not HEAD
// and the user is authenticated. For example, this prevents an unauthenticated
// failed health check from accidentally leaking sensitive information about the Client.
if r.Method != http.MethodHead && proxy.IsAllowed(r) {
if message == "" {
message = http.StatusText(status)
}
fmt.Fprint(w, message)
}
}
// httpRespondErr is similar to httpRespond, except it formats errors to be more friendly.
func (proxy *Proxy) httpRespondErr(w http.ResponseWriter, r *http.Request, defaultStatus int, err error) {
status, err := httpError(defaultStatus, err)
proxy.httpRespond(w, r, status, err.Error())
if len(err.Error()) > 0 {
proxy.httpLog(r, err).Warn("Database proxy error")
}
}
// httpLog returns a logrus.Entry that is formatted to output a request Cf-ray.
func (proxy *Proxy) httpLog(r *http.Request, err error) *logrus.Entry {
return proxy.logger.WithContext(r.Context()).WithField("CF-RAY", r.Header.Get("Cf-ray")).WithError(err)
}
// httpError extracts common errors and returns an status code and friendly error.
func httpError(defaultStatus int, err error) (int, error) {
if err == nil {
return http.StatusNotImplemented, fmt.Errorf("error expected but found none")
}
if err == io.EOF {
return http.StatusBadRequest, fmt.Errorf("request body cannot be empty")
}
if err == context.DeadlineExceeded {
return http.StatusRequestTimeout, err
}
_, ok := err.(net.Error)
if ok {
return http.StatusRequestTimeout, err
}
if err == context.Canceled {
// Does not exist in Golang, but would be: http.StatusClientClosedWithoutResponse
return 444, err
}
return defaultStatus, err
}

@ -0,0 +1,238 @@
package dbconnect
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
)
func TestNewInsecureProxy(t *testing.T) {
origins := []string{
"",
":/",
"http://localhost",
"tcp://localhost:9000?debug=true",
"mongodb://127.0.0.1",
}
for _, origin := range origins {
proxy, err := NewInsecureProxy(context.Background(), origin)
assert.Error(t, err)
assert.Empty(t, proxy)
}
}
func TestProxyIsAllowed(t *testing.T) {
proxy := helperNewProxy(t)
req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
assert.True(t, proxy.IsAllowed(req))
proxy = helperNewProxy(t, true)
req.Header.Set("Cf-access-jwt-assertion", "xxx")
assert.False(t, proxy.IsAllowed(req))
}
func TestProxyStart(t *testing.T) {
proxy := helperNewProxy(t)
ctx := context.Background()
listenerC := make(chan net.Listener)
err := proxy.Start(ctx, "1.1.1.1:", listenerC)
assert.Error(t, err)
err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
assert.Error(t, err)
ctx, cancel := context.WithTimeout(ctx, 0)
defer cancel()
err = proxy.Start(ctx, "127.0.0.1:", listenerC)
assert.IsType(t, http.ErrServerClosed, err)
}
func TestProxyHTTPRouter(t *testing.T) {
proxy := helperNewProxy(t)
router := proxy.httpRouter()
tests := []struct {
path string
method string
valid bool
}{
{"", "GET", false},
{"/", "GET", false},
{"/ping", "GET", true},
{"/ping", "HEAD", true},
{"/ping", "POST", false},
{"/submit", "POST", true},
{"/submit", "GET", false},
{"/submit/extra", "POST", false},
}
for _, test := range tests {
match := &mux.RouteMatch{}
ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
assert.True(t, ok == test.valid, test.path)
}
}
func TestProxyHTTPPing(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpPing())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(2), res.ContentLength)
res, err = client.Head(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, int64(-1), res.ContentLength)
}
func TestProxyHTTPSubmit(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
tests := []struct {
input string
status int
output string
}{
{"", http.StatusBadRequest, "request body cannot be empty"},
{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
}
for _, test := range tests {
res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
assert.NoError(t, err)
assert.Equal(t, test.status, res.StatusCode)
if res.StatusCode > http.StatusOK {
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
} else {
assert.Equal(t, "application/json", res.Header.Get("Content-type"))
}
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
str := string(data)
assert.NoError(t, err)
assert.Equal(t, test.output, str)
}
}
func TestProxyHTTPSubmitForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(proxy.httpSubmit())
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
assert.Zero(t, res.ContentLength)
}
func TestProxyHTTPRespond(t *testing.T) {
proxy := helperNewProxy(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(5), res.ContentLength)
data, err := ioutil.ReadAll(res.Body)
defer res.Body.Close()
assert.Equal(t, []byte("Hello"), data)
}
func TestProxyHTTPRespondForbidden(t *testing.T) {
proxy := helperNewProxy(t, true)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
}))
defer server.Close()
client := server.Client()
res, err := client.Get(server.URL)
assert.NoError(t, err)
assert.Equal(t, http.StatusAccepted, res.StatusCode)
assert.Equal(t, int64(0), res.ContentLength)
}
func TestHTTPError(t *testing.T) {
_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
assert.Error(t, errTimeout)
tests := []struct {
input error
status int
output error
}{
{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
{context.Canceled, 444, nil},
{errTimeout, http.StatusRequestTimeout, nil},
{fmt.Errorf(""), http.StatusInternalServerError, nil},
}
for _, test := range tests {
status, err := httpError(http.StatusInternalServerError, test.input)
assert.Error(t, err)
assert.Equal(t, test.status, status)
if test.output == nil {
test.output = test.input
}
assert.Equal(t, test.output, err)
}
}
func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
t.Helper()
proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
assert.NoError(t, err)
assert.NotNil(t, proxy)
if len(secure) == 0 {
proxy.accessValidator = nil // Mark as insecure
}
return proxy
}

@ -0,0 +1,318 @@
package dbconnect
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/url"
"reflect"
"strings"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/xo/dburl"
// SQL drivers self-register with the database/sql package.
// https://github.com/golang/go/wiki/SQLDrivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
"github.com/kshvakov/clickhouse"
"github.com/lib/pq"
)
// SQLClient is a Client that talks to a SQL database.
type SQLClient struct {
Dialect string
driver *sqlx.DB
}
// NewSQLClient creates a SQL client based on its URL scheme.
func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
res, err := dburl.Parse(originURL.String())
if err != nil {
helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
}
// Establishes the driver, but does not test the connection.
driver, err := sqlx.Open(res.Driver, res.DSN)
if err != nil {
return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
}
// Closes the driver, will occur when the context finishes.
go func() {
<-ctx.Done()
driver.Close()
}()
return &SQLClient{driver.DriverName(), driver}, nil
}
// Ping verifies a connection to the database is still alive.
func (client *SQLClient) Ping(ctx context.Context) error {
return client.driver.PingContext(ctx)
}
// Submit queries or executes a command to the SQL database.
func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
txx, err := cmd.ValidateSQL(client.Dialect)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
defer cancel()
var res interface{}
// Get the next available sql.Conn and submit the Command.
err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
stmt := cmd.Statement
args := cmd.Arguments.Positional
if cmd.Mode == "query" {
res, err = sqlQuery(ctx, conn, stmt, args)
} else {
res, err = sqlExec(ctx, conn, stmt, args)
}
return err
})
return res, err
}
// ValidateSQL extends the contract of Command for SQL dialects:
// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
//
// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
err := cmd.Validate()
if err != nil {
return nil, err
}
mode, err := sqlMode(cmd.Mode)