From 04b1e4f859cf113ede5a7acaf187ffbdb8e5840c Mon Sep 17 00:00:00 2001 From: Igor Postelnik Date: Thu, 14 Jan 2021 16:33:36 -0600 Subject: [PATCH] TUN-3738: Refactor observer to avoid potential of blocking on tunnel notifications --- cmd/cloudflared/tunnel/cmd.go | 25 ++++------ cmd/cloudflared/tunnel/configuration.go | 7 +-- cmd/cloudflared/ui/launch_ui.go | 47 +++++++++--------- connection/connection_test.go | 8 +--- connection/metrics.go | 14 +++++- connection/observer.go | 64 ++++++++++++++++++++----- connection/observer_test.go | 29 +++++++++-- metrics/metrics.go | 13 +++-- metrics/readiness.go | 40 ++++++++-------- metrics/readiness_test.go | 51 ++++++++++++++++++++ origin/tunnel.go | 12 ----- origin/tunnel_test.go | 2 +- 12 files changed, 201 insertions(+), 111 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index e214ab1e..66f576a3 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -328,13 +328,9 @@ func StartServer( transportLog := logger.CreateTransportLoggerFromContext(c, isUIEnabled) - readinessCh := make(chan connection.Event, 16) - uiCh := make(chan connection.Event, 16) - eventChannels := []chan connection.Event{ - readinessCh, - uiCh, - } - tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, transportLog, namedTunnel, isUIEnabled, eventChannels) + observer := connection.NewObserver(log, isUIEnabled) + + tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, observer, namedTunnel) if err != nil { log.Err(err).Msg("Couldn't start tunnel") return err @@ -349,7 +345,9 @@ func StartServer( wg.Add(1) go func() { defer wg.Done() - errC <- metrics.ServeMetrics(metricsListener, shutdownC, readinessCh, log) + readinessServer := metrics.NewReadyServer(log) + observer.RegisterSink(readinessServer) + errC <- metrics.ServeMetrics(metricsListener, shutdownC, readinessServer, log) }() if err := ingressRules.StartOrigins(&wg, log, shutdownC, errC); err != nil { @@ -369,20 +367,15 @@ func StartServer( }() if isUIEnabled { - tunnelInfo := ui.NewUIModel( + tunnelUI := ui.NewUIModel( version, hostname, metricsListener.Addr().String(), &ingressRules, tunnelConfig.HAConnections, ) - tunnelInfo.LaunchUI(ctx, log, transportLog, uiCh) - } else { - go func() { - for range uiCh { - // Consume UI events into a noop - } - }() + app := tunnelUI.Launch(ctx, log, transportLog) + observer.RegisterSink(app) } return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 43729380..8363814b 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -157,10 +157,8 @@ func prepareTunnelConfig( buildInfo *buildinfo.BuildInfo, version string, log *zerolog.Logger, - transportLogger *zerolog.Logger, + observer *connection.Observer, namedTunnel *connection.NamedTunnelConfig, - isUIEnabled bool, - eventChans []chan connection.Event, ) (*origin.TunnelConfig, ingress.Ingress, error) { isNamedTunnel := namedTunnel != nil @@ -281,7 +279,7 @@ func prepareTunnelConfig( LBPool: c.String("lb-pool"), Tags: tags, Log: log, - Observer: connection.NewObserver(transportLogger, eventChans, isUIEnabled), + Observer: observer, ReportedVersion: version, // Note TUN-3758 , we use Int because UInt is not supported with altsrc Retries: uint(c.Int("retries")), @@ -289,7 +287,6 @@ func prepareTunnelConfig( NamedTunnel: namedTunnel, ClassicTunnel: classicTunnel, MuxerConfig: muxerConfig, - TunnelEventChans: eventChans, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, }, ingressRules, nil diff --git a/cmd/cloudflared/ui/launch_ui.go b/cmd/cloudflared/ui/launch_ui.go index 8c3bcf4f..5a56b22f 100644 --- a/cmd/cloudflared/ui/launch_ui.go +++ b/cmd/cloudflared/ui/launch_ui.go @@ -48,11 +48,10 @@ func NewUIModel(version, hostname, metricsURL string, ing *ingress.Ingress, haCo } } -func (data *uiModel) LaunchUI( +func (data *uiModel) Launch( ctx context.Context, log, transportLog *zerolog.Logger, - tunnelEventChan <-chan connection.Event, -) { +) connection.EventSink { // Configure the logger to stream logs into the textview // Add TextView as a group to write output to @@ -114,28 +113,9 @@ func (data *uiModel) LaunchUI( grid.AddItem(logFrame, 4, 0, 5, 2, 0, 0, false) go func() { - for { - select { - case <-ctx.Done(): - app.Stop() - return - case event := <-tunnelEventChan: - switch event.EventType { - case connection.Connected: - data.setConnTableCell(event, connTable, palette) - case connection.Disconnected, connection.Reconnecting: - data.changeConnStatus(event, connTable, log, palette) - case connection.SetURL: - tunnelHostText.SetText(event.URL) - data.edgeURL = event.URL - case connection.RegisteringTunnel: - if data.edgeURL == "" { - tunnelHostText.SetText("Registering tunnel...") - } - } - } - app.Draw() - } + <-ctx.Done() + app.Stop() + return }() go func() { @@ -143,6 +123,23 @@ func (data *uiModel) LaunchUI( log.Error().Msgf("Error launching UI: %s", err) } }() + + return connection.EventSinkFunc(func(event connection.Event) { + switch event.EventType { + case connection.Connected: + data.setConnTableCell(event, connTable, palette) + case connection.Disconnected, connection.Reconnecting: + data.changeConnStatus(event, connTable, log, palette) + case connection.SetURL: + tunnelHostText.SetText(event.URL) + data.edgeURL = event.URL + case connection.RegisteringTunnel: + if data.edgeURL == "" { + tunnelHostText.SetText("Registering tunnel...") + } + } + app.Draw() + }) } func NewDynamicColorTextView() *tview.TextView { diff --git a/connection/connection_test.go b/connection/connection_test.go index 33ec7433..f15f8c6a 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -27,13 +27,7 @@ var ( Scheme: "https", Host: "connectiontest.argotunnel.com", } - testTunnelEventChan = make(chan Event) - testObserver = &Observer{ - &log, - m, - []chan Event{testTunnelEventChan}, - false, - } + testObserver = NewObserver(&log, false) testLargeResp = make([]byte, largeFileSize) ) diff --git a/connection/metrics.go b/connection/metrics.go index 119f98b9..309da6ea 100644 --- a/connection/metrics.go +++ b/connection/metrics.go @@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 { } // Metrics that can be collected without asking the edge -func newTunnelMetrics() *tunnelMetrics { +func initTunnelMetrics() *tunnelMetrics { maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: MetricsNamespace, @@ -403,3 +403,15 @@ func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) { t.serverLocations.WithLabelValues(connectionID, loc).Inc() t.oldServerLocations[connectionID] = loc } + +var tunnelMetricsInternal struct { + sync.Once + metrics *tunnelMetrics +} + +func newTunnelMetrics() *tunnelMetrics { + tunnelMetricsInternal.Do(func() { + tunnelMetricsInternal.metrics = initTunnelMetrics() + }) + return tunnelMetricsInternal.metrics +} \ No newline at end of file diff --git a/connection/observer.go b/connection/observer.go index 8e4f8fdb..1b3433b3 100644 --- a/connection/observer.go +++ b/connection/observer.go @@ -10,22 +10,37 @@ import ( "github.com/rs/zerolog" ) -const LogFieldLocation = "location" +const ( + LogFieldLocation = "location" + observerChannelBufferSize = 16 +) type Observer struct { - log *zerolog.Logger - metrics *tunnelMetrics - tunnelEventChans []chan Event - uiEnabled bool + log *zerolog.Logger + metrics *tunnelMetrics + tunnelEventChan chan Event + uiEnabled bool + addSinkChan chan EventSink } -func NewObserver(log *zerolog.Logger, tunnelEventChans []chan Event, uiEnabled bool) *Observer { - return &Observer{ - log, - newTunnelMetrics(), - tunnelEventChans, - uiEnabled, +type EventSink interface { + OnTunnelEvent(event Event) +} + +func NewObserver(log *zerolog.Logger, uiEnabled bool) *Observer { + o := &Observer{ + log: log, + metrics: newTunnelMetrics(), + uiEnabled: uiEnabled, + tunnelEventChan: make(chan Event, observerChannelBufferSize), + addSinkChan: make(chan EventSink, observerChannelBufferSize), } + go o.dispatchEvents() + return o +} + +func (o *Observer) RegisterSink(sink EventSink) { + o.addSinkChan <- sink } func (o *Observer) logServerInfo(connIndex uint8, location, msg string) { @@ -105,7 +120,30 @@ func (o *Observer) SendDisconnect(connIndex uint8) { } func (o *Observer) sendEvent(e Event) { - for _, ch := range o.tunnelEventChans { - ch <- e + select { + case o.tunnelEventChan <- e: + break + default: + o.log.Warn().Msg("observer channel buffer is full") } } + +func (o *Observer) dispatchEvents() { + var sinks []EventSink + for { + select { + case sink := <-o.addSinkChan: + sinks = append(sinks, sink) + case evt := <-o.tunnelEventChan: + for _, sink := range sinks { + sink.OnTunnelEvent(evt) + } + } + } +} + +type EventSinkFunc func(event Event) + +func (f EventSinkFunc) OnTunnelEvent(event Event) { + f(event) +} diff --git a/connection/observer_test.go b/connection/observer_test.go index 6116cded..be20b068 100644 --- a/connection/observer_test.go +++ b/connection/observer_test.go @@ -4,14 +4,13 @@ import ( "strconv" "sync" "testing" + "time" "github.com/stretchr/testify/assert" ) -// can only be called once -var m = newTunnelMetrics() - func TestRegisterServerLocation(t *testing.T) { + m := newTunnelMetrics() tunnels := 20 var wg sync.WaitGroup wg.Add(tunnels) @@ -43,3 +42,27 @@ func TestRegisterServerLocation(t *testing.T) { } } + +func TestObserverEventsDontBlock(t *testing.T) { + observer := NewObserver(&log, false) + var mu sync.Mutex + observer.RegisterSink(EventSinkFunc(func(_ Event) { + // callback will block if lock is already held + mu.Lock() + mu.Unlock() + })) + + timeout := time.AfterFunc(5*time.Second, func() { + mu.Unlock() // release the callback on timer expiration + t.Fatal("observer is blocked") + }) + + mu.Lock() // block the callback + for i := 0; i < 2 * observerChannelBufferSize; i++ { + observer.sendRegisteringEvent() + } + if pending := timeout.Stop(); pending { + // release the callback if timer hasn't expired yet + mu.Unlock() + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index faae3792..e915de25 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -10,8 +10,6 @@ import ( "sync" "time" - "github.com/cloudflare/cloudflared/connection" - "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" @@ -23,21 +21,22 @@ const ( startupTime = time.Millisecond * 500 ) -func newMetricsHandler(connectionEvents <-chan connection.Event, log *zerolog.Logger) *http.ServeMux { - readyServer := NewReadyServer(connectionEvents, log) +func newMetricsHandler(readyServer *ReadyServer) *http.ServeMux { mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprintf(w, "OK\n") }) - mux.Handle("/ready", readyServer) + if readyServer != nil { + mux.Handle("/ready", readyServer) + } return mux } func ServeMetrics( l net.Listener, shutdownC <-chan struct{}, - connectionEvents <-chan connection.Event, + readyServer *ReadyServer, log *zerolog.Logger, ) (err error) { var wg sync.WaitGroup @@ -45,7 +44,7 @@ func ServeMetrics( trace.AuthRequest = func(*http.Request) (bool, bool) { return true, true } // TODO: parameterize ReadTimeout and WriteTimeout. The maximum time we can // profile CPU usage depends on WriteTimeout - h := newMetricsHandler(connectionEvents, log) + h := newMetricsHandler(readyServer) server := &http.Server{ ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, diff --git a/metrics/readiness.go b/metrics/readiness.go index 52826772..4cae7592 100644 --- a/metrics/readiness.go +++ b/metrics/readiness.go @@ -19,30 +19,28 @@ type ReadyServer struct { } // NewReadyServer initializes a ReadyServer and starts listening for dis/connection events. -func NewReadyServer(connectionEvents <-chan conn.Event, log *zerolog.Logger) *ReadyServer { - rs := ReadyServer{ +func NewReadyServer(log *zerolog.Logger) *ReadyServer { + return &ReadyServer{ isConnected: make(map[int]bool, 0), log: log, } - go func() { - for c := range connectionEvents { - switch c.EventType { - case conn.Connected: - rs.Lock() - rs.isConnected[int(c.Index)] = true - rs.Unlock() - case conn.Disconnected, conn.Reconnecting, conn.RegisteringTunnel: - rs.Lock() - rs.isConnected[int(c.Index)] = false - rs.Unlock() - case conn.SetURL: - continue - default: - rs.log.Error().Msgf("Unknown connection event case %v", c) - } - } - }() - return &rs +} + +func (rs *ReadyServer) OnTunnelEvent(c conn.Event) { + switch c.EventType { + case conn.Connected: + rs.Lock() + rs.isConnected[int(c.Index)] = true + rs.Unlock() + case conn.Disconnected, conn.Reconnecting, conn.RegisteringTunnel: + rs.Lock() + rs.isConnected[int(c.Index)] = false + rs.Unlock() + case conn.SetURL: + break + default: + rs.log.Error().Msgf("Unknown connection event case %v", c) + } } type body struct { diff --git a/metrics/readiness_test.go b/metrics/readiness_test.go index 7a00abb7..213df36c 100644 --- a/metrics/readiness_test.go +++ b/metrics/readiness_test.go @@ -3,6 +3,11 @@ package metrics import ( "net/http" "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + + "github.com/cloudflare/cloudflared/connection" ) func TestReadyServer_makeResponse(t *testing.T) { @@ -56,3 +61,49 @@ func TestReadyServer_makeResponse(t *testing.T) { }) } } + +func TestReadinessEventHandling(t *testing.T) { + nopLogger := zerolog.Nop() + rs := NewReadyServer(&nopLogger) + + // start not ok + code, ready := rs.makeResponse() + assert.NotEqualValues(t, http.StatusOK, code) + assert.Zero(t, ready) + + // one connected => ok + rs.OnTunnelEvent(connection.Event{ + Index: 1, + EventType: connection.Connected, + }) + code, ready = rs.makeResponse() + assert.EqualValues(t, http.StatusOK, code) + assert.EqualValues(t, 1, ready) + + // another connected => still ok + rs.OnTunnelEvent(connection.Event{ + Index: 2, + EventType: connection.Connected, + }) + code, ready = rs.makeResponse() + assert.EqualValues(t, http.StatusOK, code) + assert.EqualValues(t, 2, ready) + + // one reconnecting => still ok + rs.OnTunnelEvent(connection.Event{ + Index: 2, + EventType: connection.Reconnecting, + }) + code, ready = rs.makeResponse() + assert.EqualValues(t, http.StatusOK, code) + assert.EqualValues(t, 1, ready) + + // other disconnected => not ok + rs.OnTunnelEvent(connection.Event{ + Index: 1, + EventType: connection.Disconnected, + }) + code, ready = rs.makeResponse() + assert.NotEqualValues(t, http.StatusOK, code) + assert.Zero(t, ready) +} diff --git a/origin/tunnel.go b/origin/tunnel.go index ed5c604f..eb299112 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -12,7 +12,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" @@ -27,9 +26,7 @@ import ( const ( dialTimeout = 15 * time.Second - muxerTimeout = 5 * time.Second lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" - DuplicateConnectionError = "EDUPCONN" FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" ) @@ -37,9 +34,7 @@ const ( type rpcName string const ( - register rpcName = "register" reconnect rpcName = "reconnect" - unregister rpcName = "unregister" authenticate rpcName = " authenticate" ) @@ -64,7 +59,6 @@ type TunnelConfig struct { NamedTunnel *connection.NamedTunnelConfig ClassicTunnel *connection.ClassicTunnelConfig MuxerConfig *connection.MuxerConfig - TunnelEventChans []chan connection.Event ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config } @@ -90,11 +84,6 @@ type clientRegisterTunnelError struct { cause error } -func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError { - counter.WithLabelValues(cause.Error(), string(name)).Inc() - return clientRegisterTunnelError{cause: cause} -} - func (e clientRegisterTunnelError) Error() string { return e.cause.Error() } @@ -466,5 +455,4 @@ func activeIncidentsMsg(incidents []Incident) string { incidentStrings = append(incidentStrings, incidentString) } return preamble + " " + strings.Join(incidentStrings, "; ") - } diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index f7090c8f..8f4c973c 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -53,7 +53,7 @@ func TestWaitForBackoffFallback(t *testing.T) { config := &TunnelConfig{ Log: &log, ProtocolSelector: protocolSelector, - Observer: connection.NewObserver(nil, nil, false), + Observer: connection.NewObserver(nil, false), } connIndex := uint8(1)