Merge branch 'master' into master

This commit is contained in:
Kyle Harding 2019-12-17 08:47:06 -05:00 committed by GitHub
commit da8be8e873
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 171 additions and 22 deletions

View File

@ -1,3 +1,4 @@
FROM golang:1.12-alpine as builder
WORKDIR /go/src/github.com/cloudflare/cloudflared/

View File

@ -197,6 +197,7 @@ func prepareTunnelConfig(
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: c.Int("proxy-keepalive-connections"),
MaxIdleConnsPerHost: c.Int("proxy-keepalive-connections"),
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,

View File

@ -43,6 +43,7 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
return m
}
// This function should be called while `m` is locked.
func (m *activeStreamMap) notifyStreamsEmpty() {
m.closeOnce.Do(func() {
close(m.streamsEmptyChan)
@ -87,7 +88,9 @@ func (m *activeStreamMap) Delete(streamID uint32) {
delete(m.streams, streamID)
m.activeStreams.Dec()
}
if len(m.streams) == 0 {
// shutting down, and now the map is empty
if m.ignoreNewStreams && len(m.streams) == 0 {
m.notifyStreamsEmpty()
}
}
@ -104,7 +107,7 @@ func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bo
}
m.ignoreNewStreams = true
if len(m.streams) == 0 {
// nothing to shut down
// there are no streams to wait for
m.notifyStreamsEmpty()
}
return m.streamsEmptyChan, false

View File

@ -60,6 +60,67 @@ func TestShutdown(t *testing.T) {
}
}
func TestEmptyBeforeShutdown(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{streamID: uint32(streamID)}
ok := m.Set(stream)
assert.True(t, ok)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
// Delete all the streams, bringing m to size 0
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
m.Delete(uint32(streamID))
}(i)
}
wg.Wait()
}
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
// Add one stream back
const soloStreamID = uint32(0)
ok := m.Set(&MuxedStream{streamID: soloStreamID})
assert.True(t, ok)
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
shutdownChan2, alreadyInProgress2 := m.Shutdown()
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
// Remove the remaining stream
m.Delete(soloStreamID)
select {
case <-shutdownChan:
default:
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
}
}
type noopBuffer struct {
isClosed bool
}

View File

