TUN-2573: Refactor TunnelRegistration into PermanentRegistrationError, RetryableRegistrationError and SuccessfulTunnelRegistration
This commit is contained in:
		
							parent
							
								
									23e12cf5a3
								
							
						
					
					
						commit
						b0d31a0ef3
					
				| 
						 | 
					@ -19,6 +19,7 @@ import (
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/signal"
 | 
						"github.com/cloudflare/cloudflared/signal"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/streamhandler"
 | 
						"github.com/cloudflare/cloudflared/streamhandler"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
						"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
						tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/validation"
 | 
						"github.com/cloudflare/cloudflared/validation"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/websocket"
 | 
						"github.com/cloudflare/cloudflared/websocket"
 | 
				
			||||||
| 
						 | 
					@ -335,23 +336,21 @@ func RegisterTunnel(
 | 
				
			||||||
	serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
 | 
						serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	registration, err := ts.RegisterTunnel(
 | 
						LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
 | 
				
			||||||
 | 
						registration := ts.RegisterTunnel(
 | 
				
			||||||
		ctx,
 | 
							ctx,
 | 
				
			||||||
		config.OriginCert,
 | 
							config.OriginCert,
 | 
				
			||||||
		config.Hostname,
 | 
							config.Hostname,
 | 
				
			||||||
		config.RegistrationOptions(connectionID, originLocalIP, uuid),
 | 
							config.RegistrationOptions(connectionID, originLocalIP, uuid),
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if registrationErr := registration.DeserializeError(); registrationErr != nil {
 | 
				
			||||||
		// RegisterTunnel RPC failure
 | 
							// RegisterTunnel RPC failure
 | 
				
			||||||
		return newClientRegisterTunnelError(err, config.Metrics.regFail)
 | 
							return processRegisterTunnelError(registrationErr, config.Metrics)
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, logLine := range registration.LogLines {
 | 
					 | 
				
			||||||
		logger.Info(logLine)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil {
 | 
						for _, logLine := range registration.LogLines {
 | 
				
			||||||
		return regErr
 | 
							logger.Info(logLine)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if registration.TunnelID != "" {
 | 
						if registration.TunnelID != "" {
 | 
				
			||||||
| 
						 | 
					@ -374,22 +373,19 @@ func RegisterTunnel(
 | 
				
			||||||
	config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
 | 
						config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
 | 
						logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
 | 
				
			||||||
 | 
						config.Metrics.regSuccess.Inc()
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error {
 | 
					func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
 | 
				
			||||||
	if err == "" {
 | 
						if err.Error() == DuplicateConnectionError {
 | 
				
			||||||
		metrics.regSuccess.Inc()
 | 
							metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
 | 
				
			||||||
		return nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	metrics.regFail.WithLabelValues(err).Inc()
 | 
					 | 
				
			||||||
	if err == DuplicateConnectionError {
 | 
					 | 
				
			||||||
		return dupConnRegisterTunnelError{}
 | 
							return dupConnRegisterTunnelError{}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						metrics.regFail.WithLabelValues("server_error").Inc()
 | 
				
			||||||
	return serverRegisterTunnelError{
 | 
						return serverRegisterTunnelError{
 | 
				
			||||||
		cause:     fmt.Errorf("Server error: %s", err),
 | 
							cause:     fmt.Errorf("Server error: %s", err.Error()),
 | 
				
			||||||
		permanent: permanentFailure,
 | 
							permanent: err.IsPermanent(),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,13 +9,16 @@ import (
 | 
				
			||||||
	"github.com/google/uuid"
 | 
						"github.com/google/uuid"
 | 
				
			||||||
	"github.com/pkg/errors"
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
					 | 
				
			||||||
	capnp "zombiezen.com/go/capnproto2"
 | 
						capnp "zombiezen.com/go/capnproto2"
 | 
				
			||||||
	"zombiezen.com/go/capnproto2/pogs"
 | 
						"zombiezen.com/go/capnproto2/pogs"
 | 
				
			||||||
	"zombiezen.com/go/capnproto2/rpc"
 | 
						"zombiezen.com/go/capnproto2/rpc"
 | 
				
			||||||
	"zombiezen.com/go/capnproto2/server"
 | 
						"zombiezen.com/go/capnproto2/server"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						defaultRetryAfterSeconds = 15
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Authentication struct {
 | 
					type Authentication struct {
 | 
				
			||||||
	Key         string
 | 
						Key         string
 | 
				
			||||||
	Email       string
 | 
						Email       string
 | 
				
			||||||
| 
						 | 
					@ -33,14 +36,111 @@ func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TunnelRegistration struct {
 | 
					type TunnelRegistration struct {
 | 
				
			||||||
 | 
						SuccessfulTunnelRegistration
 | 
				
			||||||
	Err               string
 | 
						Err               string
 | 
				
			||||||
	Url               string
 | 
					 | 
				
			||||||
	LogLines          []string
 | 
					 | 
				
			||||||
	PermanentFailure  bool
 | 
						PermanentFailure  bool
 | 
				
			||||||
	TunnelID          string `capnp:"tunnelID"`
 | 
					 | 
				
			||||||
	RetryAfterSeconds uint16
 | 
						RetryAfterSeconds uint16
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SuccessfulTunnelRegistration struct {
 | 
				
			||||||
 | 
						Url      string
 | 
				
			||||||
 | 
						LogLines []string
 | 
				
			||||||
 | 
						TunnelID string `capnp:"tunnelID"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewSuccessfulTunnelRegistration(
 | 
				
			||||||
 | 
						url string,
 | 
				
			||||||
 | 
						logLines []string,
 | 
				
			||||||
 | 
						tunnelID string,
 | 
				
			||||||
 | 
					) *TunnelRegistration {
 | 
				
			||||||
 | 
						// Marshal nil will result in an error
 | 
				
			||||||
 | 
						if logLines == nil {
 | 
				
			||||||
 | 
							logLines = []string{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &TunnelRegistration{
 | 
				
			||||||
 | 
							SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{
 | 
				
			||||||
 | 
								Url:      url,
 | 
				
			||||||
 | 
								LogLines: logLines,
 | 
				
			||||||
 | 
								TunnelID: tunnelID,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Not calling this function Error() to avoid confusion with implementing error interface
 | 
				
			||||||
 | 
					func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError {
 | 
				
			||||||
 | 
						if tr.Err != "" {
 | 
				
			||||||
 | 
							err := fmt.Errorf(tr.Err)
 | 
				
			||||||
 | 
							if tr.PermanentFailure {
 | 
				
			||||||
 | 
								return NewPermanentRegistrationError(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							retryAfterSeconds := tr.RetryAfterSeconds
 | 
				
			||||||
 | 
							if retryAfterSeconds < defaultRetryAfterSeconds {
 | 
				
			||||||
 | 
								retryAfterSeconds = defaultRetryAfterSeconds
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return NewRetryableRegistrationError(err, retryAfterSeconds)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TunnelRegistrationError interface {
 | 
				
			||||||
 | 
						error
 | 
				
			||||||
 | 
						Serialize() *TunnelRegistration
 | 
				
			||||||
 | 
						IsPermanent() bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type PermanentRegistrationError struct {
 | 
				
			||||||
 | 
						err string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewPermanentRegistrationError(err error) TunnelRegistrationError {
 | 
				
			||||||
 | 
						return &PermanentRegistrationError{
 | 
				
			||||||
 | 
							err: err.Error(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pre *PermanentRegistrationError) Error() string {
 | 
				
			||||||
 | 
						return pre.err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration {
 | 
				
			||||||
 | 
						return &TunnelRegistration{
 | 
				
			||||||
 | 
							Err:              pre.err,
 | 
				
			||||||
 | 
							PermanentFailure: true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*PermanentRegistrationError) IsPermanent() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RetryableRegistrationError struct {
 | 
				
			||||||
 | 
						err               string
 | 
				
			||||||
 | 
						retryAfterSeconds uint16
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError {
 | 
				
			||||||
 | 
						return &RetryableRegistrationError{
 | 
				
			||||||
 | 
							err:               err.Error(),
 | 
				
			||||||
 | 
							retryAfterSeconds: retryAfterSeconds,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rre *RetryableRegistrationError) Error() string {
 | 
				
			||||||
 | 
						return rre.err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration {
 | 
				
			||||||
 | 
						return &TunnelRegistration{
 | 
				
			||||||
 | 
							Err:               rre.err,
 | 
				
			||||||
 | 
							PermanentFailure:  false,
 | 
				
			||||||
 | 
							RetryAfterSeconds: rre.retryAfterSeconds,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*RetryableRegistrationError) IsPermanent() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
 | 
					func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
 | 
				
			||||||
	return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p)
 | 
						return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -325,7 +425,7 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TunnelServer interface {
 | 
					type TunnelServer interface {
 | 
				
			||||||
	RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
 | 
						RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
 | 
				
			||||||
	GetServerInfo(ctx context.Context) (*ServerInfo, error)
 | 
						GetServerInfo(ctx context.Context) (*ServerInfo, error)
 | 
				
			||||||
	UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
 | 
						UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
 | 
				
			||||||
	Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
 | 
						Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
 | 
				
			||||||
| 
						 | 
					@ -359,15 +459,12 @@ func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerT
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	server.Ack(p.Options)
 | 
						server.Ack(p.Options)
 | 
				
			||||||
	registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
 | 
						registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
 | 
				
			||||||
	if err != nil {
 | 
					
 | 
				
			||||||
		return err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	result, err := p.Results.NewResult()
 | 
						result, err := p.Results.NewResult()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	log.Info(registration.TunnelID)
 | 
					 | 
				
			||||||
	return MarshalTunnelRegistration(result, registration)
 | 
						return MarshalTunnelRegistration(result, registration)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -420,7 +517,7 @@ func (c TunnelServer_PogsClient) Close() error {
 | 
				
			||||||
	return c.Conn.Close()
 | 
						return c.Conn.Close()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
 | 
					func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
 | 
				
			||||||
	client := tunnelrpc.TunnelServer{Client: c.Client}
 | 
						client := tunnelrpc.TunnelServer{Client: c.Client}
 | 
				
			||||||
	promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
 | 
						promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
 | 
				
			||||||
		err := p.SetOriginCert(originCert)
 | 
							err := p.SetOriginCert(originCert)
 | 
				
			||||||
| 
						 | 
					@ -443,9 +540,13 @@ func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	retval, err := promise.Result().Struct()
 | 
						retval, err := promise.Result().Struct()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return UnmarshalTunnelRegistration(retval)
 | 
						registration, err := UnmarshalTunnelRegistration(retval)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return registration
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
 | 
					func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
package pogs
 | 
					package pogs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
| 
						 | 
					@ -11,16 +12,29 @@ import (
 | 
				
			||||||
	capnp "zombiezen.com/go/capnproto2"
 | 
						capnp "zombiezen.com/go/capnproto2"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						testURL               = "tunnel.example.com"
 | 
				
			||||||
 | 
						testTunnelID          = "asdfghjkl;"
 | 
				
			||||||
 | 
						testRetryAfterSeconds = 19
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						testErr      = fmt.Errorf("Invalid credential")
 | 
				
			||||||
 | 
						testLogLines = []string{"all", "working"}
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// *PermanentRegistrationError implements TunnelRegistrationError
 | 
				
			||||||
 | 
					var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// *RetryableRegistrationError implements TunnelRegistrationError
 | 
				
			||||||
 | 
					var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestTunnelRegistration(t *testing.T) {
 | 
					func TestTunnelRegistration(t *testing.T) {
 | 
				
			||||||
	testCases := []*TunnelRegistration{
 | 
						testCases := []*TunnelRegistration{
 | 
				
			||||||
		&TunnelRegistration{
 | 
							NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID),
 | 
				
			||||||
			Err:               "it broke",
 | 
							NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID),
 | 
				
			||||||
			Url:               "asdf.cftunnel.com",
 | 
							NewPermanentRegistrationError(testErr).Serialize(),
 | 
				
			||||||
			LogLines:          []string{"it", "was", "broken"},
 | 
							NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(),
 | 
				
			||||||
			PermanentFailure:  true,
 | 
					 | 
				
			||||||
			TunnelID:          "asdfghjkl;",
 | 
					 | 
				
			||||||
			RetryAfterSeconds: 19,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for i, testCase := range testCases {
 | 
						for i, testCase := range testCases {
 | 
				
			||||||
		_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
 | 
							_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue