package origin import ( "context" "crypto/tls" "fmt" "net" "runtime/debug" "strings" "sync" "time" "github.com/google/uuid" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) const ( dialTimeout = 15 * time.Second muxerTimeout = 5 * time.Second lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" DuplicateConnectionError = "EDUPCONN" FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" ) type rpcName string const ( register rpcName = "register" reconnect rpcName = "reconnect" unregister rpcName = "unregister" authenticate rpcName = " authenticate" ) type TunnelConfig struct { ConnectionConfig *connection.Config BuildInfo *buildinfo.BuildInfo ClientID string CloseConnOnce *sync.Once // Used to close connectedSignal no more than once EdgeAddrs []string HAConnections int IncidentLookup IncidentLookup IsAutoupdated bool IsFreeTunnel bool LBPool string Tags []tunnelpogs.Tag Log *zerolog.Logger Observer *connection.Observer ReportedVersion string Retries uint RunFromTerminal bool NamedTunnel *connection.NamedTunnelConfig ClassicTunnel *connection.ClassicTunnelConfig MuxerConfig *connection.MuxerConfig TunnelEventChans []chan connection.Event ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config } type muxerShutdownError struct{} func (e muxerShutdownError) Error() string { return "muxer shutdown" } // RegisterTunnel error from server type serverRegisterTunnelError struct { cause error permanent bool } func (e serverRegisterTunnelError) Error() string { return e.cause.Error() } // RegisterTunnel error from client type clientRegisterTunnelError struct { cause error } func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError { counter.WithLabelValues(cause.Error(), string(name)).Inc() return clientRegisterTunnelError{cause: cause} } func (e clientRegisterTunnelError) Error() string { return e.cause.Error() } func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { policy := tunnelrpc.ExistingTunnelPolicy_balance if c.HAConnections <= 1 && c.LBPool == "" { policy = tunnelrpc.ExistingTunnelPolicy_disconnect } return &tunnelpogs.RegistrationOptions{ ClientID: c.ClientID, Version: c.ReportedVersion, OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch), ExistingTunnelPolicy: policy, PoolName: c.LBPool, Tags: c.Tags, ConnectionID: connectionID, OriginLocalIP: OriginLocalIP, IsAutoupdated: c.IsAutoupdated, RunFromTerminal: c.RunFromTerminal, CompressionQuality: uint64(c.MuxerConfig.CompressionSetting), UUID: uuid.String(), Features: c.SupportedFeatures(), } } func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions { // attempt to parse out origin IP, but don't fail since it's informational field host, _, _ := net.SplitHostPort(originLocalAddr) originIP := net.ParseIP(host) return &tunnelpogs.ConnectionOptions{ Client: c.NamedTunnel.Client, OriginLocalIP: originIP, ReplaceExisting: c.ConnectionConfig.ReplaceExisting, CompressionQuality: uint8(c.MuxerConfig.CompressionSetting), NumPreviousAttempts: numPreviousAttempts, } } func (c *TunnelConfig) SupportedFeatures() []string { features := []string{FeatureSerializedHeaders} if c.NamedTunnel == nil { features = append(features, FeatureQuickReconnects) } return features } func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { s, err := NewSupervisor(config, cloudflaredID) if err != nil { return err } return s.Run(ctx, connectedSignal, reconnectCh) } func ServeTunnelLoop( ctx context.Context, credentialManager *reconnectCredentialManager, config *TunnelConfig, addr *net.TCPAddr, connIndex uint8, connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, ) error { haConnections.Inc() defer haConnections.Dec() connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger() protocallFallback := &protocallFallback{ BackoffHandler{MaxRetries: config.Retries}, config.ProtocolSelector.Current(), false, } connectedFuse := h2mux.NewBooleanFuse() go func() { if connectedFuse.Await() { connectedSignal.Notify() } }() // Ensure the above goroutine will terminate if we return without connecting defer connectedFuse.Fuse(false) // Each connection to keep its own copy of protocol, because individual connections might fallback // to another protocol when a particular metal doesn't support new protocol for { err, recoverable := ServeTunnel( ctx, &connLog, credentialManager, config, addr, connIndex, connectedFuse, protocallFallback, cloudflaredUUID, reconnectCh, protocallFallback.protocol, ) if !recoverable { return err } err = waitForBackoff(ctx, &connLog, protocallFallback, config, connIndex, err) if err != nil { return err } } } // protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches // max retries type protocallFallback struct { BackoffHandler protocol connection.Protocol inFallback bool } func (pf *protocallFallback) reset() { pf.resetNow() pf.inFallback = false } func (pf *protocallFallback) fallback(fallback connection.Protocol) { pf.resetNow() pf.protocol = fallback pf.inFallback = true } // Expect err to always be non nil func waitForBackoff( ctx context.Context, log *zerolog.Logger, protobackoff *protocallFallback, config *TunnelConfig, connIndex uint8, err error, ) error { duration, ok := protobackoff.GetBackoffDuration(ctx) if !ok { return err } config.Observer.SendReconnect(connIndex) log.Info(). Err(err). Msgf("Retrying connection in %s seconds", duration) protobackoff.Backoff(ctx) if protobackoff.ReachedMaxRetries() { fallback, hasFallback := config.ProtocolSelector.Fallback() if !hasFallback { return err } // Already using fallback protocol, no point to retry if protobackoff.protocol == fallback { return err } log.Info().Msgf("Fallback to use %s", fallback) protobackoff.fallback(fallback) } else if !protobackoff.inFallback { current := config.ProtocolSelector.Current() if protobackoff.protocol != current { protobackoff.protocol = current config.Log.Info().Msgf("Change protocol to %s", current) } } return nil } func ServeTunnel( ctx context.Context, log *zerolog.Logger, credentialManager *reconnectCredentialManager, config *TunnelConfig, addr *net.TCPAddr, connIndex uint8, fuse *h2mux.BooleanFuse, backoff *protocallFallback, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, protocol connection.Protocol, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { if r := recover(); r != nil { var ok bool err, ok = r.(error) if !ok { err = fmt.Errorf("ServeTunnel: %v", r) } err = errors.Wrapf(err, "stack trace: %s", string(debug.Stack())) recoverable = true } }() defer config.Observer.SendDisconnect(connIndex) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr) if err != nil { return err, true } connectedFuse := &connectedFuse{ fuse: fuse, backoff: backoff, } if protocol == connection.HTTP2 { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) return ServeHTTP2(ctx, log, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh) } return ServeH2mux( ctx, log, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh, ) } func ServeH2mux( ctx context.Context, log *zerolog.Logger, credentialManager *reconnectCredentialManager, config *TunnelConfig, edgeConn net.Conn, connIndex uint8, connectedFuse *connectedFuse, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { config.Log.Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection( config.ConnectionConfig, config.MuxerConfig, edgeConn, connIndex, config.Observer, ) if err != nil { return err, recoverable } errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() (err error) { if config.NamedTunnel != nil { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) } registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) errGroup.Go(listenReconnect(serveCtx, reconnectCh)) err = errGroup.Wait() if err != nil { switch err := err.(type) { case *connection.DupConnRegisterTunnelError: // don't retry this connection anymore, let supervisor pick new a address return err, false case *serverRegisterTunnelError: log.Err(err).Msg("Register tunnel error from server side") // Don't send registration error return from server to Sentry. They are // logged on server side if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { log.Error().Msg(activeIncidentsMsg(incidents)) } return err.cause, !err.permanent case *clientRegisterTunnelError: log.Err(err).Msg("Register tunnel error on client side") return err, true case *muxerShutdownError: log.Info().Msg("Muxer shutdown") return err, true case *ReconnectSignal: log.Info(). Uint8(connection.LogFieldConnIndex, connIndex). Msgf("Restarting connection due to reconnect signal in %s", err.Delay) err.DelayBeforeReconnect() return err, true default: if err == context.Canceled { log.Debug().Err(err).Msgf("Serve tunnel error") return err, false } log.Err(err).Msgf("Serve tunnel error") return err, true } } return nil, true } func ServeHTTP2( ctx context.Context, log *zerolog.Logger, config *TunnelConfig, tlsServerConn net.Conn, connOptions *tunnelpogs.ConnectionOptions, connIndex uint8, connectedFuse connection.ConnectedFuse, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { log.Debug().Msgf("Connecting via http2") server := connection.NewHTTP2Connection( tlsServerConn, config.ConnectionConfig, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse, ) errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { server.Serve(serveCtx) return fmt.Errorf("connection with edge closed") }) errGroup.Go(listenReconnect(serveCtx, reconnectCh)) err = errGroup.Wait() if err != nil { return err, true } return nil, false } func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error { return func() error { select { case reconnect := <-reconnectCh: return &reconnect case <-ctx.Done(): return nil } } } type connectedFuse struct { fuse *h2mux.BooleanFuse backoff *protocallFallback } func (cf *connectedFuse) Connected() { cf.fuse.Fuse(true) cf.backoff.reset() } func (cf *connectedFuse) IsConnected() bool { return cf.fuse.Value() } func activeIncidentsMsg(incidents []Incident) string { preamble := "There is an active Cloudflare incident that may be related:" if len(incidents) > 1 { preamble = "There are active Cloudflare incidents that may be related:" } incidentStrings := []string{} for _, incident := range incidents { incidentString := fmt.Sprintf("%s (%s)", incident.Name, incident.URL()) incidentStrings = append(incidentStrings, incidentString) } return preamble + " " + strings.Join(incidentStrings, "; ") }