319 lines
8.9 KiB
Go
319 lines
8.9 KiB
Go
package dbconnect
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/pkg/errors"
|
|
"github.com/xo/dburl"
|
|
|
|
// SQL drivers self-register with the database/sql package.
|
|
// https://github.com/golang/go/wiki/SQLDrivers
|
|
_ "github.com/denisenkom/go-mssqldb"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
"github.com/kshvakov/clickhouse"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// SQLClient is a Client that talks to a SQL database.
|
|
type SQLClient struct {
|
|
Dialect string
|
|
driver *sqlx.DB
|
|
}
|
|
|
|
// NewSQLClient creates a SQL client based on its URL scheme.
|
|
func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
|
|
res, err := dburl.Parse(originURL.String())
|
|
if err != nil {
|
|
helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
|
|
return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
|
|
}
|
|
|
|
// Establishes the driver, but does not test the connection.
|
|
driver, err := sqlx.Open(res.Driver, res.DSN)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
|
|
}
|
|
|
|
// Closes the driver, will occur when the context finishes.
|
|
go func() {
|
|
<-ctx.Done()
|
|
_ = driver.Close()
|
|
}()
|
|
|
|
return &SQLClient{driver.DriverName(), driver}, nil
|
|
}
|
|
|
|
// Ping verifies a connection to the database is still alive.
|
|
func (client *SQLClient) Ping(ctx context.Context) error {
|
|
return client.driver.PingContext(ctx)
|
|
}
|
|
|
|
// Submit queries or executes a command to the SQL database.
|
|
func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
|
|
txx, err := cmd.ValidateSQL(client.Dialect)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
|
|
defer cancel()
|
|
|
|
var res interface{}
|
|
|
|
// Get the next available sql.Conn and submit the Command.
|
|
err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
|
|
stmt := cmd.Statement
|
|
args := cmd.Arguments.Positional
|
|
|
|
if cmd.Mode == "query" {
|
|
res, err = sqlQuery(ctx, conn, stmt, args)
|
|
} else {
|
|
res, err = sqlExec(ctx, conn, stmt, args)
|
|
}
|
|
|
|
return err
|
|
})
|
|
|
|
return res, err
|
|
}
|
|
|
|
// ValidateSQL extends the contract of Command for SQL dialects:
|
|
// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
|
|
//
|
|
// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
|
|
func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
|
|
err := cmd.Validate()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mode, err := sqlMode(cmd.Mode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Mutates Arguments to only use positional arguments with the type sql.NamedArg.
|
|
// This is a required by the sql.Driver before submitting arguments.
|
|
cmd.Arguments.sql(dialect)
|
|
|
|
iso, err := sqlIsolation(cmd.Isolation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// When isolation is out-of-range, this is indicative that no
|
|
// transaction should be executed and sql.TxOptions should be nil.
|
|
if iso < sql.LevelDefault {
|
|
return nil, nil
|
|
}
|
|
|
|
// In query mode, execute the transaction in read-only, unless it's Microsoft SQL
|
|
// which does not support that type of transaction.
|
|
readOnly := mode == "query" && dialect != "mssql"
|
|
|
|
return &sql.TxOptions{Isolation: iso, ReadOnly: readOnly}, nil
|
|
}
|
|
|
|
// sqlConn gets the next available sql.Conn in the connection pool and runs a function to use it.
|
|
//
|
|
// If the transaction options are nil, run the useIt function outside a transaction.
|
|
// This is potentially an unsafe operation if the command does not clean up its state.
|
|
func sqlConn(ctx context.Context, driver *sqlx.DB, txx *sql.TxOptions, useIt func(*sql.Conn) error) error {
|
|
conn, err := driver.Conn(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
// If transaction options are specified, begin and defer a rollback to catch errors.
|
|
var tx *sql.Tx
|
|
if txx != nil {
|
|
tx, err = conn.BeginTx(ctx, txx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
}
|
|
|
|
err = useIt(conn)
|
|
|
|
// Check if useIt was successful and a transaction exists before committing.
|
|
if err == nil && tx != nil {
|
|
err = tx.Commit()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// sqlQuery queries rows on a sql.Conn and returns an array of result objects.
|
|
func sqlQuery(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) ([]map[string]interface{}, error) {
|
|
rows, err := conn.QueryContext(ctx, stmt, args...)
|
|
if err == nil {
|
|
return sqlRows(rows)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// sqlExec executes a command on a sql.Conn and returns the result of the operation.
|
|
func sqlExec(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) (sqlResult, error) {
|
|
exec, err := conn.ExecContext(ctx, stmt, args...)
|
|
if err == nil {
|
|
return sqlResultFrom(exec), nil
|
|
}
|
|
return sqlResult{}, err
|
|
}
|
|
|
|
// sql mutates Arguments to contain a positional []sql.NamedArg.
|
|
//
|
|
// The actual return type is []interface{} due to the native Golang
|
|
// function signatures for sql.Exec and sql.Query being generic.
|
|
func (args *Arguments) sql(dialect string) {
|
|
result := args.Positional
|
|
|
|
for i, val := range result {
|
|
result[i] = sqlArg("", val, dialect)
|
|
}
|
|
|
|
for key, val := range args.Named {
|
|
result = append(result, sqlArg(key, val, dialect))
|
|
}
|
|
|
|
args.Positional = result
|
|
args.Named = map[string]interface{}{}
|
|
}
|
|
|
|
// sqlArg creates a sql.NamedArg from a key-value pair and an optional dialect.
|
|
//
|
|
// Certain dialects will need to wrap objects, such as arrays, to conform its driver requirements.
|
|
func sqlArg(key, val interface{}, dialect string) sql.NamedArg {
|
|
switch reflect.ValueOf(val).Kind() {
|
|
|
|
// PostgreSQL and Clickhouse require arrays to be wrapped before
|
|
// being inserted into the driver interface.
|
|
case reflect.Slice, reflect.Array:
|
|
switch dialect {
|
|
case "postgres":
|
|
val = pq.Array(val)
|
|
case "clickhouse":
|
|
val = clickhouse.Array(val)
|
|
}
|
|
}
|
|
|
|
return sql.Named(fmt.Sprint(key), val)
|
|
}
|
|
|
|
// sqlIsolation tries to match a string to a sql.IsolationLevel.
|
|
func sqlIsolation(str string) (sql.IsolationLevel, error) {
|
|
if str == "none" {
|
|
return sql.IsolationLevel(-1), nil
|
|
}
|
|
|
|
for iso := sql.LevelDefault; ; iso++ {
|
|
if iso > sql.LevelLinearizable {
|
|
return -1, fmt.Errorf("cannot provide an invalid sql isolation level: '%s'", str)
|
|
}
|
|
|
|
if str == "" || strings.EqualFold(iso.String(), strings.ReplaceAll(str, "_", " ")) {
|
|
return iso, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// sqlMode tries to match a string to a command mode: 'query' or 'exec' for now.
|
|
func sqlMode(str string) (string, error) {
|
|
switch str {
|
|
case "query", "exec":
|
|
return str, nil
|
|
default:
|
|
return "", fmt.Errorf("cannot provide invalid sql mode: '%s'", str)
|
|
}
|
|
}
|
|
|
|
// sqlRows scans through a SQL result set and returns an array of objects.
|
|
func sqlRows(rows *sql.Rows) ([]map[string]interface{}, error) {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not extract columns from result")
|
|
}
|
|
defer rows.Close()
|
|
|
|
types, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
// Some drivers do not support type extraction, so fail silently and continue.
|
|
types = make([]*sql.ColumnType, len(columns))
|
|
}
|
|
|
|
values := make([]interface{}, len(columns))
|
|
pointers := make([]interface{}, len(columns))
|
|
|
|
var results []map[string]interface{}
|
|
for rows.Next() {
|
|
for i := range columns {
|
|
pointers[i] = &values[i]
|
|
}
|
|
_ = rows.Scan(pointers...)
|
|
|
|
// Convert a row, an array of values, into an object where
|
|
// each key is the name of its respective column.
|
|
entry := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
entry[col] = sqlValue(values[i], types[i])
|
|
}
|
|
results = append(results, entry)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// sqlValue handles special cases where sql.Rows does not return a "human-readable" object.
|
|
func sqlValue(val interface{}, col *sql.ColumnType) interface{} {
|
|
bytes, ok := val.([]byte)
|
|
if ok {
|
|
// Opportunistically check for embeded JSON and convert it to a first-class object.
|
|
var embeded interface{}
|
|
if json.Unmarshal(bytes, &embeded) == nil {
|
|
return embeded
|
|
}
|
|
|
|
// STOR-604: investigate a way to coerce PostgreSQL arrays '{a, b, ...}' into JSON.
|
|
// Although easy with strings, it becomes more difficult with special types like INET[].
|
|
|
|
return string(bytes)
|
|
}
|
|
|
|
return val
|
|
}
|
|
|
|
// sqlResult is a thin wrapper around sql.Result.
|
|
type sqlResult struct {
|
|
LastInsertId int64 `json:"last_insert_id"`
|
|
RowsAffected int64 `json:"rows_affected"`
|
|
}
|
|
|
|
// sqlResultFrom converts sql.Result into a JSON-marshable sqlResult.
|
|
func sqlResultFrom(res sql.Result) sqlResult {
|
|
insertID, errID := res.LastInsertId()
|
|
rowsAffected, errRows := res.RowsAffected()
|
|
|
|
// If an error occurs when extracting the result, it is because the
|
|
// driver does not support that specific field. Instead of passing this
|
|
// to the user, omit the field in the response.
|
|
if errID != nil {
|
|
insertID = -1
|
|
}
|
|
if errRows != nil {
|
|
rowsAffected = -1
|
|
}
|
|
|
|
return sqlResult{insertID, rowsAffected}
|
|
}
|