package clickhouse

import (
	"bufio"
	"crypto/tls"
	"database/sql/driver"
	"net"
	"sync/atomic"
	"time"
)

var tick int32

type openStrategy int8

func (s openStrategy) String() string {
	switch s {
	case connOpenInOrder:
		return "in_order"
	}
	return "random"
}

const (
	connOpenRandom openStrategy = iota + 1
	connOpenInOrder
)

type connOptions struct {
	secure, skipVerify                     bool
	tlsConfig                              *tls.Config
	hosts                                  []string
	connTimeout, readTimeout, writeTimeout time.Duration
	noDelay                                bool
	openStrategy                           openStrategy
	logf                                   func(string, ...interface{})
}

func dial(options connOptions) (*connect, error) {
	var (
		err error
		abs = func(v int) int {
			if v < 0 {
				return -1 * v
			}
			return v
		}
		conn  net.Conn
		ident = abs(int(atomic.AddInt32(&tick, 1)))
	)
	tlsConfig := options.tlsConfig
	if options.secure {
		if tlsConfig == nil {
			tlsConfig = &tls.Config{}
		}
		tlsConfig.InsecureSkipVerify = options.skipVerify
	}
	for i := range options.hosts {
		var num int
		switch options.openStrategy {
		case connOpenInOrder:
			num = i
		case connOpenRandom:
			num = (ident + i) % len(options.hosts)
		}
		switch {
		case options.secure:
			conn, err = tls.DialWithDialer(
				&net.Dialer{
					Timeout: options.connTimeout,
				},
				"tcp",
				options.hosts[num],
				tlsConfig,
			)
		default:
			conn, err = net.DialTimeout("tcp", options.hosts[num], options.connTimeout)
		}
		if err == nil {
			options.logf(
				"[dial] secure=%t, skip_verify=%t, strategy=%s, ident=%d, server=%d -> %s",
				options.secure,
				options.skipVerify,
				options.openStrategy,
				ident,
				num,
				conn.RemoteAddr(),
			)
			if tcp, ok := conn.(*net.TCPConn); ok {
				err = tcp.SetNoDelay(options.noDelay) // Disable or enable the Nagle Algorithm for this tcp socket
				if err != nil {
					return nil, err
				}
			}
			return &connect{
				Conn:         conn,
				logf:         options.logf,
				ident:        ident,
				buffer:       bufio.NewReader(conn),
				readTimeout:  options.readTimeout,
				writeTimeout: options.writeTimeout,
			}, nil
		} else {
			options.logf(
				"[dial err] secure=%t, skip_verify=%t, strategy=%s, ident=%d, addr=%s\n%#v",
				options.secure,
				options.skipVerify,
				options.openStrategy,
				ident,
				options.hosts[num],
				err,
			)
		}
	}
	return nil, err
}

type connect struct {
	net.Conn
	logf                  func(string, ...interface{})
	ident                 int
	buffer                *bufio.Reader
	closed                bool
	readTimeout           time.Duration
	writeTimeout          time.Duration
	lastReadDeadlineTime  time.Time
	lastWriteDeadlineTime time.Time
}

func (conn *connect) Read(b []byte) (int, error) {
	var (
		n      int
		err    error
		total  int
		dstLen = len(b)
	)
	if currentTime := now(); conn.readTimeout != 0 && currentTime.Sub(conn.lastReadDeadlineTime) > (conn.readTimeout>>2) {
		conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
		conn.lastReadDeadlineTime = currentTime
	}
	for total < dstLen {
		if n, err = conn.buffer.Read(b[total:]); err != nil {
			conn.logf("[connect] read error: %v", err)
			conn.Close()
			return n, driver.ErrBadConn
		}
		total += n
	}
	return total, nil
}

func (conn *connect) Write(b []byte) (int, error) {
	var (
		n      int
		err    error
		total  int
		srcLen = len(b)
	)
	if currentTime := now(); conn.writeTimeout != 0 && currentTime.Sub(conn.lastWriteDeadlineTime) > (conn.writeTimeout>>2) {
		conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
		conn.lastWriteDeadlineTime = currentTime
	}
	for total < srcLen {
		if n, err = conn.Conn.Write(b[total:]); err != nil {
			conn.logf("[connect] write error: %v", err)
			conn.Close()
			return n, driver.ErrBadConn
		}
		total += n
	}
	return n, nil
}

func (conn *connect) Close() error {
	if !conn.closed {
		conn.closed = true
		return conn.Conn.Close()
	}
	return nil
}