@ -31,6 +31,8 @@ const (
refreshAuthMaxBackoff = 10
// Waiting time before retrying a failed 'Authenticate' connection
refreshAuthRetryDuration = time.Second * 10
// Maximum time to make an Authenticate RPC
authTokenTimeout = time.Second * 30
)
var (
@ -90,14 +92,21 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
return err
}
var tunnelsWaiting []int
tunnelsActive := s.config.HAConnections
backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
var backoffTimer <-chan time.Time
tunnelsActive := s.config.HAConnections
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken {
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
if timer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer
} else {
logger.WithError(err).Errorf("initial refreshAuth failed, retrying in %v", refreshAuthRetryDuration)
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
}
for {
@ -169,11 +178,11 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
}
}
// Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.logger
edgeIPs, err := s.resolveEdgeIPs()
if err != nil {
logger.Infof("ResolveEdgeIPs err")
return err
@ -190,7 +199,6 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
select {
case <-ctx.Done():
<-s.tunnelErrors
// Error can't be nil. A nil error signals that initialization succeed
return ctx.Err()
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
@ -208,7 +216,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
defer func() {
s.tunnelErrors <- tunnelError{index: 0, err: err}
}()
@ -229,14 +237,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
default:
return
}
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
err = ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
@ -347,7 +355,9 @@ func (s *Supervisor) refreshAuth(
s.SetReconnectToken(outcome.JWT())
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
return timeAfter(outcome.RefreshAfter()), nil
duration := outcome.RefreshAfter()
logger.WithError(outcome).Warnf("Retrying in %v", duration)
return timeAfter(duration), nil
case tunnelpogs.AuthFail:
return nil, outcome
default:

View File

@ -78,6 +78,14 @@ type TunnelConfig struct {
UseReconnectToken bool
}
// ReconnectTunnelCredentialManager is invoked by functions in this file to
// get/set parameters for ReconnectTunnel RPC calls.
type ReconnectTunnelCredentialManager interface {
ReconnectToken() ([]byte, error)
EventDigest() ([]byte, error)
SetEventDigest(eventDigest []byte)
}
type dupConnRegisterTunnelError struct{}
func (e dupConnRegisterTunnelError) Error() string {
@ -152,6 +160,7 @@ func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSigna
}
func ServeTunnelLoop(ctx context.Context,
credentialManager ReconnectTunnelCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionID uint8,
@ -173,6 +182,7 @@ func ServeTunnelLoop(ctx context.Context,
for {
err, recoverable := ServeTunnel(
ctx,
credentialManager,
config,
connectionLogger,
addr, connectionID,
@ -193,6 +203,7 @@ func ServeTunnelLoop(ctx context.Context,
func ServeTunnel(
ctx context.Context,
credentialManager ReconnectTunnelCredentialManager,
config *TunnelConfig,
logger *log.Entry,
addr *net.TCPAddr,
@ -237,13 +248,30 @@ func ServeTunnel(
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err := RegisterTunnel(serveCtx, handler.muxer, config, logger, connectionID, originLocalIP, u)
if err == nil {
connectedFuse.Fuse(true)
backoff.SetGracePeriod()
errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Fuse(true)
backoff.SetGracePeriod()
}
}()
if config.UseReconnectToken && connectedFuse.Value() {
token, tokenErr := credentialManager.ReconnectToken()
eventDigest, eventDigestErr := credentialManager.EventDigest()
// if we have both credentials, we can reconnect
if tokenErr == nil && eventDigestErr == nil {
return ReconnectTunnel(ctx, token, eventDigest, handler.muxer, config, logger, connectionID, originLocalIP, u)
}
// log errors and proceed to RegisterTunnel
if tokenErr != nil {
logger.WithError(tokenErr).Error("Couldn't get reconnect token")
}
if eventDigestErr != nil {
logger.WithError(eventDigestErr).Error("Couldn't get event digest")
}
}
return err
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, u)
})
errGroup.Go(func() error {
@ -304,6 +332,7 @@ func ServeTunnel(
func RegisterTunnel(
ctx context.Context,
credentialManager ReconnectTunnelCredentialManager,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger *log.Entry,
@ -329,12 +358,52 @@ func RegisterTunnel(
config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid),
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// RegisterTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics)
}
credentialManager.SetEventDigest(registration.EventDigest)
return processRegistrationSuccess(config, logger, connectionID, registration)
}
func ReconnectTunnel(
ctx context.Context,
token []byte,
eventDigest []byte,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger *log.Entry,
connectionID uint8,
originLocalIP string,
uuid uuid.UUID,
) error {
config.TransportLogger.Debug("initiating RPC stream to reconnect")
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-reconnect"), openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
defer tunnelServer.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := tunnelServer.ReconnectTunnel(
ctx,
token,
eventDigest,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid),
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics)
}
return processRegistrationSuccess(config, logger, connectionID, registration)
}
func processRegistrationSuccess(config *TunnelConfig, logger *log.Entry, connectionID uint8, registration *tunnelpogs.TunnelRegistration) error {
for _, logLine := range registration.LogLines {
logger.Info(logLine)
}
@ -378,13 +447,13 @@ func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
logger.Debug("initiating RPC stream to unregister")
ctx := context.Background()
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
tunnelServer, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
if err != nil {
// RPC stream open error
return err
}
// gracePeriod is encoded in int64 using capnproto
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
return tunnelServer.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
}
func LogServerInfo(

View File

@ -46,7 +46,7 @@ func (c TunnelServer_PogsClient) ReconnectTunnel(
eventDigest []byte,
hostname string,
options *RegistrationOptions,
) (*TunnelRegistration, error) {
) *TunnelRegistration {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.ReconnectTunnel(ctx, func(p tunnelrpc.TunnelServer_reconnectTunnel_Params) error {
err := p.SetJwt(jwt)
@ -73,7 +73,11 @@ func (c TunnelServer_PogsClient) ReconnectTunnel(
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
return UnmarshalTunnelRegistration(retval)
registration, err := UnmarshalTunnelRegistration(retval)
if err != nil {
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
return registration
}