diff --git a/origin/tunnel.go b/origin/tunnel.go index 070f38ef..86a50a16 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -19,6 +19,7 @@ import ( "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/websocket" @@ -335,23 +336,21 @@ func RegisterTunnel( serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { return nil }) - registration, err := ts.RegisterTunnel( + LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) + registration := ts.RegisterTunnel( ctx, config.OriginCert, config.Hostname, config.RegistrationOptions(connectionID, originLocalIP, uuid), ) - LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) - if err != nil { + + if registrationErr := registration.DeserializeError(); registrationErr != nil { // RegisterTunnel RPC failure - return newClientRegisterTunnelError(err, config.Metrics.regFail) - } - for _, logLine := range registration.LogLines { - logger.Info(logLine) + return processRegisterTunnelError(registrationErr, config.Metrics) } - if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil { - return regErr + for _, logLine := range registration.LogLines { + logger.Info(logLine) } if registration.TunnelID != "" { @@ -374,22 +373,19 @@ func RegisterTunnel( config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc() logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional") + config.Metrics.regSuccess.Inc() return nil } -func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error { - if err == "" { - metrics.regSuccess.Inc() - return nil - } - - metrics.regFail.WithLabelValues(err).Inc() - if err == DuplicateConnectionError { +func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error { + if err.Error() == DuplicateConnectionError { + metrics.regFail.WithLabelValues("dup_edge_conn").Inc() return dupConnRegisterTunnelError{} } + metrics.regFail.WithLabelValues("server_error").Inc() return serverRegisterTunnelError{ - cause: fmt.Errorf("Server error: %s", err), - permanent: permanentFailure, + cause: fmt.Errorf("Server error: %s", err.Error()), + permanent: err.IsPermanent(), } } diff --git a/tunnelrpc/pogs/tunnelrpc.go b/tunnelrpc/pogs/tunnelrpc.go index b675f0f2..1c169554 100644 --- a/tunnelrpc/pogs/tunnelrpc.go +++ b/tunnelrpc/pogs/tunnelrpc.go @@ -9,13 +9,16 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - log "github.com/sirupsen/logrus" capnp "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/pogs" "zombiezen.com/go/capnproto2/rpc" "zombiezen.com/go/capnproto2/server" ) +const ( + defaultRetryAfterSeconds = 15 +) + type Authentication struct { Key string Email string @@ -33,14 +36,111 @@ func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error } type TunnelRegistration struct { + SuccessfulTunnelRegistration Err string - Url string - LogLines []string PermanentFailure bool - TunnelID string `capnp:"tunnelID"` RetryAfterSeconds uint16 } +type SuccessfulTunnelRegistration struct { + Url string + LogLines []string + TunnelID string `capnp:"tunnelID"` +} + +func NewSuccessfulTunnelRegistration( + url string, + logLines []string, + tunnelID string, +) *TunnelRegistration { + // Marshal nil will result in an error + if logLines == nil { + logLines = []string{} + } + return &TunnelRegistration{ + SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{ + Url: url, + LogLines: logLines, + TunnelID: tunnelID, + }, + } +} + +// Not calling this function Error() to avoid confusion with implementing error interface +func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError { + if tr.Err != "" { + err := fmt.Errorf(tr.Err) + if tr.PermanentFailure { + return NewPermanentRegistrationError(err) + } + retryAfterSeconds := tr.RetryAfterSeconds + if retryAfterSeconds < defaultRetryAfterSeconds { + retryAfterSeconds = defaultRetryAfterSeconds + } + return NewRetryableRegistrationError(err, retryAfterSeconds) + } + return nil +} + +type TunnelRegistrationError interface { + error + Serialize() *TunnelRegistration + IsPermanent() bool +} + +type PermanentRegistrationError struct { + err string +} + +func NewPermanentRegistrationError(err error) TunnelRegistrationError { + return &PermanentRegistrationError{ + err: err.Error(), + } +} + +func (pre *PermanentRegistrationError) Error() string { + return pre.err +} + +func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration { + return &TunnelRegistration{ + Err: pre.err, + PermanentFailure: true, + } +} + +func (*PermanentRegistrationError) IsPermanent() bool { + return true +} + +type RetryableRegistrationError struct { + err string + retryAfterSeconds uint16 +} + +func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError { + return &RetryableRegistrationError{ + err: err.Error(), + retryAfterSeconds: retryAfterSeconds, + } +} + +func (rre *RetryableRegistrationError) Error() string { + return rre.err +} + +func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration { + return &TunnelRegistration{ + Err: rre.err, + PermanentFailure: false, + RetryAfterSeconds: rre.retryAfterSeconds, + } +} + +func (*RetryableRegistrationError) IsPermanent() bool { + return false +} + func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error { return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p) } @@ -325,7 +425,7 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar } type TunnelServer interface { - RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) + RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration GetServerInfo(ctx context.Context) (*ServerInfo, error) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error) @@ -359,15 +459,12 @@ func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerT return err } server.Ack(p.Options) - registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions) - if err != nil { - return err - } + registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions) + result, err := p.Results.NewResult() if err != nil { return err } - log.Info(registration.TunnelID) return MarshalTunnelRegistration(result, registration) } @@ -420,7 +517,7 @@ func (c TunnelServer_PogsClient) Close() error { return c.Conn.Close() } -func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) { +func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error { err := p.SetOriginCert(originCert) @@ -443,9 +540,13 @@ func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert }) retval, err := promise.Result().Struct() if err != nil { - return nil, err + return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize() } - return UnmarshalTunnelRegistration(retval) + registration, err := UnmarshalTunnelRegistration(retval) + if err != nil { + return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize() + } + return registration } func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) { diff --git a/tunnelrpc/pogs/tunnelrpc_test.go b/tunnelrpc/pogs/tunnelrpc_test.go index 98f061e0..4ae8f2fb 100644 --- a/tunnelrpc/pogs/tunnelrpc_test.go +++ b/tunnelrpc/pogs/tunnelrpc_test.go @@ -1,6 +1,7 @@ package pogs import ( + "fmt" "reflect" "testing" "time" @@ -11,16 +12,29 @@ import ( capnp "zombiezen.com/go/capnproto2" ) +const ( + testURL = "tunnel.example.com" + testTunnelID = "asdfghjkl;" + testRetryAfterSeconds = 19 +) + +var ( + testErr = fmt.Errorf("Invalid credential") + testLogLines = []string{"all", "working"} +) + +// *PermanentRegistrationError implements TunnelRegistrationError +var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil) + +// *RetryableRegistrationError implements TunnelRegistrationError +var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil) + func TestTunnelRegistration(t *testing.T) { testCases := []*TunnelRegistration{ - &TunnelRegistration{ - Err: "it broke", - Url: "asdf.cftunnel.com", - LogLines: []string{"it", "was", "broken"}, - PermanentFailure: true, - TunnelID: "asdfghjkl;", - RetryAfterSeconds: 19, - }, + NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID), + NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID), + NewPermanentRegistrationError(testErr).Serialize(), + NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(), } for i, testCase := range testCases { _, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))