parent
13bf65ce4e
commit
759cd019be
@ -1,148 +0,0 @@
|
||||
package sqlgateway
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
|
||||
"github.com/elgs/gosqljson"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Connection Connection `json:"connection"`
|
||||
Command string `json:"command"`
|
||||
Params []interface{} `json:"params"`
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
SSLMode string `json:"sslmode"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Columns []string `json:"columns"`
|
||||
Rows [][]string `json:"rows"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
Context *cli.Context
|
||||
Router *mux.Router
|
||||
Token string
|
||||
User string
|
||||
Password string
|
||||
Driver string
|
||||
Database string
|
||||
Logger *logrus.Logger
|
||||
}
|
||||
|
||||
func StartProxy(c *cli.Context, logger *logrus.Logger, password string) error {
|
||||
proxy := NewProxy(c, logger, password)
|
||||
|
||||
logger.Infof("Starting SQL Gateway Proxy on port %s", strings.Split(c.String("url"), ":")[1])
|
||||
|
||||
err := http.ListenAndServe(":"+strings.Split(c.String("url"), ":")[1], proxy.Router)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func randID(n int, c *cli.Context) string {
|
||||
charBytes := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = charBytes[rand.Intn(len(charBytes))]
|
||||
}
|
||||
return fmt.Sprintf("%s&%s", c.String("hostname"), b)
|
||||
}
|
||||
|
||||
// db://user@dbname
|
||||
func parseInfo(input string) (string, string, string) {
|
||||
p1 := strings.Split(input, "://")
|
||||
p2 := strings.Split(p1[1], "@")
|
||||
return p1[0], p2[0], p2[1]
|
||||
}
|
||||
|
||||
func NewProxy(c *cli.Context, logger *logrus.Logger, pass string) *Proxy {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
driver, user, dbname := parseInfo(c.String("address"))
|
||||
proxy := Proxy{
|
||||
Context: c,
|
||||
Router: mux.NewRouter(),
|
||||
Token: randID(64, c),
|
||||
Logger: logger,
|
||||
User: user,
|
||||
Password: pass,
|
||||
Database: dbname,
|
||||
Driver: driver,
|
||||
}
|
||||
|
||||
logger.Info(fmt.Sprintf(`
|
||||
|
||||
--------------------
|
||||
SQL Gateway Proxy
|
||||
Token: %s
|
||||
--------------------
|
||||
|
||||
`, proxy.Token))
|
||||
|
||||
proxy.Router.HandleFunc("/", proxy.proxyRequest).Methods("POST")
|
||||
return &proxy
|
||||
}
|
||||
|
||||
func (proxy *Proxy) proxyRequest(rw http.ResponseWriter, req *http.Request) {
|
||||
var message Message
|
||||
response := Response{}
|
||||
|
||||
err := json.NewDecoder(req.Body).Decode(&message)
|
||||
if err != nil {
|
||||
proxy.Logger.Error(err)
|
||||
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if message.Connection.Token != proxy.Token {
|
||||
proxy.Logger.Error("Invalid token")
|
||||
http.Error(rw, "400 - Invalid token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=%s", proxy.User, proxy.Password, proxy.Database, message.Connection.SSLMode)
|
||||
|
||||
db, err := sql.Open(proxy.Driver, connStr)
|
||||
defer db.Close()
|
||||
|
||||
if err != nil {
|
||||
proxy.Logger.Error(err)
|
||||
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
} else {
|
||||
proxy.Logger.Info("Forwarding SQL: ", message.Command)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
|
||||
headers, data, err := gosqljson.QueryDbToArray(db, "lower", message.Command, message.Params...)
|
||||
|
||||
if err != nil {
|
||||
proxy.Logger.Error(err)
|
||||
http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
} else {
|
||||
response = Response{headers, data, ""}
|
||||
}
|
||||
}
|
||||
json.NewEncoder(rw).Encode(response)
|
||||
}
|
@ -0,0 +1,145 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Client is an interface to talk to any database.
|
||||
//
|
||||
// Currently, the only implementation is SQLClient, but its structure
|
||||
// should be designed to handle a MongoClient or RedisClient in the future.
|
||||
type Client interface {
|
||||
Ping(context.Context) error
|
||||
Submit(context.Context, *Command) (interface{}, error)
|
||||
}
|
||||
|
||||
// NewClient creates a database client based on its URL scheme.
|
||||
func NewClient(ctx context.Context, originURL *url.URL) (Client, error) {
|
||||
return NewSQLClient(ctx, originURL)
|
||||
}
|
||||
|
||||
// Command is a standard, non-vendor format for submitting database commands.
|
||||
//
|
||||
// When determining the scope of this struct, refer to the following litmus test:
|
||||
// Could this (roughly) conform to SQL, Document-based, and Key-value command formats?
|
||||
type Command struct {
|
||||
Statement string `json:"statement"`
|
||||
Arguments Arguments `json:"arguments,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Isolation string `json:"isolation,omitempty"`
|
||||
Timeout time.Duration `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
// Validate enforces the contract of Command: non empty statement (both in length and logic),
|
||||
// lowercase mode and isolation, non-zero timeout, and valid Arguments.
|
||||
func (cmd *Command) Validate() error {
|
||||
if cmd.Statement == "" {
|
||||
return fmt.Errorf("cannot provide an empty statement")
|
||||
}
|
||||
|
||||
if strings.Map(func(char rune) rune {
|
||||
if char == ';' || unicode.IsSpace(char) {
|
||||
return -1
|
||||
}
|
||||
return char
|
||||
}, cmd.Statement) == "" {
|
||||
return fmt.Errorf("cannot provide a statement with no logic: '%s'", cmd.Statement)
|
||||
}
|
||||
|
||||
cmd.Mode = strings.ToLower(cmd.Mode)
|
||||
cmd.Isolation = strings.ToLower(cmd.Isolation)
|
||||
|
||||
if cmd.Timeout.Nanoseconds() <= 0 {
|
||||
cmd.Timeout = 24 * time.Hour
|
||||
}
|
||||
|
||||
return cmd.Arguments.Validate()
|
||||
}
|
||||
|
||||
// UnmarshalJSON converts a byte representation of JSON into a Command, which is also validated.
|
||||
func (cmd *Command) UnmarshalJSON(data []byte) error {
|
||||
// Alias is required to avoid infinite recursion from the default UnmarshalJSON.
|
||||
type Alias Command
|
||||
alias := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(cmd),
|
||||
}
|
||||
|
||||
err := json.Unmarshal(data, &alias)
|
||||
if err == nil {
|
||||
err = cmd.Validate()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Arguments is a wrapper for either map-based or array-based Command arguments.
|
||||
//
|
||||
// Each field is mutually-exclusive and some Client implementations may not
|
||||
// support both fields (eg. MySQL does not accept named arguments).
|
||||
type Arguments struct {
|
||||
Named map[string]interface{}
|
||||
Positional []interface{}
|
||||
}
|
||||
|
||||
// Validate enforces the contract of Arguments: non nil, mutually exclusive, and no empty or reserved keys.
|
||||
func (args *Arguments) Validate() error {
|
||||
if args.Named == nil {
|
||||
args.Named = map[string]interface{}{}
|
||||
}
|
||||
if args.Positional == nil {
|
||||
args.Positional = []interface{}{}
|
||||
}
|
||||
|
||||
if len(args.Named) > 0 && len(args.Positional) > 0 {
|
||||
return fmt.Errorf("both named and positional arguments cannot be specified: %+v and %+v", args.Named, args.Positional)
|
||||
}
|
||||
|
||||
for key := range args.Named {
|
||||
if key == "" {
|
||||
return fmt.Errorf("named arguments cannot contain an empty key: %+v", args.Named)
|
||||
}
|
||||
if !utf8.ValidString(key) {
|
||||
return fmt.Errorf("named argument does not conform to UTF-8 encoding: %s", key)
|
||||
}
|
||||
if strings.HasPrefix(key, "_") {
|
||||
return fmt.Errorf("named argument cannot start with a reserved keyword '_': %s", key)
|
||||
}
|
||||
if unicode.IsNumber([]rune(key)[0]) {
|
||||
return fmt.Errorf("named argument cannot start with a number: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON converts a byte representation of JSON into Arguments, which is also validated.
|
||||
func (args *Arguments) UnmarshalJSON(data []byte) error {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal(data, &obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
named, ok := obj.(map[string]interface{})
|
||||
if ok {
|
||||
args.Named = named
|
||||
} else {
|
||||
positional, ok := obj.([]interface{})
|
||||
if ok {
|
||||
args.Positional = positional
|
||||
} else {
|
||||
return fmt.Errorf("arguments must either be an object {\"0\":\"val\"} or an array [\"val\"]: %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
return args.Validate()
|
||||
}
|
@ -0,0 +1,183 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCommandValidateEmpty(t *testing.T) {
|
||||
stmts := []string{
|
||||
"",
|
||||
";",
|
||||
" \n\t",
|
||||
";\n;\t;",
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
cmd := Command{Statement: stmt}
|
||||
|
||||
assert.Error(t, cmd.Validate(), stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandValidateMode(t *testing.T) {
|
||||
modes := []string{
|
||||
"",
|
||||
"query",
|
||||
"ExEc",
|
||||
"PREPARE",
|
||||
}
|
||||
|
||||
for _, mode := range modes {
|
||||
cmd := Command{Statement: "Ok", Mode: mode}
|
||||
|
||||
assert.NoError(t, cmd.Validate(), mode)
|
||||
assert.Equal(t, strings.ToLower(mode), cmd.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandValidateIsolation(t *testing.T) {
|
||||
isos := []string{
|
||||
"",
|
||||
"default",
|
||||
"read_committed",
|
||||
"SNAPshot",
|
||||
}
|
||||
|
||||
for _, iso := range isos {
|
||||
cmd := Command{Statement: "Ok", Isolation: iso}
|
||||
|
||||
assert.NoError(t, cmd.Validate(), iso)
|
||||
assert.Equal(t, strings.ToLower(iso), cmd.Isolation)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandValidateTimeout(t *testing.T) {
|
||||
cmd := Command{Statement: "Ok", Timeout: 0}
|
||||
|
||||
assert.NoError(t, cmd.Validate())
|
||||
assert.NotZero(t, cmd.Timeout)
|
||||
|
||||
cmd = Command{Statement: "Ok", Timeout: 1 * time.Second}
|
||||
|
||||
assert.NoError(t, cmd.Validate())
|
||||
assert.Equal(t, 1*time.Second, cmd.Timeout)
|
||||
}
|
||||
|
||||
func TestCommandValidateArguments(t *testing.T) {
|
||||
cmd := Command{Statement: "Ok", Arguments: Arguments{
|
||||
Named: map[string]interface{}{"key": "val"},
|
||||
Positional: []interface{}{"val"},
|
||||
}}
|
||||
|
||||
assert.Error(t, cmd.Validate())
|
||||
}
|
||||
|
||||
func TestCommandUnmarshalJSON(t *testing.T) {
|
||||
strs := []string{
|
||||
"{\"statement\":\"Ok\"}",
|
||||
"{\"statement\":\"Ok\",\"arguments\":[0, 3.14, \"apple\"],\"mode\":\"query\"}",
|
||||
"{\"statement\":\"Ok\",\"isolation\":\"read_uncommitted\",\"timeout\":1000}",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var cmd Command
|
||||
assert.NoError(t, json.Unmarshal([]byte(str), &cmd), str)
|
||||
}
|
||||
|
||||
strs = []string{
|
||||
"",
|
||||
"\"",
|
||||
"{}",
|
||||
"{\"argument\":{\"key\":\"val\"}}",
|
||||
"{\"statement\":[\"Ok\"]}",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var cmd Command
|
||||
assert.Error(t, json.Unmarshal([]byte(str), &cmd), str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsValidateNotNil(t *testing.T) {
|
||||
args := Arguments{}
|
||||
|
||||
assert.NoError(t, args.Validate())
|
||||
assert.NotNil(t, args.Named)
|
||||
assert.NotNil(t, args.Positional)
|
||||
}
|
||||
|
||||
func TestArgumentsValidateMutuallyExclusive(t *testing.T) {
|
||||
args := []Arguments{
|
||||
Arguments{},
|
||||
Arguments{Named: map[string]interface{}{"key": "val"}},
|
||||
Arguments{Positional: []interface{}{"val"}},
|
||||
}
|
||||
|
||||
for _, arg := range args {
|
||||
assert.NoError(t, arg.Validate())
|
||||
assert.False(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
|
||||
}
|
||||
|
||||
args = []Arguments{
|
||||
Arguments{
|
||||
Named: map[string]interface{}{"key": "val"},
|
||||
Positional: []interface{}{"val"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, arg := range args {
|
||||
assert.Error(t, arg.Validate())
|
||||
assert.True(t, len(arg.Named) > 0 && len(arg.Positional) > 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsValidateKeys(t *testing.T) {
|
||||
keys := []string{
|
||||
"",
|
||||
"_",
|
||||
"_key",
|
||||
"1",
|
||||
"1key",
|
||||
"\xf0\x28\x8c\xbc", // non-utf8
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
args := Arguments{Named: map[string]interface{}{key: "val"}}
|
||||
|
||||
assert.Error(t, args.Validate(), key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgumentsUnmarshalJSON(t *testing.T) {
|
||||
strs := []string{
|
||||
"{}",
|
||||
"{\"key\":\"val\"}",
|
||||
"{\"key\":[1, 3.14, {\"key\":\"val\"}]}",
|
||||
"[]",
|
||||
"[\"key\",\"val\"]",
|
||||
"[{}]",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var args Arguments
|
||||
assert.NoError(t, json.Unmarshal([]byte(str), &args), str)
|
||||
}
|
||||
|
||||
strs = []string{
|
||||
"",
|
||||
"\"",
|
||||
"1",
|
||||
"\"key\"",
|
||||
"{\"key\",\"val\"}",
|
||||
}
|
||||
|
||||
for _, str := range strs {
|
||||
var args Arguments
|
||||
assert.Error(t, json.Unmarshal([]byte(str), &args), str)
|
||||
}
|
||||
}
|
@ -0,0 +1,157 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
"gopkg.in/urfave/cli.v2/altsrc"
|
||||
)
|
||||
|
||||
// Cmd is the entrypoint command for dbconnect.
|
||||
//
|
||||
// The tunnel package is responsible for appending this to tunnel.Commands().
|
||||
func Cmd() *cli.Command {
|
||||
return &cli.Command{
|
||||
Category: "Database Connect (ALPHA)",
|
||||
Name: "db-connect",
|
||||
Usage: "Access your SQL database from Cloudflare Workers or the browser",
|
||||
ArgsUsage: " ",
|
||||
Description: `
|
||||
Creates a connection between your database and the Cloudflare edge.
|
||||
Now you can execute SQL commands anywhere you can send HTTPS requests.
|
||||
|
||||
Connect your database with any of the following commands, you can also try the "playground" without a database:
|
||||
|
||||
cloudflared db-connect --hostname sql.mysite.com --url postgres://user:pass@localhost?sslmode=disable \
|
||||
--auth-domain mysite.cloudflareaccess.com --application-aud my-access-policy-tag
|
||||
cloudflared db-connect --hostname sql-dev.mysite.com --url mysql://localhost --insecure
|
||||
cloudflared db-connect --playground
|
||||
|
||||
Requests should be authenticated using Cloudflare Access, learn more about how to enable it here:
|
||||
|
||||
https://developers.cloudflare.com/access/service-auth/service-token/
|
||||
`,
|
||||
Flags: []cli.Flag{
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "url",
|
||||
Usage: "URL to the database (eg. postgres://user:pass@localhost?sslmode=disable)",
|
||||
EnvVars: []string{"TUNNEL_URL"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "hostname",
|
||||
Usage: "Hostname to accept commands over HTTPS (eg. sql.mysite.com)",
|
||||
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "auth-domain",
|
||||
Usage: "Cloudflare Access authentication domain for your account (eg. mysite.cloudflareaccess.com)",
|
||||
EnvVars: []string{"TUNNEL_ACCESS_AUTH_DOMAIN"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "application-aud",
|
||||
Usage: "Cloudflare Access application \"AUD\" to verify JWTs from requests",
|
||||
EnvVars: []string{"TUNNEL_ACCESS_APPLICATION_AUD"},
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "insecure",
|
||||
Usage: "Disable authentication, the database will be open to the Internet",
|
||||
Value: false,
|
||||
EnvVars: []string{"TUNNEL_ACCESS_INSECURE"},
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "playground",
|
||||
Usage: "Run a temporary, in-memory SQLite3 database for testing",
|
||||
Value: false,
|
||||
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "loglevel",
|
||||
Value: "debug", // Make it more verbose than the tunnel default 'info'.
|
||||
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||
Hidden: true,
|
||||
}),
|
||||
},
|
||||
Before: CmdBefore,
|
||||
Action: CmdAction,
|
||||
Hidden: true,
|
||||
}
|
||||
}
|
||||
|
||||
// CmdBefore runs some validation checks before running the command.
|
||||
func CmdBefore(c *cli.Context) error {
|
||||
// Show the help text is no flags are specified.
|
||||
if c.NumFlags() == 0 {
|
||||
return cli.ShowSubcommandHelp(c)
|
||||
}
|
||||
|
||||
// Hello-world and playground are synonymous with each other,
|
||||
// unset hello-world to prevent tunnel from initializing the hello package.
|
||||
if c.IsSet("hello-world") {
|
||||
c.Set("playground", "true")
|
||||
c.Set("hello-world", "false")
|
||||
}
|
||||
|
||||
// Unix-socket database urls are supported, but the logic is the same as url.
|
||||
if c.IsSet("unix-socket") {
|
||||
c.Set("url", c.String("unix-socket"))
|
||||
c.Set("unix-socket", "")
|
||||
}
|
||||
|
||||
// When playground mode is enabled, run with an in-memory database.
|
||||
if c.IsSet("playground") {
|
||||
c.Set("url", "sqlite3::memory:?cache=shared")
|
||||
c.Set("insecure", strconv.FormatBool(!c.IsSet("auth-domain") && !c.IsSet("application-aud")))
|
||||
}
|
||||
|
||||
// At this point, insecure configurations are valid.
|
||||
if c.Bool("insecure") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure that secure configurations specify a hostname, domain, and tag for JWT validation.
|
||||
if !c.IsSet("hostname") || !c.IsSet("auth-domain") || !c.IsSet("application-aud") {
|
||||
log.Fatal("must specify --hostname, --auth-domain, and --application-aud unless you want to run in --insecure mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CmdAction starts the Proxy and sets the url in cli.Context to point to the Proxy address.
|
||||
func CmdAction(c *cli.Context) error {
|
||||
// STOR-612: sync with context in tunnel daemon.
|
||||
ctx := context.Background()
|
||||
|
||||
var proxy *Proxy
|
||||
var err error
|
||||
if c.Bool("insecure") {
|
||||
proxy, err = NewInsecureProxy(ctx, c.String("url"))
|
||||
} else {
|
||||
proxy, err = NewSecureProxy(ctx, c.String("url"), c.String("auth-domain"), c.String("application-aud"))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return err
|
||||
}
|
||||
|
||||
listenerC := make(chan net.Listener)
|
||||
defer close(listenerC)
|
||||
|
||||
// Since the Proxy should only talk to the tunnel daemon, find the next available
|
||||
// localhost port and start to listen to requests.
|
||||
go func() {
|
||||
err := proxy.Start(ctx, "127.0.0.1:", listenerC)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Block until the the Proxy is online, retreive its address, and change the url to point to it.
|
||||
// This is effectively "handing over" control to the tunnel package so it can run the tunnel daemon.
|
||||
c.Set("url", "https://"+(<-listenerC).Addr().String())
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
func TestCmd(t *testing.T) {
|
||||
tests := [][]string{
|
||||
{"cloudflared", "db-connect", "--playground"},
|
||||
{"cloudflared", "db-connect", "--playground", "--hostname", "sql.mysite.com"},
|
||||
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--insecure"},
|
||||
{"cloudflared", "db-connect", "--url", "sqlite3::memory:?cache=shared", "--hostname", "sql.mysite.com", "--auth-domain", "mysite.cloudflareaccess.com", "--application-aud", "aud"},
|
||||
}
|
||||
|
||||
app := &cli.App{
|
||||
Name: "cloudflared",
|
||||
Commands: []*cli.Command{Cmd()},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.NoError(t, app.Run(test))
|
||||
}
|
||||
}
|
@ -0,0 +1,271 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Proxy is an HTTP server that proxies requests to a Client.
|
||||
type Proxy struct {
|
||||
client Client
|
||||
accessValidator *validation.Access
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewInsecureProxy creates a Proxy that talks to a Client at an origin.
|
||||
//
|
||||
// In insecure mode, the Proxy will allow all Command requests.
|
||||
func NewInsecureProxy(ctx context.Context, origin string) (*Proxy, error) {
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "must provide a valid database url")
|
||||
}
|
||||
|
||||
client, err := NewClient(ctx, originURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = client.Ping(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not connect to the database")
|
||||
}
|
||||
|
||||
return &Proxy{client, nil, logrus.New()}, nil
|
||||
}
|
||||
|
||||
// NewSecureProxy creates a Proxy that talks to a Client at an origin.
|
||||
//
|
||||
// In secure mode, the Proxy will reject any Command requests that are
|
||||
// not authenticated by Cloudflare Access with a valid JWT.
|
||||
func NewSecureProxy(ctx context.Context, origin, authDomain, applicationAUD string) (*Proxy, error) {
|
||||
proxy, err := NewInsecureProxy(ctx, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validator, err := validation.NewAccessValidator(ctx, authDomain, authDomain, applicationAUD)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxy.accessValidator = validator
|
||||
|
||||
return proxy, err
|
||||
}
|
||||
|
||||
// IsInsecure gets whether the Proxy will accept a Command from any source.
|
||||
func (proxy *Proxy) IsInsecure() bool {
|
||||
return proxy.accessValidator == nil
|
||||
}
|
||||
|
||||
// IsAllowed checks whether a http.Request is allowed to receive data.
|
||||
//
|
||||
// By default, requests must pass through Cloudflare Access for authentication.
|
||||
// If the proxy is explcitly set to insecure mode, all requests will be allowed.
|
||||
func (proxy *Proxy) IsAllowed(r *http.Request, verbose ...bool) bool {
|
||||
if proxy.IsInsecure() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Access and Tunnel should prevent bad JWTs from even reaching the origin,
|
||||
// but validate tokens anyway as an abundance of caution.
|
||||
err := proxy.accessValidator.ValidateRequest(r.Context(), r)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Warn administrators that invalid JWTs are being rejected. This is indicative
|
||||
// of either a misconfiguration of the CLI or a massive failure of upstream systems.
|
||||
if len(verbose) > 0 {
|
||||
proxy.httpLog(r, err).Error("Failed JWT authentication")
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Start the Proxy at a given address and notify the listener channel when the server is online.
|
||||
func (proxy *Proxy) Start(ctx context.Context, addr string, listenerC chan<- net.Listener) error {
|
||||
// STOR-611: use a seperate listener and consider web socket support.
|
||||
httpListener, err := hello.CreateTLSListener(addr)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not create listener at %s", addr)
|
||||
}
|
||||
|
||||
errC := make(chan error)
|
||||
defer close(errC)
|
||||
|
||||
// Starts the HTTP server and begins to serve requests.
|
||||
go func() {
|
||||
errC <- proxy.httpListen(ctx, httpListener)
|
||||
}()
|
||||
|
||||
// Continually ping the server until it comes online or 10 attempts fail.
|
||||
go func() {
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = http.Get("http://" + httpListener.Addr().String())
|
||||
|
||||
// Once no error was detected, notify the listener channel and return.
|
||||
if err == nil {
|
||||
listenerC <- httpListener
|
||||
return
|
||||
}
|
||||
|
||||
// Backoff between requests to ping the server.
|
||||
<-time.After(1 * time.Second)
|
||||
}
|
||||
errC <- errors.Wrap(err, "took too long for the http server to start")
|
||||
}()
|
||||
|
||||
return <-errC
|
||||
}
|
||||
|
||||
// httpListen starts the httpServer and blocks until the context closes.
|
||||
func (proxy *Proxy) httpListen(ctx context.Context, listener net.Listener) error {
|
||||
httpServer := &http.Server{
|
||||
Addr: listener.Addr().String(),
|
||||
Handler: proxy.httpRouter(),
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
httpServer.Close()
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
return httpServer.Serve(listener)
|
||||
}
|
||||
|
||||
// httpRouter creates a mux.Router for the Proxy.
|
||||
func (proxy *Proxy) httpRouter() *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/ping", proxy.httpPing()).Methods("GET", "HEAD")
|
||||
router.HandleFunc("/submit", proxy.httpSubmit()).Methods("POST")
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// httpPing tests the connection to the database.
|
||||
//
|
||||
// By default, this endpoint is unauthenticated to allow for health checks.
|
||||
// To enable authentication, Cloudflare Access must be enabled on this route.
|
||||
func (proxy *Proxy) httpPing() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
err := proxy.client.Ping(ctx)
|
||||
|
||||
if err == nil {
|
||||
proxy.httpRespond(w, r, http.StatusOK, "")
|
||||
} else {
|
||||
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// httpSubmit sends a command to the database and returns its response.
|
||||
//
|
||||
// By default, this endpoint will reject requests that do not pass through Cloudflare Access.
|
||||
// To disable authentication, the --insecure flag must be specified in the command line.
|
||||
func (proxy *Proxy) httpSubmit() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !proxy.IsAllowed(r, true) {
|
||||
proxy.httpRespondErr(w, r, http.StatusForbidden, fmt.Errorf(""))
|
||||
return
|
||||
}
|
||||
|
||||
var cmd Command
|
||||
err := json.NewDecoder(r.Body).Decode(&cmd)
|
||||
if err != nil {
|
||||
proxy.httpRespondErr(w, r, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
data, err := proxy.client.Submit(ctx, &cmd)
|
||||
|
||||
if err != nil {
|
||||
proxy.httpRespondErr(w, r, http.StatusUnprocessableEntity, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data)
|
||||
if err != nil {
|
||||
proxy.httpRespondErr(w, r, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// httpRespond writes a status code and string response to the response writer.
|
||||
func (proxy *Proxy) httpRespond(w http.ResponseWriter, r *http.Request, status int, message string) {
|
||||
w.WriteHeader(status)
|
||||
|
||||
// Only expose the message detail of the reponse if the request is not HEAD
|
||||
// and the user is authenticated. For example, this prevents an unauthenticated
|
||||
// failed health check from accidentally leaking sensitive information about the Client.
|
||||
if r.Method != http.MethodHead && proxy.IsAllowed(r) {
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
fmt.Fprint(w, message)
|
||||
}
|
||||
}
|
||||
|
||||
// httpRespondErr is similar to httpRespond, except it formats errors to be more friendly.
|
||||
func (proxy *Proxy) httpRespondErr(w http.ResponseWriter, r *http.Request, defaultStatus int, err error) {
|
||||
status, err := httpError(defaultStatus, err)
|
||||
|
||||
proxy.httpRespond(w, r, status, err.Error())
|
||||
if len(err.Error()) > 0 {
|
||||
proxy.httpLog(r, err).Warn("Database proxy error")
|
||||
}
|
||||
}
|
||||
|
||||
// httpLog returns a logrus.Entry that is formatted to output a request Cf-ray.
|
||||
func (proxy *Proxy) httpLog(r *http.Request, err error) *logrus.Entry {
|
||||
return proxy.logger.WithContext(r.Context()).WithField("CF-RAY", r.Header.Get("Cf-ray")).WithError(err)
|
||||
}
|
||||
|
||||
// httpError extracts common errors and returns an status code and friendly error.
|
||||
func httpError(defaultStatus int, err error) (int, error) {
|
||||
if err == nil {
|
||||
return http.StatusNotImplemented, fmt.Errorf("error expected but found none")
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
return http.StatusBadRequest, fmt.Errorf("request body cannot be empty")
|
||||
}
|
||||
|
||||
if err == context.DeadlineExceeded {
|
||||
return http.StatusRequestTimeout, err
|
||||
}
|
||||
|
||||
_, ok := err.(net.Error)
|
||||
if ok {
|
||||
return http.StatusRequestTimeout, err
|
||||
}
|
||||
|
||||
if err == context.Canceled {
|
||||
// Does not exist in Golang, but would be: http.StatusClientClosedWithoutResponse
|
||||
return 444, err
|
||||
}
|
||||
|
||||
return defaultStatus, err
|
||||
}
|
@ -0,0 +1,238 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewInsecureProxy(t *testing.T) {
|
||||
origins := []string{
|
||||
"",
|
||||
":/",
|
||||
"http://localhost",
|
||||
"tcp://localhost:9000?debug=true",
|
||||
"mongodb://127.0.0.1",
|
||||
}
|
||||
|
||||
for _, origin := range origins {
|
||||
proxy, err := NewInsecureProxy(context.Background(), origin)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, proxy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyIsAllowed(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
|
||||
assert.True(t, proxy.IsAllowed(req))
|
||||
|
||||
proxy = helperNewProxy(t, true)
|
||||
req.Header.Set("Cf-access-jwt-assertion", "xxx")
|
||||
assert.False(t, proxy.IsAllowed(req))
|
||||
}
|
||||
|
||||
func TestProxyStart(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
ctx := context.Background()
|
||||
listenerC := make(chan net.Listener)
|
||||
|
||||
err := proxy.Start(ctx, "1.1.1.1:", listenerC)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
|
||||
assert.Error(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 0)
|
||||
defer cancel()
|
||||
|
||||
err = proxy.Start(ctx, "127.0.0.1:", listenerC)
|
||||
assert.IsType(t, http.ErrServerClosed, err)
|
||||
}
|
||||
|
||||
func TestProxyHTTPRouter(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
router := proxy.httpRouter()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
method string
|
||||
valid bool
|
||||
}{
|
||||
{"", "GET", false},
|
||||
{"/", "GET", false},
|
||||
{"/ping", "GET", true},
|
||||
{"/ping", "HEAD", true},
|
||||
{"/ping", "POST", false},
|
||||
{"/submit", "POST", true},
|
||||
{"/submit", "GET", false},
|
||||
{"/submit/extra", "POST", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
match := &mux.RouteMatch{}
|
||||
ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
|
||||
|
||||
assert.True(t, ok == test.valid, test.path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPPing(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
|
||||
server := httptest.NewServer(proxy.httpPing())
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, int64(2), res.ContentLength)
|
||||
|
||||
res, err = client.Head(server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, int64(-1), res.ContentLength)
|
||||
}
|
||||
|
||||
func TestProxyHTTPSubmit(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
|
||||
server := httptest.NewServer(proxy.httpSubmit())
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
status int
|
||||
output string
|
||||
}{
|
||||
{"", http.StatusBadRequest, "request body cannot be empty"},
|
||||
{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
|
||||
{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
|
||||
{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
|
||||
{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.status, res.StatusCode)
|
||||
if res.StatusCode > http.StatusOK {
|
||||
assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
|
||||
} else {
|
||||
assert.Equal(t, "application/json", res.Header.Get("Content-type"))
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
defer res.Body.Close()
|
||||
str := string(data)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.output, str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPSubmitForbidden(t *testing.T) {
|
||||
proxy := helperNewProxy(t, true)
|
||||
|
||||
server := httptest.NewServer(proxy.httpSubmit())
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
assert.Zero(t, res.ContentLength)
|
||||
}
|
||||
|
||||
func TestProxyHTTPRespond(t *testing.T) {
|
||||
proxy := helperNewProxy(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
|
||||
}))
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.Equal(t, int64(5), res.ContentLength)
|
||||
|
||||
data, err := ioutil.ReadAll(res.Body)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, []byte("Hello"), data)
|
||||
}
|
||||
|
||||
func TestProxyHTTPRespondForbidden(t *testing.T) {
|
||||
proxy := helperNewProxy(t, true)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
|
||||
}))
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusAccepted, res.StatusCode)
|
||||
assert.Equal(t, int64(0), res.ContentLength)
|
||||
}
|
||||
|
||||
func TestHTTPError(t *testing.T) {
|
||||
_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
|
||||
assert.Error(t, errTimeout)
|
||||
|
||||
tests := []struct {
|
||||
input error
|
||||
status int
|
||||
output error
|
||||
}{
|
||||
{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
|
||||
{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
|
||||
{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
|
||||
{context.Canceled, 444, nil},
|
||||
{errTimeout, http.StatusRequestTimeout, nil},
|
||||
{fmt.Errorf(""), http.StatusInternalServerError, nil},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
status, err := httpError(http.StatusInternalServerError, test.input)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, test.status, status)
|
||||
if test.output == nil {
|
||||
test.output = test.input
|
||||
}
|
||||
assert.Equal(t, test.output, err)
|
||||
}
|
||||
}
|
||||
|
||||
func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
|
||||
t.Helper()
|
||||
|
||||
proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proxy)
|
||||
|
||||
if len(secure) == 0 {
|
||||
proxy.accessValidator = nil // Mark as insecure
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
@ -0,0 +1,318 @@
|
||||
package dbconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xo/dburl"
|
||||
|
||||
// SQL drivers self-register with the database/sql package.
|
||||
// https://github.com/golang/go/wiki/SQLDrivers
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/kshvakov/clickhouse"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// SQLClient is a Client that talks to a SQL database.
|
||||
type SQLClient struct {
|
||||
Dialect string
|
||||
driver *sqlx.DB
|
||||
}
|
||||
|
||||
// NewSQLClient creates a SQL client based on its URL scheme.
|
||||
func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
|
||||
res, err := dburl.Parse(originURL.String())
|
||||
if err != nil {
|
||||
helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
|
||||
return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
|
||||
}
|
||||
|
||||
// Establishes the driver, but does not test the connection.
|
||||
driver, err := sqlx.Open(res.Driver, res.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
|
||||
}
|
||||
|
||||
// Closes the driver, will occur when the context finishes.
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
driver.Close()
|
||||
}()
|
||||
|
||||
return &SQLClient{driver.DriverName(), driver}, nil
|
||||
}
|
||||
|
||||
// Ping verifies a connection to the database is still alive.
|
||||
func (client *SQLClient) Ping(ctx context.Context) error {
|
||||
return client.driver.PingContext(ctx)
|
||||
}
|
||||
|
||||
// Submit queries or executes a command to the SQL database.
|
||||
func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
|
||||
txx, err := cmd.ValidateSQL(client.Dialect)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
|
||||
defer cancel()
|
||||
|
||||
var res interface{}
|
||||
|
||||
// Get the next available sql.Conn and submit the Command.
|
||||
err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
|
||||
stmt := cmd.Statement
|
||||
args := cmd.Arguments.Positional
|
||||
|
||||
if cmd.Mode == "query" {
|
||||
res, err = sqlQuery(ctx, conn, stmt, args)
|
||||
} else {
|
||||
res, err = sqlExec(ctx, conn, stmt, args)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// ValidateSQL extends the contract of Command for SQL dialects:
|
||||
// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
|
||||
//
|
||||
// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
|
||||
func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
|
||||
err := cmd.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mode, err := sqlMode(cmd.Mode)
|
||||