181 lines
3.9 KiB
Go
181 lines
3.9 KiB
Go
|
package clickhouse
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"crypto/tls"
|
||
|
"database/sql/driver"
|
||
|
"net"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var tick int32
|
||
|
|
||
|
type openStrategy int8
|
||
|
|
||
|
func (s openStrategy) String() string {
|
||
|
switch s {
|
||
|
case connOpenInOrder:
|
||
|
return "in_order"
|
||
|
}
|
||
|
return "random"
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
connOpenRandom openStrategy = iota + 1
|
||
|
connOpenInOrder
|
||
|
)
|
||
|
|
||
|
type connOptions struct {
|
||
|
secure, skipVerify bool
|
||
|
tlsConfig *tls.Config
|
||
|
hosts []string
|
||
|
connTimeout, readTimeout, writeTimeout time.Duration
|
||
|
noDelay bool
|
||
|
openStrategy openStrategy
|
||
|
logf func(string, ...interface{})
|
||
|
}
|
||
|
|
||
|
func dial(options connOptions) (*connect, error) {
|
||
|
var (
|
||
|
err error
|
||
|
abs = func(v int) int {
|
||
|
if v < 0 {
|
||
|
return -1 * v
|
||
|
}
|
||
|
return v
|
||
|
}
|
||
|
conn net.Conn
|
||
|
ident = abs(int(atomic.AddInt32(&tick, 1)))
|
||
|
)
|
||
|
tlsConfig := options.tlsConfig
|
||
|
if options.secure {
|
||
|
if tlsConfig == nil {
|
||
|
tlsConfig = &tls.Config{}
|
||
|
}
|
||
|
tlsConfig.InsecureSkipVerify = options.skipVerify
|
||
|
}
|
||
|
for i := range options.hosts {
|
||
|
var num int
|
||
|
switch options.openStrategy {
|
||
|
case connOpenInOrder:
|
||
|
num = i
|
||
|
case connOpenRandom:
|
||
|
num = (ident + i) % len(options.hosts)
|
||
|
}
|
||
|
switch {
|
||
|
case options.secure:
|
||
|
conn, err = tls.DialWithDialer(
|
||
|
&net.Dialer{
|
||
|
Timeout: options.connTimeout,
|
||
|
},
|
||
|
"tcp",
|
||
|
options.hosts[num],
|
||
|
tlsConfig,
|
||
|
)
|
||
|
default:
|
||
|
conn, err = net.DialTimeout("tcp", options.hosts[num], options.connTimeout)
|
||
|
}
|
||
|
if err == nil {
|
||
|
options.logf(
|
||
|
"[dial] secure=%t, skip_verify=%t, strategy=%s, ident=%d, server=%d -> %s",
|
||
|
options.secure,
|
||
|
options.skipVerify,
|
||
|
options.openStrategy,
|
||
|
ident,
|
||
|
num,
|
||
|
conn.RemoteAddr(),
|
||
|
)
|
||
|
if tcp, ok := conn.(*net.TCPConn); ok {
|
||
|
err = tcp.SetNoDelay(options.noDelay) // Disable or enable the Nagle Algorithm for this tcp socket
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
return &connect{
|
||
|
Conn: conn,
|
||
|
logf: options.logf,
|
||
|
ident: ident,
|
||
|
buffer: bufio.NewReader(conn),
|
||
|
readTimeout: options.readTimeout,
|
||
|
writeTimeout: options.writeTimeout,
|
||
|
}, nil
|
||
|
} else {
|
||
|
options.logf(
|
||
|
"[dial err] secure=%t, skip_verify=%t, strategy=%s, ident=%d, addr=%s\n%#v",
|
||
|
options.secure,
|
||
|
options.skipVerify,
|
||
|
options.openStrategy,
|
||
|
ident,
|
||
|
options.hosts[num],
|
||
|
err,
|
||
|
)
|
||
|
}
|
||
|
}
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
type connect struct {
|
||
|
net.Conn
|
||
|
logf func(string, ...interface{})
|
||
|
ident int
|
||
|
buffer *bufio.Reader
|
||
|
closed bool
|
||
|
readTimeout time.Duration
|
||
|
writeTimeout time.Duration
|
||
|
lastReadDeadlineTime time.Time
|
||
|
lastWriteDeadlineTime time.Time
|
||
|
}
|
||
|
|
||
|
func (conn *connect) Read(b []byte) (int, error) {
|
||
|
var (
|
||
|
n int
|
||
|
err error
|
||
|
total int
|
||
|
dstLen = len(b)
|
||
|
)
|
||
|
if currentTime := now(); conn.readTimeout != 0 && currentTime.Sub(conn.lastReadDeadlineTime) > (conn.readTimeout>>2) {
|
||
|
conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
|
||
|
conn.lastReadDeadlineTime = currentTime
|
||
|
}
|
||
|
for total < dstLen {
|
||
|
if n, err = conn.buffer.Read(b[total:]); err != nil {
|
||
|
conn.logf("[connect] read error: %v", err)
|
||
|
conn.Close()
|
||
|
return n, driver.ErrBadConn
|
||
|
}
|
||
|
total += n
|
||
|
}
|
||
|
return total, nil
|
||
|
}
|
||
|
|
||
|
func (conn *connect) Write(b []byte) (int, error) {
|
||
|
var (
|
||
|
n int
|
||
|
err error
|
||
|
total int
|
||
|
srcLen = len(b)
|
||
|
)
|
||
|
if currentTime := now(); conn.writeTimeout != 0 && currentTime.Sub(conn.lastWriteDeadlineTime) > (conn.writeTimeout>>2) {
|
||
|
conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
|
||
|
conn.lastWriteDeadlineTime = currentTime
|
||
|
}
|
||
|
for total < srcLen {
|
||
|
if n, err = conn.Conn.Write(b[total:]); err != nil {
|
||
|
conn.logf("[connect] write error: %v", err)
|
||
|
conn.Close()
|
||
|
return n, driver.ErrBadConn
|
||
|
}
|
||
|
total += n
|
||
|
}
|
||
|
return n, nil
|
||
|
}
|
||
|
|
||
|
func (conn *connect) Close() error {
|
||
|
if !conn.closed {
|
||
|
conn.closed = true
|
||
|
return conn.Conn.Close()
|
||
|
}
|
||
|
return nil
|
||
|
}
|