package clickhouse

import (
	"bytes"
	"context"
	"database/sql/driver"
	"unicode"

	"github.com/kshvakov/clickhouse/lib/data"
)

type stmt struct {
	ch       *clickhouse
	query    string
	counter  int
	numInput int
	isInsert bool
}

var emptyResult = &result{}

func (stmt *stmt) NumInput() int {
	switch {
	case stmt.ch.block != nil:
		return len(stmt.ch.block.Columns)
	case stmt.numInput < 0:
		return 0
	}
	return stmt.numInput
}

func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) {
	return stmt.execContext(context.Background(), args)
}

func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
	dargs := make([]driver.Value, len(args))
	for i, nv := range args {
		dargs[i] = nv.Value
	}
	return stmt.execContext(ctx, dargs)
}

func (stmt *stmt) execContext(ctx context.Context, args []driver.Value) (driver.Result, error) {
	if stmt.isInsert {
		stmt.counter++
		if err := stmt.ch.block.AppendRow(args); err != nil {
			return nil, err
		}
		if (stmt.counter % stmt.ch.blockSize) == 0 {
			stmt.ch.logf("[exec] flush block")
			if err := stmt.ch.writeBlock(stmt.ch.block); err != nil {
				return nil, err
			}
			if err := stmt.ch.encoder.Flush(); err != nil {
				return nil, err
			}
		}
		return emptyResult, nil
	}
	if err := stmt.ch.sendQuery(stmt.bind(convertOldArgs(args))); err != nil {
		return nil, err
	}
	if err := stmt.ch.process(); err != nil {
		return nil, err
	}
	return emptyResult, nil
}

func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
	return stmt.queryContext(context.Background(), convertOldArgs(args))
}

func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
	return stmt.queryContext(ctx, args)
}

func (stmt *stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
	finish := stmt.ch.watchCancel(ctx)
	if err := stmt.ch.sendQuery(stmt.bind(args)); err != nil {
		finish()
		return nil, err
	}
	meta, err := stmt.ch.readMeta()
	if err != nil {
		finish()
		return nil, err
	}
	rows := rows{
		ch:           stmt.ch,
		finish:       finish,
		stream:       make(chan *data.Block, 50),
		columns:      meta.ColumnNames(),
		blockColumns: meta.Columns,
	}
	go rows.receiveData()
	return &rows, nil
}

func (stmt *stmt) Close() error {
	stmt.ch.logf("[stmt] close")
	return nil
}

func (stmt *stmt) bind(args []driver.NamedValue) string {
	var (
		buf       bytes.Buffer
		index     int
		keyword   bool
		inBetween bool
		like      = newMatcher("like")
		limit     = newMatcher("limit")
		between   = newMatcher("between")
		and       = newMatcher("and")
	)
	switch {
	case stmt.NumInput() != 0:
		reader := bytes.NewReader([]byte(stmt.query))
		for {
			if char, _, err := reader.ReadRune(); err == nil {
				switch char {
				case '@':
					if param := paramParser(reader); len(param) != 0 {
						for _, v := range args {
							if len(v.Name) != 0 && v.Name == param {
								buf.WriteString(quote(v.Value))
							}
						}
					}
				case '?':
					if keyword && index < len(args) && len(args[index].Name) == 0 {
						buf.WriteString(quote(args[index].Value))
						index++
					} else {
						buf.WriteRune(char)
					}
				default:
					switch {
					case
						char == '=',
						char == '<',
						char == '>',
						char == '(',
						char == ',',
						char == '+',
						char == '-',
						char == '*',
						char == '/',
						char == '[':
						keyword = true
					default:
						if limit.matchRune(char) || like.matchRune(char) {
							keyword = true
						} else if between.matchRune(char) {
							keyword = true
							inBetween = true
						} else if inBetween && and.matchRune(char) {
							keyword = true
							inBetween = false
						} else {
							keyword = keyword && unicode.IsSpace(char)
						}
					}
					buf.WriteRune(char)
				}
			} else {
				break
			}
		}
	default:
		buf.WriteString(stmt.query)
	}
	return buf.String()
}

func convertOldArgs(args []driver.Value) []driver.NamedValue {
	dargs := make([]driver.NamedValue, len(args))
	for i, v := range args {
		dargs[i] = driver.NamedValue{
			Ordinal: i + 1,
			Value:   v,
		}
	}
	return dargs
}