Release 2017.12.1
This commit is contained in:
parent
d40eb85da6
commit
e0ae598112
|
@ -7,6 +7,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -203,6 +204,35 @@ WARNING:
|
||||||
Value: 4,
|
Value: 4,
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-connect-timeout",
|
||||||
|
Usage: "HTTP proxy timeout for establishing a new connection",
|
||||||
|
Value: time.Second * 30,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-tls-timeout",
|
||||||
|
Usage: "HTTP proxy timeout for completing a TLS handshake",
|
||||||
|
Value: time.Second * 10,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-tcp-keepalive",
|
||||||
|
Usage: "HTTP proxy TCP keepalive duration",
|
||||||
|
Value: time.Second * 30,
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "proxy-no-happy-eyeballs",
|
||||||
|
Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
|
||||||
|
}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{
|
||||||
|
Name: "proxy-keepalive-connections",
|
||||||
|
Usage: "HTTP proxy maximum keepalive connection pool size",
|
||||||
|
Value: 100,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-keepalive-timeout",
|
||||||
|
Usage: "HTTP proxy timeout for closing an idle connection",
|
||||||
|
Value: time.Second * 90,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
app.Action = func(c *cli.Context) error {
|
app.Action = func(c *cli.Context) error {
|
||||||
raven.CapturePanic(func() { startServer(c) }, nil)
|
raven.CapturePanic(func() { startServer(c) }, nil)
|
||||||
|
@ -348,6 +378,18 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
||||||
}
|
}
|
||||||
tunnelMetrics := origin.NewTunnelMetrics()
|
tunnelMetrics := origin.NewTunnelMetrics()
|
||||||
|
httpTransport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: c.Duration("proxy-connect-timeout"),
|
||||||
|
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
||||||
|
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: c.Int("proxy-keepalive-connections"),
|
||||||
|
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
||||||
|
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
tunnelConfig := &origin.TunnelConfig{
|
tunnelConfig := &origin.TunnelConfig{
|
||||||
EdgeAddrs: c.StringSlice("edge"),
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
OriginUrl: url,
|
OriginUrl: url,
|
||||||
|
@ -362,6 +404,7 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
LBPool: c.String("lb-pool"),
|
LBPool: c.String("lb-pool"),
|
||||||
Tags: tags,
|
Tags: tags,
|
||||||
HAConnections: c.Int("ha-connections"),
|
HAConnections: c.Int("ha-connections"),
|
||||||
|
HTTPTransport: httpTransport,
|
||||||
Metrics: tunnelMetrics,
|
Metrics: tunnelMetrics,
|
||||||
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
||||||
ProtocolLogger: protoLogger,
|
ProtocolLogger: protoLogger,
|
||||||
|
|
|
@ -5,27 +5,51 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/net/trace"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ServeMetrics(l net.Listener, shutdownC <-chan struct{}) error {
|
const (
|
||||||
|
shutdownTimeout = time.Second * 15
|
||||||
|
startupTime = time.Millisecond * 500
|
||||||
|
)
|
||||||
|
|
||||||
|
func ServeMetrics(l net.Listener, shutdownC <-chan struct{}) (err error) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
// Metrics port is privileged, so no need for further access control
|
||||||
|
trace.AuthRequest = func(*http.Request) (bool, bool) { return true, true }
|
||||||
|
// TODO: parameterize ReadTimeout and WriteTimeout. The maximum time we can
|
||||||
|
// profile CPU usage depends on WriteTimeout
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
ReadTimeout: 5 * time.Second,
|
ReadTimeout: 10 * time.Second,
|
||||||
WriteTimeout: 5 * time.Second,
|
WriteTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
<-shutdownC
|
|
||||||
server.Shutdown(context.Background())
|
|
||||||
}()
|
|
||||||
http.Handle("/metrics", promhttp.Handler())
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err = server.Serve(l)
|
||||||
|
}()
|
||||||
log.WithField("addr", l.Addr()).Info("Starting metrics server")
|
log.WithField("addr", l.Addr()).Info("Starting metrics server")
|
||||||
err := server.Serve(l)
|
// server.Serve will hang if server.Shutdown is called before the server is
|
||||||
|
// fully started up. So add artificial delay.
|
||||||
|
time.Sleep(startupTime)
|
||||||
|
|
||||||
|
<-shutdownC
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||||
|
server.Shutdown(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
if err == http.ErrServerClosed {
|
if err == http.ErrServerClosed {
|
||||||
log.Info("Metrics server stopped")
|
log.Info("Metrics server stopped")
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelMetrics struct {
|
type TunnelMetrics struct {
|
||||||
|
haConnections prometheus.Gauge
|
||||||
totalRequests prometheus.Counter
|
totalRequests prometheus.Counter
|
||||||
requestsPerTunnel *prometheus.CounterVec
|
requestsPerTunnel *prometheus.CounterVec
|
||||||
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
||||||
|
@ -41,6 +42,13 @@ type TunnelMetrics struct {
|
||||||
|
|
||||||
// Metrics that can be collected without asking the edge
|
// Metrics that can be collected without asking the edge
|
||||||
func NewTunnelMetrics() *TunnelMetrics {
|
func NewTunnelMetrics() *TunnelMetrics {
|
||||||
|
haConnections := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "ha_connections",
|
||||||
|
Help: "Number of active ha connections",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(haConnections)
|
||||||
|
|
||||||
totalRequests := prometheus.NewCounter(
|
totalRequests := prometheus.NewCounter(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "total_requests",
|
Name: "total_requests",
|
||||||
|
@ -173,6 +181,7 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
prometheus.MustRegister(serverLocations)
|
prometheus.MustRegister(serverLocations)
|
||||||
|
|
||||||
return &TunnelMetrics{
|
return &TunnelMetrics{
|
||||||
|
haConnections: haConnections,
|
||||||
totalRequests: totalRequests,
|
totalRequests: totalRequests,
|
||||||
requestsPerTunnel: requestsPerTunnel,
|
requestsPerTunnel: requestsPerTunnel,
|
||||||
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
||||||
|
@ -196,6 +205,14 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) incrementHaConnections() {
|
||||||
|
t.haConnections.Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) decrementHaConnections() {
|
||||||
|
t.haConnections.Dec()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
|
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
|
||||||
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
|
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
|
||||||
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
|
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Waiting time before retrying a failed tunnel connection
|
// Waiting time before retrying a failed tunnel connection
|
||||||
tunnelRetryDuration = time.Minute
|
tunnelRetryDuration = time.Second * 10
|
||||||
// Limit on the exponential backoff time period. (2^5 = 32 minutes)
|
// Limit on the exponential backoff time period. (2^5 = 32 minutes)
|
||||||
tunnelRetryLimit = 5
|
tunnelRetryLimit = 5
|
||||||
// SRV record resolution TTL
|
// SRV record resolution TTL
|
||||||
|
@ -19,12 +19,16 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Supervisor struct {
|
type Supervisor struct {
|
||||||
config *TunnelConfig
|
config *TunnelConfig
|
||||||
edgeIPs []*net.TCPAddr
|
edgeIPs []*net.TCPAddr
|
||||||
lastResolve time.Time
|
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
||||||
resolverC chan resolveResult
|
nextUnusedEdgeIP int
|
||||||
tunnelErrors chan tunnelError
|
lastResolve time.Time
|
||||||
tunnelsConnecting map[int]chan struct{}
|
resolverC chan resolveResult
|
||||||
|
tunnelErrors chan tunnelError
|
||||||
|
tunnelsConnecting map[int]chan struct{}
|
||||||
|
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
||||||
|
// currently-connecting tunnels to finish connecting so we can reset backoff timer
|
||||||
nextConnectedIndex int
|
nextConnectedIndex int
|
||||||
nextConnectedSignal chan struct{}
|
nextConnectedSignal chan struct{}
|
||||||
}
|
}
|
||||||
|
@ -53,7 +57,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
}
|
}
|
||||||
tunnelsActive := s.config.HAConnections
|
tunnelsActive := s.config.HAConnections
|
||||||
tunnelsWaiting := []int{}
|
tunnelsWaiting := []int{}
|
||||||
backoff := BackoffHandler{MaxRetries: tunnelRetryLimit, BaseTime: tunnelRetryDuration, RetryForever: true}
|
backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
|
||||||
var backoffTimer <-chan time.Time
|
var backoffTimer <-chan time.Time
|
||||||
for tunnelsActive > 0 {
|
for tunnelsActive > 0 {
|
||||||
select {
|
select {
|
||||||
|
@ -72,10 +76,17 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
s.waitForNextTunnel(tunnelError.index)
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
if backoffTimer != nil {
|
if backoffTimer == nil {
|
||||||
backoffTimer = backoff.BackoffTimer()
|
backoffTimer = backoff.BackoffTimer()
|
||||||
}
|
}
|
||||||
s.refreshEdgeIPs()
|
// If the error is a dial error, the problem is likely to be network related
|
||||||
|
// try another addr before refreshing since we are likely to get back the
|
||||||
|
// same IPs in the same order. Same problem with duplicate connection error.
|
||||||
|
if s.unusedIPs() {
|
||||||
|
s.replaceEdgeIP(tunnelError.index)
|
||||||
|
} else {
|
||||||
|
s.refreshEdgeIPs()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Backoff was set and its timer expired
|
// Backoff was set and its timer expired
|
||||||
case <-backoffTimer:
|
case <-backoffTimer:
|
||||||
|
@ -109,15 +120,23 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Infof("ResolveEdgeIPs err")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.edgeIPs = edgeIPs
|
s.edgeIPs = edgeIPs
|
||||||
|
if s.config.HAConnections > len(edgeIPs) {
|
||||||
|
log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
||||||
|
s.config.HAConnections = len(edgeIPs)
|
||||||
|
}
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
go s.startTunnel(ctx, 0, connectedSignal)
|
// check entitlement and version too old error before attempting to register more tunnels
|
||||||
|
s.nextUnusedEdgeIP = s.config.HAConnections
|
||||||
|
go s.startFirstTunnel(ctx, connectedSignal)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
<-s.tunnelErrors
|
<-s.tunnelErrors
|
||||||
return nil
|
// Error can't be nil. A nil error signals that initialization succeed
|
||||||
|
return fmt.Errorf("Context was canceled")
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
return tunnelError.err
|
return tunnelError.err
|
||||||
case <-connectedSignal:
|
case <-connectedSignal:
|
||||||
|
@ -125,10 +144,41 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct
|
||||||
// At least one successful connection, so start the rest
|
// At least one successful connection, so start the rest
|
||||||
for i := 1; i < s.config.HAConnections; i++ {
|
for i := 1; i < s.config.HAConnections; i++ {
|
||||||
go s.startTunnel(ctx, i, make(chan struct{}))
|
go s.startTunnel(ctx, i, make(chan struct{}))
|
||||||
|
// TODO: Add artificial delay between HA connections to make sure all origins
|
||||||
|
// are registered in LB pool. Temporary fix until we fix LB
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||||
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||||
|
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}) {
|
||||||
|
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal)
|
||||||
|
defer func() {
|
||||||
|
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for s.unusedIPs() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
switch err.(type) {
|
||||||
|
case nil:
|
||||||
|
return
|
||||||
|
// try the next address if it was a dialError(network problem) or
|
||||||
|
// dupConnRegisterTunnelError
|
||||||
|
case dialError, dupConnRegisterTunnelError:
|
||||||
|
s.replaceEdgeIP(0)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors.
|
// s.tunnelErrors.
|
||||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
|
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
|
||||||
|
@ -172,3 +222,12 @@ func (s *Supervisor) refreshEdgeIPs() {
|
||||||
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
|
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) unusedIPs() bool {
|
||||||
|
return s.nextUnusedEdgeIP < len(s.edgeIPs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
|
||||||
|
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
||||||
|
s.nextUnusedEdgeIP++
|
||||||
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
@ -47,6 +48,7 @@ type TunnelConfig struct {
|
||||||
LBPool string
|
LBPool string
|
||||||
Tags []tunnelpogs.Tag
|
Tags []tunnelpogs.Tag
|
||||||
HAConnections int
|
HAConnections int
|
||||||
|
HTTPTransport http.RoundTripper
|
||||||
Metrics *TunnelMetrics
|
Metrics *TunnelMetrics
|
||||||
MetricsUpdateFreq time.Duration
|
MetricsUpdateFreq time.Duration
|
||||||
ProtocolLogger *log.Logger
|
ProtocolLogger *log.Logger
|
||||||
|
@ -98,6 +100,7 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte
|
||||||
<-shutdownC
|
<-shutdownC
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
// If a user specified negative HAConnections, we will treat it as requesting 1 connection
|
||||||
if config.HAConnections > 1 {
|
if config.HAConnections > 1 {
|
||||||
return NewSupervisor(config).Run(ctx, connectedSignal)
|
return NewSupervisor(config).Run(ctx, connectedSignal)
|
||||||
} else {
|
} else {
|
||||||
|
@ -110,6 +113,8 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, connectedSignal chan struct{}) error {
|
func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, connectedSignal chan struct{}) error {
|
||||||
|
config.Metrics.incrementHaConnections()
|
||||||
|
defer config.Metrics.decrementHaConnections()
|
||||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||||
// Used to close connectedSignal no more than once
|
// Used to close connectedSignal no more than once
|
||||||
connectedFuse := h2mux.NewBooleanFuse()
|
connectedFuse := h2mux.NewBooleanFuse()
|
||||||
|
@ -141,6 +146,8 @@ func ServeTunnel(
|
||||||
connectedFuse *h2mux.BooleanFuse,
|
connectedFuse *h2mux.BooleanFuse,
|
||||||
backoff *BackoffHandler,
|
backoff *BackoffHandler,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
@ -159,6 +166,7 @@ func ServeTunnel(
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case dialError:
|
case dialError:
|
||||||
errLog.Error("Unable to dial edge")
|
errLog.Error("Unable to dial edge")
|
||||||
|
return err, false
|
||||||
case h2mux.MuxerHandshakeError:
|
case h2mux.MuxerHandshakeError:
|
||||||
errLog.Error("Handshake failed with edge server")
|
errLog.Error("Handshake failed with edge server")
|
||||||
default:
|
default:
|
||||||
|
@ -178,14 +186,16 @@ func ServeTunnel(
|
||||||
serveCancel()
|
serveCancel()
|
||||||
}
|
}
|
||||||
registerErrC <- err
|
registerErrC <- err
|
||||||
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-serveCtx.Done():
|
case <-serveCtx.Done():
|
||||||
handler.muxer.Shutdown()
|
handler.muxer.Shutdown()
|
||||||
break
|
return
|
||||||
case <-updateMetricsTickC:
|
case <-updateMetricsTickC:
|
||||||
handler.UpdateMetrics()
|
handler.UpdateMetrics()
|
||||||
}
|
}
|
||||||
|
@ -195,6 +205,7 @@ func ServeTunnel(
|
||||||
err = handler.muxer.Serve()
|
err = handler.muxer.Serve()
|
||||||
serveCancel()
|
serveCancel()
|
||||||
registerErr := <-registerErrC
|
registerErr := <-registerErrC
|
||||||
|
wg.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Tunnel error")
|
log.WithError(err).Error("Tunnel error")
|
||||||
return err, true
|
return err, true
|
||||||
|
@ -204,7 +215,7 @@ func ServeTunnel(
|
||||||
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
||||||
log.Error(e)
|
log.Error(e)
|
||||||
if e.permanent {
|
if e.permanent {
|
||||||
return nil, false
|
return e, false
|
||||||
}
|
}
|
||||||
return e.cause, true
|
return e.cause, true
|
||||||
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
|
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
|
||||||
|
@ -282,10 +293,6 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Registered at %s", registration.Url)
|
log.Infof("Registered at %s", registration.Url)
|
||||||
|
|
||||||
for _, logLine := range registration.LogLines {
|
|
||||||
log.Infof(logLine)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,7 +355,7 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||||
type TunnelHandler struct {
|
type TunnelHandler struct {
|
||||||
originUrl string
|
originUrl string
|
||||||
muxer *h2mux.Muxer
|
muxer *h2mux.Muxer
|
||||||
httpClient *http.Client
|
httpClient http.RoundTripper
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
metrics *TunnelMetrics
|
metrics *TunnelMetrics
|
||||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||||
|
@ -365,24 +372,27 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
|
||||||
}
|
}
|
||||||
h := &TunnelHandler{
|
h := &TunnelHandler{
|
||||||
originUrl: url,
|
originUrl: url,
|
||||||
httpClient: &http.Client{Timeout: time.Minute},
|
httpClient: config.HTTPTransport,
|
||||||
tags: config.Tags,
|
tags: config.Tags,
|
||||||
metrics: config.Metrics,
|
metrics: config.Metrics,
|
||||||
connectionID: uint8ToString(connectionID),
|
connectionID: uint8ToString(connectionID),
|
||||||
}
|
}
|
||||||
|
if h.httpClient == nil {
|
||||||
|
h.httpClient = http.DefaultTransport
|
||||||
|
}
|
||||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||||
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
||||||
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
||||||
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
|
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
|
||||||
dialCancel()
|
dialCancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", dialError{cause: err}
|
return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
|
||||||
}
|
}
|
||||||
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
||||||
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
||||||
err = edgeConn.Handshake()
|
err = edgeConn.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", dialError{cause: err}
|
return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
||||||
}
|
}
|
||||||
// clear the deadline on the conn; h2mux has its own timeouts
|
// clear the deadline on the conn; h2mux has its own timeouts
|
||||||
edgeConn.SetDeadline(time.Time{})
|
edgeConn.SetDeadline(time.Time{})
|
||||||
|
@ -419,7 +429,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
log.WithError(err).Error("invalid request received")
|
log.WithError(err).Error("invalid request received")
|
||||||
}
|
}
|
||||||
h.AppendTagHeaders(req)
|
h.AppendTagHeaders(req)
|
||||||
response, err := h.httpClient.Do(req)
|
response, err := h.httpClient.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("HTTP request error")
|
log.WithError(err).Error("HTTP request error")
|
||||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||||
|
|
Loading…
Reference in New Issue