diff --git a/cmd/cloudflared/configuration.go b/cmd/cloudflared/configuration.go index 386fb900..345eb7ff 100644 --- a/cmd/cloudflared/configuration.go +++ b/cmd/cloudflared/configuration.go @@ -123,14 +123,12 @@ func handleDeprecatedOptions(c *cli.Context) error { return nil } -// validate url. It can be either from --url or argument func validateUrl(c *cli.Context) (string, error) { var url = c.String("url") if c.NArg() > 0 { - if c.IsSet("url") { - return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") + if !c.IsSet("url") { + return "", errors.New("Please specify an origin URL.") } - url = c.Args().Get(0) } validUrl, err := validation.ValidateUrl(url) return validUrl, err diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index 711f5a23..9f78872d 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -13,6 +13,8 @@ import ( "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tunneldns" + rapid "github.com/cloudflare/cloudflared/cmd/rapid" + "github.com/getsentry/raven-go" "github.com/mitchellh/go-homedir" "gopkg.in/urfave/cli.v2" @@ -400,6 +402,54 @@ func main() { }, ArgsUsage: " ", // can't be the empty string or we get the default output }, + { + Name: "rapid", + Action: func(c *cli.Context) error { + tags := make(map[string]string) + tags["hostname"] = c.String("hostname") + raven.SetTagsContext(tags) + + go rapid.StartProxy(c, logger) + + var err error + raven.CapturePanic(func() { err = startServer(c, shutdownC, graceShutdownC) }, nil) + if err != nil { + raven.CaptureError(err, nil) + } + return err + }, + Before: func(c *cli.Context) error { + if c.String("config") == "" { + logger.Warnf("Cannot determine default configuration path. No file %v in %v", defaultConfigFiles, defaultConfigDirs) + } + inputSource, err := findInputSourceContext(c) + if err != nil { + logger.WithError(err).Infof("Cannot load configuration from %s", c.String("config")) + return err + } else if inputSource != nil { + err := altsrc.ApplyInputSourceValues(c, inputSource, app.Flags) + if err != nil { + logger.WithError(err).Infof("Cannot apply configuration from %s", c.String("config")) + return err + } + logger.Infof("Applied configuration from %s", c.String("config")) + } + return nil + }, + Usage: "Rapid is an SQL over HTTP reverse proxy", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "db", + Value: false, + Usage: "Enable the SQL Gateway Proxy", + }, + &cli.StringFlag{ + Name: "address", + Value: "", + Usage: "Database connection string: db://user:pass", + }, + }, + }, } runApp(app, shutdownC, graceShutdownC) } diff --git a/cmd/rapid/rapid.go b/cmd/rapid/rapid.go new file mode 100644 index 00000000..26eafdf4 --- /dev/null +++ b/cmd/rapid/rapid.go @@ -0,0 +1,149 @@ +package rapid + +import ( + "database/sql" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "strings" + "time" + + _ "github.com/lib/pq" + cli "gopkg.in/urfave/cli.v2" + + "github.com/elgs/gosqljson" + + "github.com/gorilla/mux" + "github.com/sirupsen/logrus" +) + +type Message struct { + Connection Connection `json:"connection"` + Command string `json:"command"` + Params []interface{} `json:"params"` +} + +type Connection struct { + SSLMode string `json:"sslmode"` + Token string `json:"token"` +} + +type Response struct { + Columns []string `json:"columns"` + Rows [][]string `json:"rows"` + Error string `json:"error"` +} + +type Proxy struct { + Context *cli.Context + Router *mux.Router + Token string + User string + Password string + Driver string + Database string + Logger *logrus.Logger +} + +func StartProxy(c *cli.Context, logger *logrus.Logger) error { + proxy := NewProxy(c, logger) + + logger.Infof("Starting Rapid SQL Proxy on port %s", c.String("port")) + + err := http.ListenAndServe(":"+c.String("port"), proxy.Router) + if err != nil { + return err + } + + return nil +} + +func randID(n int, c *cli.Context) string { + charBytes := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") + b := make([]byte, n) + for i := range b { + b[i] = charBytes[rand.Intn(len(charBytes))] + } + return fmt.Sprintf("%s&%s", c.String("hostname"), b) +} + +// db://user:pass@dbname +func parseInfo(input string) (string, string, string, string) { + p1 := strings.Split(input, "://") + p2 := strings.Split(p1[1], ":") + p3 := strings.Split(p2[1], "@") + return p1[0], p2[0], p3[0], p3[1] +} + +func NewProxy(c *cli.Context, logger *logrus.Logger) *Proxy { + rand.Seed(time.Now().UnixNano()) + driver, user, pass, dbname := parseInfo(c.String("address")) + proxy := Proxy{ + Context: c, + Router: mux.NewRouter(), + Token: randID(64, c), + Logger: logger, + User: user, + Password: pass, + Database: dbname, + Driver: driver, + } + + logger.Info(fmt.Sprintf(` + + -------------------- + Rapid SQL Proxy + Token: %s + -------------------- + + `, proxy.Token)) + + proxy.Router.HandleFunc("/", proxy.proxyRequest).Methods("POST") + return &proxy +} + +func (proxy *Proxy) proxyRequest(rw http.ResponseWriter, req *http.Request) { + var message Message + response := Response{} + + err := json.NewDecoder(req.Body).Decode(&message) + if err != nil { + proxy.Logger.Error(err) + http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest) + return + } + + if message.Connection.Token != proxy.Token { + proxy.Logger.Error("Invalid token") + http.Error(rw, "400 - Invalid token", http.StatusBadRequest) + return + } + + connStr := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=%s", proxy.User, proxy.Password, proxy.Database, message.Connection.SSLMode) + + db, err := sql.Open(proxy.Driver, connStr) + defer db.Close() + + if err != nil { + proxy.Logger.Error(err) + http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest) + return + + } else { + proxy.Logger.Info("Forwarding SQL: ", message.Command) + rw.Header().Set("Content-Type", "application/json") + + headers, data, err := gosqljson.QueryDbToArray(db, "lower", message.Command, message.Params...) + + if err != nil { + proxy.Logger.Error(err) + http.Error(rw, fmt.Sprintf("400 - %s", err.Error()), http.StatusBadRequest) + return + + } else { + response = Response{headers, data, ""} + } + } + json.NewEncoder(rw).Encode(response) +}