package clickhouse

import (
	"bufio"
	"database/sql"
	"database/sql/driver"
	"fmt"
	"io"
	"log"
	"net/url"
	"os"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/kshvakov/clickhouse/lib/leakypool"

	"github.com/kshvakov/clickhouse/lib/binary"
	"github.com/kshvakov/clickhouse/lib/data"
	"github.com/kshvakov/clickhouse/lib/protocol"
)

const (
	// DefaultDatabase when connecting to ClickHouse
	DefaultDatabase = "default"
	// DefaultUsername when connecting to ClickHouse
	DefaultUsername = "default"
	// DefaultConnTimeout when connecting to ClickHouse
	DefaultConnTimeout = 5 * time.Second
	// DefaultReadTimeout when reading query results
	DefaultReadTimeout = time.Minute
	// DefaultWriteTimeout when sending queries
	DefaultWriteTimeout = time.Minute
)

var (
	unixtime    int64
	logOutput   io.Writer = os.Stdout
	hostname, _           = os.Hostname()
	poolInit    sync.Once
)

func init() {
	sql.Register("clickhouse", &bootstrap{})
	go func() {
		for tick := time.Tick(time.Second); ; {
			select {
			case <-tick:
				atomic.AddInt64(&unixtime, int64(time.Second))
			}
		}
	}()
}

func now() time.Time {
	return time.Unix(atomic.LoadInt64(&unixtime), 0)
}

type bootstrap struct{}

func (d *bootstrap) Open(dsn string) (driver.Conn, error) {
	return Open(dsn)
}

// SetLogOutput allows to change output of the default logger
func SetLogOutput(output io.Writer) {
	logOutput = output
}

// Open the connection
func Open(dsn string) (driver.Conn, error) {
	return open(dsn)
}

func open(dsn string) (*clickhouse, error) {
	url, err := url.Parse(dsn)
	if err != nil {
		return nil, err
	}
	var (
		hosts            = []string{url.Host}
		query            = url.Query()
		secure           = false
		skipVerify       = false
		tlsConfigName    = query.Get("tls_config")
		noDelay          = true
		compress         = false
		database         = query.Get("database")
		username         = query.Get("username")
		password         = query.Get("password")
		blockSize        = 1000000
		connTimeout      = DefaultConnTimeout
		readTimeout      = DefaultReadTimeout
		writeTimeout     = DefaultWriteTimeout
		connOpenStrategy = connOpenRandom
		poolSize         = 100
	)
	if len(database) == 0 {
		database = DefaultDatabase
	}
	if len(username) == 0 {
		username = DefaultUsername
	}
	if v, err := strconv.ParseBool(query.Get("no_delay")); err == nil {
		noDelay = v
	}
	tlsConfig := getTLSConfigClone(tlsConfigName)
	if tlsConfigName != "" && tlsConfig == nil {
		return nil, fmt.Errorf("invalid tls_config - no config registered under name %s", tlsConfigName)
	}
	secure = tlsConfig != nil
	if v, err := strconv.ParseBool(query.Get("secure")); err == nil {
		secure = v
	}
	if v, err := strconv.ParseBool(query.Get("skip_verify")); err == nil {
		skipVerify = v
	}
	if duration, err := strconv.ParseFloat(query.Get("timeout"), 64); err == nil {
		connTimeout = time.Duration(duration * float64(time.Second))
	}
	if duration, err := strconv.ParseFloat(query.Get("read_timeout"), 64); err == nil {
		readTimeout = time.Duration(duration * float64(time.Second))
	}
	if duration, err := strconv.ParseFloat(query.Get("write_timeout"), 64); err == nil {
		writeTimeout = time.Duration(duration * float64(time.Second))
	}
	if size, err := strconv.ParseInt(query.Get("block_size"), 10, 64); err == nil {
		blockSize = int(size)
	}
	if size, err := strconv.ParseInt(query.Get("pool_size"), 10, 64); err == nil {
		poolSize = int(size)
	}
	poolInit.Do(func() {
		leakypool.InitBytePool(poolSize)
	})
	if altHosts := strings.Split(query.Get("alt_hosts"), ","); len(altHosts) != 0 {
		for _, host := range altHosts {
			if len(host) != 0 {
				hosts = append(hosts, host)
			}
		}
	}
	switch query.Get("connection_open_strategy") {
	case "random":
		connOpenStrategy = connOpenRandom
	case "in_order":
		connOpenStrategy = connOpenInOrder
	}

	settings, err := makeQuerySettings(query)
	if err != nil {
		return nil, err
	}

	if v, err := strconv.ParseBool(query.Get("compress")); err == nil {
		compress = v
	}

	var (
		ch = clickhouse{
			logf:      func(string, ...interface{}) {},
			settings:  settings,
			compress:  compress,
			blockSize: blockSize,
			ServerInfo: data.ServerInfo{
				Timezone: time.Local,
			},
		}
		logger = log.New(logOutput, "[clickhouse]", 0)
	)
	if debug, err := strconv.ParseBool(url.Query().Get("debug")); err == nil && debug {
		ch.logf = logger.Printf
	}
	ch.logf("host(s)=%s, database=%s, username=%s",
		strings.Join(hosts, ", "),
		database,
		username,
	)
	options := connOptions{
		secure:       secure,
		tlsConfig:    tlsConfig,
		skipVerify:   skipVerify,
		hosts:        hosts,
		connTimeout:  connTimeout,
		readTimeout:  readTimeout,
		writeTimeout: writeTimeout,
		noDelay:      noDelay,
		openStrategy: connOpenStrategy,
		logf:         ch.logf,
	}
	if ch.conn, err = dial(options); err != nil {
		return nil, err
	}
	logger.SetPrefix(fmt.Sprintf("[clickhouse][connect=%d]", ch.conn.ident))
	ch.buffer = bufio.NewWriter(ch.conn)

	ch.decoder = binary.NewDecoder(ch.conn)
	ch.encoder = binary.NewEncoder(ch.buffer)

	if err := ch.hello(database, username, password); err != nil {
		return nil, err
	}
	return &ch, nil
}

func (ch *clickhouse) hello(database, username, password string) error {
	ch.logf("[hello] -> %s", ch.ClientInfo)
	{
		ch.encoder.Uvarint(protocol.ClientHello)
		if err := ch.ClientInfo.Write(ch.encoder); err != nil {
			return err
		}
		{
			ch.encoder.String(database)
			ch.encoder.String(username)
			ch.encoder.String(password)
		}
		if err := ch.encoder.Flush(); err != nil {
			return err
		}

	}
	{
		packet, err := ch.decoder.Uvarint()
		if err != nil {
			return err
		}
		switch packet {
		case protocol.ServerException:
			return ch.exception()
		case protocol.ServerHello:
			if err := ch.ServerInfo.Read(ch.decoder); err != nil {
				return err
			}
		case protocol.ServerEndOfStream:
			ch.logf("[bootstrap] <- end of stream")
			return nil
		default:
			ch.conn.Close()
			return fmt.Errorf("[hello] unexpected packet [%d] from server", packet)
		}
	}
	ch.logf("[hello] <- %s", ch.ServerInfo)
	return nil
}