cloudflared-mirror/vendor/github.com/kshvakov/clickhouse/stmt.go

186 lines
4.1 KiB
Go

package clickhouse
import (
"bytes"
"context"
"database/sql/driver"
"unicode"
"github.com/kshvakov/clickhouse/lib/data"
)
type stmt struct {
ch *clickhouse
query string
counter int
numInput int
isInsert bool
}
var emptyResult = &result{}
func (stmt *stmt) NumInput() int {
switch {
case stmt.ch.block != nil:
return len(stmt.ch.block.Columns)
case stmt.numInput < 0:
return 0
}
return stmt.numInput
}
func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) {
return stmt.execContext(context.Background(), args)
}
func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs := make([]driver.Value, len(args))
for i, nv := range args {
dargs[i] = nv.Value
}
return stmt.execContext(ctx, dargs)
}
func (stmt *stmt) execContext(ctx context.Context, args []driver.Value) (driver.Result, error) {
if stmt.isInsert {
stmt.counter++
if err := stmt.ch.block.AppendRow(args); err != nil {
return nil, err
}
if (stmt.counter % stmt.ch.blockSize) == 0 {
stmt.ch.logf("[exec] flush block")
if err := stmt.ch.writeBlock(stmt.ch.block); err != nil {
return nil, err
}
if err := stmt.ch.encoder.Flush(); err != nil {
return nil, err
}
}
return emptyResult, nil
}
if err := stmt.ch.sendQuery(stmt.bind(convertOldArgs(args))); err != nil {
return nil, err
}
if err := stmt.ch.process(); err != nil {
return nil, err
}
return emptyResult, nil
}
func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
return stmt.queryContext(context.Background(), convertOldArgs(args))
}
func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return stmt.queryContext(ctx, args)
}
func (stmt *stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
finish := stmt.ch.watchCancel(ctx)
if err := stmt.ch.sendQuery(stmt.bind(args)); err != nil {
finish()
return nil, err
}
meta, err := stmt.ch.readMeta()
if err != nil {
finish()
return nil, err
}
rows := rows{
ch: stmt.ch,
finish: finish,
stream: make(chan *data.Block, 50),
columns: meta.ColumnNames(),
blockColumns: meta.Columns,
}
go rows.receiveData()
return &rows, nil
}
func (stmt *stmt) Close() error {
stmt.ch.logf("[stmt] close")
return nil
}
func (stmt *stmt) bind(args []driver.NamedValue) string {
var (
buf bytes.Buffer
index int
keyword bool
inBetween bool
like = newMatcher("like")
limit = newMatcher("limit")
between = newMatcher("between")
and = newMatcher("and")
)
switch {
case stmt.NumInput() != 0:
reader := bytes.NewReader([]byte(stmt.query))
for {
if char, _, err := reader.ReadRune(); err == nil {
switch char {
case '@':
if param := paramParser(reader); len(param) != 0 {
for _, v := range args {
if len(v.Name) != 0 && v.Name == param {
buf.WriteString(quote(v.Value))
}
}
}
case '?':
if keyword && index < len(args) && len(args[index].Name) == 0 {
buf.WriteString(quote(args[index].Value))
index++
} else {
buf.WriteRune(char)
}
default:
switch {
case
char == '=',
char == '<',
char == '>',
char == '(',
char == ',',
char == '+',
char == '-',
char == '*',
char == '/',
char == '[':
keyword = true
default:
if limit.matchRune(char) || like.matchRune(char) {
keyword = true
} else if between.matchRune(char) {
keyword = true
inBetween = true
} else if inBetween && and.matchRune(char) {
keyword = true
inBetween = false
} else {
keyword = keyword && unicode.IsSpace(char)
}
}
buf.WriteRune(char)
}
} else {
break
}
}
default:
buf.WriteString(stmt.query)
}
return buf.String()
}
func convertOldArgs(args []driver.Value) []driver.NamedValue {
dargs := make([]driver.NamedValue, len(args))
for i, v := range args {
dargs[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
return dargs
}