From 38fb0b28b6ff9ec662a3b262bae8d7d343022365 Mon Sep 17 00:00:00 2001 From: Adam Chalmers Date: Mon, 30 Nov 2020 14:05:37 -0600 Subject: [PATCH] TUN-3593: /ready endpoint for k8s readiness. Move tunnel events out of UI package, into connection package. --- cmd/cloudflared/tunnel/cmd.go | 34 ++++++----- cmd/cloudflared/tunnel/configuration.go | 13 ++-- cmd/cloudflared/ui/launch_ui.go | 56 +++++++---------- connection/connection_test.go | 6 +- connection/event.go | 25 ++++++++ connection/observer.go | 42 +++++++------ metrics/metrics.go | 26 ++++++-- metrics/readiness.go | 80 +++++++++++++++++++++++++ metrics/readiness_test.go | 58 ++++++++++++++++++ origin/tunnel.go | 15 +---- origin/tunnel_test.go | 1 + tunneldns/tunnel.go | 2 +- 12 files changed, 259 insertions(+), 99 deletions(-) create mode 100644 connection/event.go create mode 100644 metrics/readiness.go create mode 100644 metrics/readiness_test.go diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 58f963a8..aeb80747 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -271,18 +271,6 @@ func StartServer( // Wait for proxy-dns to come up (if used) <-dnsReadySignal - metricsListener, err := listeners.Listen("tcp", c.String("metrics")) - if err != nil { - generalLogger.Errorf("Error opening metrics server listener: %s", err) - return errors.Wrap(err, "Error opening metrics server listener") - } - defer metricsListener.Close() - wg.Add(1) - go func() { - defer wg.Done() - errC <- metrics.ServeMetrics(metricsListener, shutdownC, generalLogger) - }() - go notifySystemd(connectedSignal) if c.IsSet("pidfile") { go writePidFile(connectedSignal, c.String("pidfile"), generalLogger) @@ -331,12 +319,30 @@ func StartServer( return errors.Wrap(err, "error setting up transport logger") } - tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, generalLogger, transportLogger, namedTunnel, isUIEnabled) + readinessCh := make(chan connection.Event, 16) + uiCh := make(chan connection.Event, 16) + eventChannels := []chan connection.Event{ + readinessCh, + uiCh, + } + tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, generalLogger, transportLogger, namedTunnel, isUIEnabled, eventChannels) if err != nil { generalLogger.Errorf("Couldn't start tunnel: %v", err) return err } + metricsListener, err := listeners.Listen("tcp", c.String("metrics")) + if err != nil { + generalLogger.Errorf("Error opening metrics server listener: %s", err) + return errors.Wrap(err, "Error opening metrics server listener") + } + defer metricsListener.Close() + wg.Add(1) + go func() { + defer wg.Done() + errC <- metrics.ServeMetrics(metricsListener, shutdownC, readinessCh, generalLogger) + }() + ingressRules.StartOrigins(&wg, generalLogger, shutdownC, errC) reconnectCh := make(chan origin.ReconnectSignal, 1) @@ -363,7 +369,7 @@ func StartServer( if err != nil { return err } - tunnelInfo.LaunchUI(ctx, generalLogger, transportLogger, logLevels, tunnelConfig.TunnelEventChan) + tunnelInfo.LaunchUI(ctx, generalLogger, transportLogger, logLevels, uiCh) } return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), generalLogger) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index fb57abb7..df81ce85 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -10,7 +10,6 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" - "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" @@ -160,7 +159,8 @@ func prepareTunnelConfig( logger logger.Service, transportLogger logger.Service, namedTunnel *connection.NamedTunnelConfig, - uiIsEnabled bool, + isUIEnabled bool, + eventChans []chan connection.Event, ) (*origin.TunnelConfig, ingress.Ingress, error) { isNamedTunnel := namedTunnel != nil @@ -261,11 +261,6 @@ func prepareTunnelConfig( MetricsUpdateFreq: c.Duration("metrics-update-freq"), } - var tunnelEventChan chan ui.TunnelEvent - if uiIsEnabled { - tunnelEventChan = make(chan ui.TunnelEvent, 16) - } - return &origin.TunnelConfig{ ConnectionConfig: connectionConfig, BuildInfo: buildInfo, @@ -278,14 +273,14 @@ func prepareTunnelConfig( LBPool: c.String("lb-pool"), Tags: tags, Logger: logger, - Observer: connection.NewObserver(transportLogger, tunnelEventChan), + Observer: connection.NewObserver(transportLogger, eventChans, isUIEnabled), ReportedVersion: version, Retries: c.Uint("retries"), RunFromTerminal: isRunningFromTerminal(), NamedTunnel: namedTunnel, ClassicTunnel: classicTunnel, MuxerConfig: muxerConfig, - TunnelEventChan: tunnelEventChan, + TunnelEventChans: eventChans, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, }, ingressRules, nil diff --git a/cmd/cloudflared/ui/launch_ui.go b/cmd/cloudflared/ui/launch_ui.go index 9f830050..3869559b 100644 --- a/cmd/cloudflared/ui/launch_ui.go +++ b/cmd/cloudflared/ui/launch_ui.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" @@ -15,24 +16,7 @@ import ( type connState struct { location string - state status -} - -type status int - -const ( - Disconnected status = iota - Connected - Reconnecting - SetUrl - RegisteringTunnel -) - -type TunnelEvent struct { - Index uint8 - EventType status - Location string - Url string + state connection.Status } type uiModel struct { @@ -69,7 +53,7 @@ func (data *uiModel) LaunchUI( ctx context.Context, generalLogger, transportLogger logger.Service, logLevels []logger.Level, - tunnelEventChan <-chan TunnelEvent, + tunnelEventChan <-chan connection.Event, ) { // Configure the logger to stream logs into the textview @@ -138,14 +122,14 @@ func (data *uiModel) LaunchUI( return case event := <-tunnelEventChan: switch event.EventType { - case Connected: + case connection.Connected: data.setConnTableCell(event, connTable, palette) - case Disconnected, Reconnecting: + case connection.Disconnected, connection.Reconnecting: data.changeConnStatus(event, connTable, generalLogger, palette) - case SetUrl: - tunnelHostText.SetText(event.Url) - data.edgeURL = event.Url - case RegisteringTunnel: + case connection.SetURL: + tunnelHostText.SetText(event.URL) + data.edgeURL = event.URL + case connection.RegisteringTunnel: if data.edgeURL == "" { tunnelHostText.SetText("Registering tunnel...") } @@ -175,7 +159,7 @@ func handleNewText(app *tview.Application, logTextView *tview.TextView) func() { } } -func (data *uiModel) changeConnStatus(event TunnelEvent, table *tview.Table, logger logger.Service, palette palette) { +func (data *uiModel) changeConnStatus(event connection.Event, table *tview.Table, logger logger.Service, palette palette) { index := int(event.Index) // Get connection location and state connState := data.getConnState(index) @@ -187,10 +171,10 @@ func (data *uiModel) changeConnStatus(event TunnelEvent, table *tview.Table, log locationState := event.Location - if event.EventType == Disconnected { - connState.state = Disconnected - } else if event.EventType == Reconnecting { - connState.state = Reconnecting + if event.EventType == connection.Disconnected { + connState.state = connection.Disconnected + } else if event.EventType == connection.Reconnecting { + connState.state = connection.Reconnecting locationState = "Reconnecting..." } @@ -211,12 +195,12 @@ func (data *uiModel) getConnState(connID int) *connState { return nil } -func (data *uiModel) setConnTableCell(event TunnelEvent, table *tview.Table, palette palette) { +func (data *uiModel) setConnTableCell(event connection.Event, table *tview.Table, palette palette) { index := int(event.Index) connectionNum := index + 1 // Update slice to keep track of connection location and state in UI table - data.connections[index].state = Connected + data.connections[index].state = connection.Connected data.connections[index].location = event.Location // Update text in table cell to show disconnected state @@ -225,18 +209,18 @@ func (data *uiModel) setConnTableCell(event TunnelEvent, table *tview.Table, pal table.SetCell(index, 0, cell) } -func newCellText(palette palette, connectionNum int, location string, connectedStatus status) string { +func newCellText(palette palette, connectionNum int, location string, connectedStatus connection.Status) string { // HA connection indicator formatted as: "• #: ", // where the left middle dot's color depends on the status of the connection const connFmtString = "[%s]\u2022[%s] #%d: %s" var dotColor string switch connectedStatus { - case Connected: + case connection.Connected: dotColor = palette.connected - case Disconnected: + case connection.Disconnected: dotColor = palette.disconnected - case Reconnecting: + case connection.Reconnecting: dotColor = palette.reconnecting } diff --git a/connection/connection_test.go b/connection/connection_test.go index 07ea8a30..6c55ed7e 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/logger" "github.com/gobwas/ws/wsutil" "github.com/stretchr/testify/assert" @@ -28,11 +27,12 @@ var ( Scheme: "https", Host: "connectiontest.argotunnel.com", } - testTunnelEventChan = make(chan ui.TunnelEvent) + testTunnelEventChan = make(chan Event) testObserver = &Observer{ testLogger, m, - testTunnelEventChan, + []chan Event{testTunnelEventChan}, + false, } testLargeResp = make([]byte, largeFileSize) ) diff --git a/connection/event.go b/connection/event.go new file mode 100644 index 00000000..64218f91 --- /dev/null +++ b/connection/event.go @@ -0,0 +1,25 @@ +package connection + +// Event is something that happened to a connection, e.g. disconnection or registration. +type Event struct { + Index uint8 + EventType Status + Location string + URL string +} + +// Status is the status of a connection. +type Status int + +const ( + // Disconnected means the connection to the edge was broken. + Disconnected Status = iota + // Connected means the connection to the edge was successfully established. + Connected + // Reconnecting means the connection to the edge is being re-established. + Reconnecting + // SetURL means this connection's tunnel was given a URL by the edge. Used for free tunnels. + SetURL + // RegisteringTunnel means the non-named tunnel is registering its connection. + RegisteringTunnel +) diff --git a/connection/observer.go b/connection/observer.go index 2c18f284..b6d9aeaa 100644 --- a/connection/observer.go +++ b/connection/observer.go @@ -5,37 +5,35 @@ import ( "net/url" "strings" - "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/logger" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) type Observer struct { logger.Service - metrics *tunnelMetrics - tunnelEventChan chan<- ui.TunnelEvent + metrics *tunnelMetrics + tunnelEventChans []chan Event + uiEnabled bool } -func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent) *Observer { +func NewObserver(logger logger.Service, tunnelEventChans []chan Event, uiEnabled bool) *Observer { return &Observer{ logger, newTunnelMetrics(), - tunnelEventChan, + tunnelEventChans, + uiEnabled, } } func (o *Observer) logServerInfo(connIndex uint8, location, msg string) { - // If launch-ui flag is set, send connect msg - if o.tunnelEventChan != nil { - o.tunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Connected, Location: location} - } + o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location}) o.Infof(msg) o.metrics.registerServerLocation(uint8ToString(connIndex), location) } func (o *Observer) logTrialHostname(registration *tunnelpogs.TunnelRegistration) error { // Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set) - if o.tunnelEventChan == nil { + if !o.uiEnabled { if registrationURL, err := url.Parse(registration.Url); err == nil { for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) { o.Info(line) @@ -81,19 +79,27 @@ func trialZoneMsg(url string) []string { } func (o *Observer) sendRegisteringEvent() { - if o.tunnelEventChan != nil { - o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel} - } + o.sendEvent(Event{EventType: RegisteringTunnel}) } func (o *Observer) sendConnectedEvent(connIndex uint8, location string) { - if o.tunnelEventChan != nil { - o.tunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Connected, Location: location} - } + o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location}) } func (o *Observer) sendURL(url string) { - if o.tunnelEventChan != nil { - o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: url} + o.sendEvent(Event{EventType: SetURL, URL: url}) +} + +func (o *Observer) SendReconnect(connIndex uint8) { + o.sendEvent(Event{Index: connIndex, EventType: Reconnecting}) +} + +func (o *Observer) SendDisconnect(connIndex uint8) { + o.sendEvent(Event{Index: connIndex, EventType: Disconnected}) +} + +func (o *Observer) sendEvent(e Event) { + for _, ch := range o.tunnelEventChans { + ch <- e } } diff --git a/metrics/metrics.go b/metrics/metrics.go index c28d6013..9784c66a 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -12,6 +12,7 @@ import ( "golang.org/x/net/trace" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/logger" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -22,22 +23,35 @@ const ( startupTime = time.Millisecond * 500 ) -func ServeMetrics(l net.Listener, shutdownC <-chan struct{}, logger logger.Service) (err error) { +func newMetricsHandler(connectionEvents <-chan connection.Event, log logger.Service) *http.ServeMux { + readyServer := NewReadyServer(connectionEvents, log) + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) + mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "OK\n") + }) + mux.Handle("/ready", readyServer) + return mux +} + +func ServeMetrics( + l net.Listener, + shutdownC <-chan struct{}, + connectionEvents <-chan connection.Event, + logger logger.Service, +) (err error) { var wg sync.WaitGroup // Metrics port is privileged, so no need for further access control trace.AuthRequest = func(*http.Request) (bool, bool) { return true, true } // TODO: parameterize ReadTimeout and WriteTimeout. The maximum time we can // profile CPU usage depends on WriteTimeout + h := newMetricsHandler(connectionEvents, logger) server := &http.Server{ ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, + Handler: h, } - http.Handle("/metrics", promhttp.Handler()) - http.Handle("/healthcheck", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "OK\n") - })) - wg.Add(1) go func() { defer wg.Done() diff --git a/metrics/readiness.go b/metrics/readiness.go new file mode 100644 index 00000000..856e3e1d --- /dev/null +++ b/metrics/readiness.go @@ -0,0 +1,80 @@ +package metrics + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + + conn "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/logger" +) + +// ReadyServer serves HTTP 200 if the tunnel can serve traffic. Intended for k8s readiness checks. +type ReadyServer struct { + sync.RWMutex + isConnected map[int]bool + log logger.Service +} + +// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events. +func NewReadyServer(connectionEvents <-chan conn.Event, log logger.Service) *ReadyServer { + rs := ReadyServer{ + isConnected: make(map[int]bool, 0), + log: log, + } + go func() { + for c := range connectionEvents { + switch c.EventType { + case conn.Connected: + rs.Lock() + rs.isConnected[int(c.Index)] = true + rs.Unlock() + case conn.Disconnected, conn.Reconnecting, conn.RegisteringTunnel: + rs.Lock() + rs.isConnected[int(c.Index)] = false + rs.Unlock() + case conn.SetURL: + continue + default: + rs.log.Errorf("Unknown connection event case %v", c) + } + } + }() + return &rs +} + +type body struct { + Status int `json:"status"` + ReadyConnections int `json:"readyConnections"` +} + +// ServeHTTP responds with HTTP 200 if the tunnel is connected to the edge. +func (rs *ReadyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + statusCode, readyConnections := rs.makeResponse() + w.WriteHeader(statusCode) + body := body{ + Status: statusCode, + ReadyConnections: readyConnections, + } + msg, err := json.Marshal(body) + if err != nil { + fmt.Fprintf(w, `{"error": "%s"}`, err) + } + w.Write(msg) +} + +// This is the bulk of the logic for ServeHTTP, broken into its own pure function +// to make unit testing easy. +func (rs *ReadyServer) makeResponse() (statusCode, readyConnections int) { + statusCode = http.StatusServiceUnavailable + rs.RLock() + defer rs.RUnlock() + for _, connected := range rs.isConnected { + if connected { + statusCode = http.StatusOK + readyConnections++ + } + } + return statusCode, readyConnections +} diff --git a/metrics/readiness_test.go b/metrics/readiness_test.go new file mode 100644 index 00000000..7a00abb7 --- /dev/null +++ b/metrics/readiness_test.go @@ -0,0 +1,58 @@ +package metrics + +import ( + "net/http" + "testing" +) + +func TestReadyServer_makeResponse(t *testing.T) { + type fields struct { + isConnected map[int]bool + } + tests := []struct { + name string + fields fields + wantOK bool + wantReadyConnections int + }{ + { + name: "One connection online => HTTP 200", + fields: fields{ + isConnected: map[int]bool{ + 0: false, + 1: false, + 2: true, + 3: false, + }, + }, + wantOK: true, + wantReadyConnections: 1, + }, + { + name: "No connections online => no HTTP 200", + fields: fields{ + isConnected: map[int]bool{ + 0: false, + 1: false, + 2: false, + 3: false, + }, + }, + wantReadyConnections: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := &ReadyServer{ + isConnected: tt.fields.isConnected, + } + gotStatusCode, gotReadyConnections := rs.makeResponse() + if tt.wantOK && gotStatusCode != http.StatusOK { + t.Errorf("ReadyServer.makeResponse() gotStatusCode = %v, want ok = %v", gotStatusCode, tt.wantOK) + } + if gotReadyConnections != tt.wantReadyConnections { + t.Errorf("ReadyServer.makeResponse() gotReadyConnections = %v, want %v", gotReadyConnections, tt.wantReadyConnections) + } + }) + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go index 211ab8e4..f175d978 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -16,7 +16,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" - "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" @@ -65,7 +64,7 @@ type TunnelConfig struct { NamedTunnel *connection.NamedTunnelConfig ClassicTunnel *connection.ClassicTunnelConfig MuxerConfig *connection.MuxerConfig - TunnelEventChan chan ui.TunnelEvent + TunnelEventChans []chan connection.Event ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config } @@ -235,10 +234,7 @@ func waitForBackoff( return err } - if config.TunnelEventChan != nil { - config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting} - } - + config.Observer.SendReconnect(connIndex) config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err) protobackoff.Backoff(ctx) @@ -288,12 +284,7 @@ func ServeTunnel( } }() - // If launch-ui flag is set, send disconnect msg - if config.TunnelEventChan != nil { - defer func() { - config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected} - }() - } + defer config.Observer.SendDisconnect(connIndex) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr) if err != nil { diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index bdaf4430..390ef1c6 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -47,6 +47,7 @@ func TestWaitForBackoffFallback(t *testing.T) { config := &TunnelConfig{ Logger: logger, ProtocolSelector: protocolSelector, + Observer: connection.NewObserver(nil, nil, false), } connIndex := uint8(1) diff --git a/tunneldns/tunnel.go b/tunneldns/tunnel.go index f7bd6971..d18b96ab 100644 --- a/tunneldns/tunnel.go +++ b/tunneldns/tunnel.go @@ -80,7 +80,7 @@ func Run(c *cli.Context) error { logger.Fatalf("Failed to open the metrics listener: %s", err) } - go metrics.ServeMetrics(metricsListener, nil, logger) + go metrics.ServeMetrics(metricsListener, nil, nil, logger) listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream"), c.StringSlice("bootstrap"), logger) if err != nil {