TUN-2573: Refactor TunnelRegistration into PermanentRegistrationError, RetryableRegistrationError and SuccessfulTunnelRegistration
This commit is contained in:
		
							parent
							
								
									23e12cf5a3
								
							
						
					
					
						commit
						b0d31a0ef3
					
				|  | @ -19,6 +19,7 @@ import ( | ||||||
| 	"github.com/cloudflare/cloudflared/signal" | 	"github.com/cloudflare/cloudflared/signal" | ||||||
| 	"github.com/cloudflare/cloudflared/streamhandler" | 	"github.com/cloudflare/cloudflared/streamhandler" | ||||||
| 	"github.com/cloudflare/cloudflared/tunnelrpc" | 	"github.com/cloudflare/cloudflared/tunnelrpc" | ||||||
|  | 	"github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||||
| 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||||
| 	"github.com/cloudflare/cloudflared/validation" | 	"github.com/cloudflare/cloudflared/validation" | ||||||
| 	"github.com/cloudflare/cloudflared/websocket" | 	"github.com/cloudflare/cloudflared/websocket" | ||||||
|  | @ -335,23 +336,21 @@ func RegisterTunnel( | ||||||
| 	serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { | 	serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
| 	registration, err := ts.RegisterTunnel( | 	LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) | ||||||
|  | 	registration := ts.RegisterTunnel( | ||||||
| 		ctx, | 		ctx, | ||||||
| 		config.OriginCert, | 		config.OriginCert, | ||||||
| 		config.Hostname, | 		config.Hostname, | ||||||
| 		config.RegistrationOptions(connectionID, originLocalIP, uuid), | 		config.RegistrationOptions(connectionID, originLocalIP, uuid), | ||||||
| 	) | 	) | ||||||
| 	LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) | 
 | ||||||
| 	if err != nil { | 	if registrationErr := registration.DeserializeError(); registrationErr != nil { | ||||||
| 		// RegisterTunnel RPC failure
 | 		// RegisterTunnel RPC failure
 | ||||||
| 		return newClientRegisterTunnelError(err, config.Metrics.regFail) | 		return processRegisterTunnelError(registrationErr, config.Metrics) | ||||||
| 	} |  | ||||||
| 	for _, logLine := range registration.LogLines { |  | ||||||
| 		logger.Info(logLine) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil { | 	for _, logLine := range registration.LogLines { | ||||||
| 		return regErr | 		logger.Info(logLine) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if registration.TunnelID != "" { | 	if registration.TunnelID != "" { | ||||||
|  | @ -374,22 +373,19 @@ func RegisterTunnel( | ||||||
| 	config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc() | 	config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc() | ||||||
| 
 | 
 | ||||||
| 	logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional") | 	logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional") | ||||||
|  | 	config.Metrics.regSuccess.Inc() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error { | func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error { | ||||||
| 	if err == "" { | 	if err.Error() == DuplicateConnectionError { | ||||||
| 		metrics.regSuccess.Inc() | 		metrics.regFail.WithLabelValues("dup_edge_conn").Inc() | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	metrics.regFail.WithLabelValues(err).Inc() |  | ||||||
| 	if err == DuplicateConnectionError { |  | ||||||
| 		return dupConnRegisterTunnelError{} | 		return dupConnRegisterTunnelError{} | ||||||
| 	} | 	} | ||||||
|  | 	metrics.regFail.WithLabelValues("server_error").Inc() | ||||||
| 	return serverRegisterTunnelError{ | 	return serverRegisterTunnelError{ | ||||||
| 		cause:     fmt.Errorf("Server error: %s", err), | 		cause:     fmt.Errorf("Server error: %s", err.Error()), | ||||||
| 		permanent: permanentFailure, | 		permanent: err.IsPermanent(), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -9,13 +9,16 @@ import ( | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/sirupsen/logrus" |  | ||||||
| 	capnp "zombiezen.com/go/capnproto2" | 	capnp "zombiezen.com/go/capnproto2" | ||||||
| 	"zombiezen.com/go/capnproto2/pogs" | 	"zombiezen.com/go/capnproto2/pogs" | ||||||
| 	"zombiezen.com/go/capnproto2/rpc" | 	"zombiezen.com/go/capnproto2/rpc" | ||||||
| 	"zombiezen.com/go/capnproto2/server" | 	"zombiezen.com/go/capnproto2/server" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | const ( | ||||||
|  | 	defaultRetryAfterSeconds = 15 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| type Authentication struct { | type Authentication struct { | ||||||
| 	Key         string | 	Key         string | ||||||
| 	Email       string | 	Email       string | ||||||
|  | @ -33,12 +36,109 @@ func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type TunnelRegistration struct { | type TunnelRegistration struct { | ||||||
|  | 	SuccessfulTunnelRegistration | ||||||
| 	Err               string | 	Err               string | ||||||
|  | 	PermanentFailure  bool | ||||||
|  | 	RetryAfterSeconds uint16 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type SuccessfulTunnelRegistration struct { | ||||||
| 	Url      string | 	Url      string | ||||||
| 	LogLines []string | 	LogLines []string | ||||||
| 	PermanentFailure  bool |  | ||||||
| 	TunnelID string `capnp:"tunnelID"` | 	TunnelID string `capnp:"tunnelID"` | ||||||
| 	RetryAfterSeconds uint16 | } | ||||||
|  | 
 | ||||||
|  | func NewSuccessfulTunnelRegistration( | ||||||
|  | 	url string, | ||||||
|  | 	logLines []string, | ||||||
|  | 	tunnelID string, | ||||||
|  | ) *TunnelRegistration { | ||||||
|  | 	// Marshal nil will result in an error
 | ||||||
|  | 	if logLines == nil { | ||||||
|  | 		logLines = []string{} | ||||||
|  | 	} | ||||||
|  | 	return &TunnelRegistration{ | ||||||
|  | 		SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{ | ||||||
|  | 			Url:      url, | ||||||
|  | 			LogLines: logLines, | ||||||
|  | 			TunnelID: tunnelID, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Not calling this function Error() to avoid confusion with implementing error interface
 | ||||||
|  | func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError { | ||||||
|  | 	if tr.Err != "" { | ||||||
|  | 		err := fmt.Errorf(tr.Err) | ||||||
|  | 		if tr.PermanentFailure { | ||||||
|  | 			return NewPermanentRegistrationError(err) | ||||||
|  | 		} | ||||||
|  | 		retryAfterSeconds := tr.RetryAfterSeconds | ||||||
|  | 		if retryAfterSeconds < defaultRetryAfterSeconds { | ||||||
|  | 			retryAfterSeconds = defaultRetryAfterSeconds | ||||||
|  | 		} | ||||||
|  | 		return NewRetryableRegistrationError(err, retryAfterSeconds) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type TunnelRegistrationError interface { | ||||||
|  | 	error | ||||||
|  | 	Serialize() *TunnelRegistration | ||||||
|  | 	IsPermanent() bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PermanentRegistrationError struct { | ||||||
|  | 	err string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewPermanentRegistrationError(err error) TunnelRegistrationError { | ||||||
|  | 	return &PermanentRegistrationError{ | ||||||
|  | 		err: err.Error(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pre *PermanentRegistrationError) Error() string { | ||||||
|  | 	return pre.err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration { | ||||||
|  | 	return &TunnelRegistration{ | ||||||
|  | 		Err:              pre.err, | ||||||
|  | 		PermanentFailure: true, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (*PermanentRegistrationError) IsPermanent() bool { | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type RetryableRegistrationError struct { | ||||||
|  | 	err               string | ||||||
|  | 	retryAfterSeconds uint16 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError { | ||||||
|  | 	return &RetryableRegistrationError{ | ||||||
|  | 		err:               err.Error(), | ||||||
|  | 		retryAfterSeconds: retryAfterSeconds, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rre *RetryableRegistrationError) Error() string { | ||||||
|  | 	return rre.err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration { | ||||||
|  | 	return &TunnelRegistration{ | ||||||
|  | 		Err:               rre.err, | ||||||
|  | 		PermanentFailure:  false, | ||||||
|  | 		RetryAfterSeconds: rre.retryAfterSeconds, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (*RetryableRegistrationError) IsPermanent() bool { | ||||||
|  | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error { | func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error { | ||||||
|  | @ -325,7 +425,7 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type TunnelServer interface { | type TunnelServer interface { | ||||||
| 	RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) | 	RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration | ||||||
| 	GetServerInfo(ctx context.Context) (*ServerInfo, error) | 	GetServerInfo(ctx context.Context) (*ServerInfo, error) | ||||||
| 	UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error | 	UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error | ||||||
| 	Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error) | 	Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error) | ||||||
|  | @ -359,15 +459,12 @@ func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerT | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	server.Ack(p.Options) | 	server.Ack(p.Options) | ||||||
| 	registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions) | 	registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions) | ||||||
| 	if err != nil { | 
 | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	result, err := p.Results.NewResult() | 	result, err := p.Results.NewResult() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	log.Info(registration.TunnelID) |  | ||||||
| 	return MarshalTunnelRegistration(result, registration) | 	return MarshalTunnelRegistration(result, registration) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -420,7 +517,7 @@ func (c TunnelServer_PogsClient) Close() error { | ||||||
| 	return c.Conn.Close() | 	return c.Conn.Close() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) { | func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration { | ||||||
| 	client := tunnelrpc.TunnelServer{Client: c.Client} | 	client := tunnelrpc.TunnelServer{Client: c.Client} | ||||||
| 	promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error { | 	promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error { | ||||||
| 		err := p.SetOriginCert(originCert) | 		err := p.SetOriginCert(originCert) | ||||||
|  | @ -443,9 +540,13 @@ func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert | ||||||
| 	}) | 	}) | ||||||
| 	retval, err := promise.Result().Struct() | 	retval, err := promise.Result().Struct() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize() | ||||||
| 	} | 	} | ||||||
| 	return UnmarshalTunnelRegistration(retval) | 	registration, err := UnmarshalTunnelRegistration(retval) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize() | ||||||
|  | 	} | ||||||
|  | 	return registration | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) { | func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) { | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| package pogs | package pogs | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -11,16 +12,29 @@ import ( | ||||||
| 	capnp "zombiezen.com/go/capnproto2" | 	capnp "zombiezen.com/go/capnproto2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | const ( | ||||||
|  | 	testURL               = "tunnel.example.com" | ||||||
|  | 	testTunnelID          = "asdfghjkl;" | ||||||
|  | 	testRetryAfterSeconds = 19 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	testErr      = fmt.Errorf("Invalid credential") | ||||||
|  | 	testLogLines = []string{"all", "working"} | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // *PermanentRegistrationError implements TunnelRegistrationError
 | ||||||
|  | var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil) | ||||||
|  | 
 | ||||||
|  | // *RetryableRegistrationError implements TunnelRegistrationError
 | ||||||
|  | var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil) | ||||||
|  | 
 | ||||||
| func TestTunnelRegistration(t *testing.T) { | func TestTunnelRegistration(t *testing.T) { | ||||||
| 	testCases := []*TunnelRegistration{ | 	testCases := []*TunnelRegistration{ | ||||||
| 		&TunnelRegistration{ | 		NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID), | ||||||
| 			Err:               "it broke", | 		NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID), | ||||||
| 			Url:               "asdf.cftunnel.com", | 		NewPermanentRegistrationError(testErr).Serialize(), | ||||||
| 			LogLines:          []string{"it", "was", "broken"}, | 		NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(), | ||||||
| 			PermanentFailure:  true, |  | ||||||
| 			TunnelID:          "asdfghjkl;", |  | ||||||
| 			RetryAfterSeconds: 19, |  | ||||||
| 		}, |  | ||||||
| 	} | 	} | ||||||
| 	for i, testCase := range testCases { | 	for i, testCase := range testCases { | ||||||
| 		_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) | 		_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue