TUN-3738: Refactor observer to avoid potential of blocking on tunnel notifications

This commit is contained in:
Igor Postelnik 2021-01-14 16:33:36 -06:00 committed by Arég Harutyunyan
parent 8c9d725eeb
commit 04b1e4f859
12 changed files with 201 additions and 111 deletions

View File

@ -328,13 +328,9 @@ func StartServer(
transportLog := logger.CreateTransportLoggerFromContext(c, isUIEnabled)
readinessCh := make(chan connection.Event, 16)
uiCh := make(chan connection.Event, 16)
eventChannels := []chan connection.Event{
readinessCh,
uiCh,
}
tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, transportLog, namedTunnel, isUIEnabled, eventChannels)
observer := connection.NewObserver(log, isUIEnabled)
tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, observer, namedTunnel)
if err != nil {
log.Err(err).Msg("Couldn't start tunnel")
return err
@ -349,7 +345,9 @@ func StartServer(
wg.Add(1)
go func() {
defer wg.Done()
errC <- metrics.ServeMetrics(metricsListener, shutdownC, readinessCh, log)
readinessServer := metrics.NewReadyServer(log)
observer.RegisterSink(readinessServer)
errC <- metrics.ServeMetrics(metricsListener, shutdownC, readinessServer, log)
}()
if err := ingressRules.StartOrigins(&wg, log, shutdownC, errC); err != nil {
@ -369,20 +367,15 @@ func StartServer(
}()
if isUIEnabled {
tunnelInfo := ui.NewUIModel(
tunnelUI := ui.NewUIModel(
version,
hostname,
metricsListener.Addr().String(),
&ingressRules,
tunnelConfig.HAConnections,
)
tunnelInfo.LaunchUI(ctx, log, transportLog, uiCh)
} else {
go func() {
for range uiCh {
// Consume UI events into a noop
}
}()
app := tunnelUI.Launch(ctx, log, transportLog)
observer.RegisterSink(app)
}
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log)

View File

@ -157,10 +157,8 @@ func prepareTunnelConfig(
buildInfo *buildinfo.BuildInfo,
version string,
log *zerolog.Logger,
transportLogger *zerolog.Logger,
observer *connection.Observer,
namedTunnel *connection.NamedTunnelConfig,
isUIEnabled bool,
eventChans []chan connection.Event,
) (*origin.TunnelConfig, ingress.Ingress, error) {
isNamedTunnel := namedTunnel != nil
@ -281,7 +279,7 @@ func prepareTunnelConfig(
LBPool: c.String("lb-pool"),
Tags: tags,
Log: log,
Observer: connection.NewObserver(transportLogger, eventChans, isUIEnabled),
Observer: observer,
ReportedVersion: version,
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
Retries: uint(c.Int("retries")),
@ -289,7 +287,6 @@ func prepareTunnelConfig(
NamedTunnel: namedTunnel,
ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig,
TunnelEventChans: eventChans,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
}, ingressRules, nil

View File

@ -48,11 +48,10 @@ func NewUIModel(version, hostname, metricsURL string, ing *ingress.Ingress, haCo
}
}
func (data *uiModel) LaunchUI(
func (data *uiModel) Launch(
ctx context.Context,
log, transportLog *zerolog.Logger,
tunnelEventChan <-chan connection.Event,
) {
) connection.EventSink {
// Configure the logger to stream logs into the textview
// Add TextView as a group to write output to
@ -114,28 +113,9 @@ func (data *uiModel) LaunchUI(
grid.AddItem(logFrame, 4, 0, 5, 2, 0, 0, false)
go func() {
for {
select {
case <-ctx.Done():
app.Stop()
return
case event := <-tunnelEventChan:
switch event.EventType {
case connection.Connected:
data.setConnTableCell(event, connTable, palette)
case connection.Disconnected, connection.Reconnecting:
data.changeConnStatus(event, connTable, log, palette)
case connection.SetURL:
tunnelHostText.SetText(event.URL)
data.edgeURL = event.URL
case connection.RegisteringTunnel:
if data.edgeURL == "" {
tunnelHostText.SetText("Registering tunnel...")
}
}
}
app.Draw()
}
<-ctx.Done()
app.Stop()
return
}()
go func() {
@ -143,6 +123,23 @@ func (data *uiModel) LaunchUI(
log.Error().Msgf("Error launching UI: %s", err)
}
}()
return connection.EventSinkFunc(func(event connection.Event) {
switch event.EventType {
case connection.Connected:
data.setConnTableCell(event, connTable, palette)
case connection.Disconnected, connection.Reconnecting:
data.changeConnStatus(event, connTable, log, palette)
case connection.SetURL:
tunnelHostText.SetText(event.URL)
data.edgeURL = event.URL
case connection.RegisteringTunnel:
if data.edgeURL == "" {
tunnelHostText.SetText("Registering tunnel...")
}
}
app.Draw()
})
}
func NewDynamicColorTextView() *tview.TextView {

View File

@ -27,13 +27,7 @@ var (
Scheme: "https",
Host: "connectiontest.argotunnel.com",
}
testTunnelEventChan = make(chan Event)
testObserver = &Observer{
&log,
m,
[]chan Event{testTunnelEventChan},
false,
}
testObserver = NewObserver(&log, false)
testLargeResp = make([]byte, largeFileSize)
)

View File

@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 {
}
// Metrics that can be collected without asking the edge
func newTunnelMetrics() *tunnelMetrics {
func initTunnelMetrics() *tunnelMetrics {
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
@ -403,3 +403,15 @@ func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
t.oldServerLocations[connectionID] = loc
}
var tunnelMetricsInternal struct {
sync.Once
metrics *tunnelMetrics
}
func newTunnelMetrics() *tunnelMetrics {
tunnelMetricsInternal.Do(func() {
tunnelMetricsInternal.metrics = initTunnelMetrics()
})
return tunnelMetricsInternal.metrics
}

View File

@ -10,22 +10,37 @@ import (
"github.com/rs/zerolog"
)
const LogFieldLocation = "location"
const (
LogFieldLocation = "location"
observerChannelBufferSize = 16
)
type Observer struct {
log *zerolog.Logger
metrics *tunnelMetrics
tunnelEventChans []chan Event
uiEnabled bool
log *zerolog.Logger
metrics *tunnelMetrics
tunnelEventChan chan Event
uiEnabled bool
addSinkChan chan EventSink
}
func NewObserver(log *zerolog.Logger, tunnelEventChans []chan Event, uiEnabled bool) *Observer {
return &Observer{
log,
newTunnelMetrics(),
tunnelEventChans,
uiEnabled,
type EventSink interface {
OnTunnelEvent(event Event)
}
func NewObserver(log *zerolog.Logger, uiEnabled bool) *Observer {
o := &Observer{
log: log,
metrics: newTunnelMetrics(),
uiEnabled: uiEnabled,
tunnelEventChan: make(chan Event, observerChannelBufferSize),
addSinkChan: make(chan EventSink, observerChannelBufferSize),
}
go o.dispatchEvents()
return o
}
func (o *Observer) RegisterSink(sink EventSink) {
o.addSinkChan <- sink
}
func (o *Observer) logServerInfo(connIndex uint8, location, msg string) {
@ -105,7 +120,30 @@ func (o *Observer) SendDisconnect(connIndex uint8) {
}
func (o *Observer) sendEvent(e Event) {
for _, ch := range o.tunnelEventChans {
ch <- e
select {
case o.tunnelEventChan <- e:
break
default:
o.log.Warn().Msg("observer channel buffer is full")
}
}
func (o *Observer) dispatchEvents() {
var sinks []EventSink
for {
select {
case sink := <-o.addSinkChan:
sinks = append(sinks, sink)
case evt := <-o.tunnelEventChan:
for _, sink := range sinks {
sink.OnTunnelEvent(evt)
}
}
}
}
type EventSinkFunc func(event Event)
func (f EventSinkFunc) OnTunnelEvent(event Event) {
f(event)
}

View File

@ -4,14 +4,13 @@ import (
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// can only be called once
var m = newTunnelMetrics()
func TestRegisterServerLocation(t *testing.T) {
m := newTunnelMetrics()
tunnels := 20
var wg sync.WaitGroup
wg.Add(tunnels)
@ -43,3 +42,27 @@ func TestRegisterServerLocation(t *testing.T) {
}
}
func TestObserverEventsDontBlock(t *testing.T) {
observer := NewObserver(&log, false)
var mu sync.Mutex
observer.RegisterSink(EventSinkFunc(func(_ Event) {
// callback will block if lock is already held
mu.Lock()
mu.Unlock()
}))
timeout := time.AfterFunc(5*time.Second, func() {
mu.Unlock() // release the callback on timer expiration
t.Fatal("observer is blocked")
})
mu.Lock() // block the callback
for i := 0; i < 2 * observerChannelBufferSize; i++ {
observer.sendRegisteringEvent()
}
if pending := timeout.Stop(); pending {
// release the callback if timer hasn't expired yet
mu.Unlock()
}
}

View File

@ -10,8 +10,6 @@ import (
"sync"
"time"
"github.com/cloudflare/cloudflared/connection"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
@ -23,21 +21,22 @@ const (
startupTime = time.Millisecond * 500
)
func newMetricsHandler(connectionEvents <-chan connection.Event, log *zerolog.Logger) *http.ServeMux {
readyServer := NewReadyServer(connectionEvents, log)
func newMetricsHandler(readyServer *ReadyServer) *http.ServeMux {
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "OK\n")
})
mux.Handle("/ready", readyServer)
if readyServer != nil {
mux.Handle("/ready", readyServer)
}
return mux
}
func ServeMetrics(
l net.Listener,
shutdownC <-chan struct{},
connectionEvents <-chan connection.Event,
readyServer *ReadyServer,
log *zerolog.Logger,
) (err error) {
var wg sync.WaitGroup
@ -45,7 +44,7 @@ func ServeMetrics(
trace.AuthRequest = func(*http.Request) (bool, bool) { return true, true }
// TODO: parameterize ReadTimeout and WriteTimeout. The maximum time we can
// profile CPU usage depends on WriteTimeout
h := newMetricsHandler(connectionEvents, log)
h := newMetricsHandler(readyServer)
server := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,

View File

@ -19,30 +19,28 @@ type ReadyServer struct {
}
// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events.
func NewReadyServer(connectionEvents <-chan conn.Event, log *zerolog.Logger) *ReadyServer {
rs := ReadyServer{
func NewReadyServer(log *zerolog.Logger) *ReadyServer {
return &ReadyServer{
isConnected: make(map[int]bool, 0),
log: log,
}
go func() {
for c := range connectionEvents {
switch c.EventType {
case conn.Connected:
rs.Lock()
rs.isConnected[int(c.Index)] = true
rs.Unlock()
case conn.Disconnected, conn.Reconnecting, conn.RegisteringTunnel:
rs.Lock()
rs.isConnected[int(c.Index)] = false
rs.Unlock()
case conn.SetURL:
continue
default:
rs.log.Error().Msgf("Unknown connection event case %v", c)
}
}
}()
return &rs
}
func (rs *ReadyServer) OnTunnelEvent(c conn.Event) {
switch c.EventType {
case conn.Connected:
rs.Lock()
rs.isConnected[int(c.Index)] = true
rs.Unlock()
case conn.Disconnected, conn.Reconnecting, conn.RegisteringTunnel:
rs.Lock()
rs.isConnected[int(c.Index)] = false
rs.Unlock()
case conn.SetURL:
break
default:
rs.log.Error().Msgf("Unknown connection event case %v", c)
}
}
type body struct {

View File

@ -3,6 +3,11 @@ package metrics
import (
"net/http"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/connection"
)
func TestReadyServer_makeResponse(t *testing.T) {
@ -56,3 +61,49 @@ func TestReadyServer_makeResponse(t *testing.T) {
})
}
}
func TestReadinessEventHandling(t *testing.T) {
nopLogger := zerolog.Nop()
rs := NewReadyServer(&nopLogger)
// start not ok
code, ready := rs.makeResponse()
assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready)
// one connected => ok
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.Connected,
})
code, ready = rs.makeResponse()
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready)
// another connected => still ok
rs.OnTunnelEvent(connection.Event{
Index: 2,
EventType: connection.Connected,
})
code, ready = rs.makeResponse()
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 2, ready)
// one reconnecting => still ok
rs.OnTunnelEvent(connection.Event{
Index: 2,
EventType: connection.Reconnecting,
})
code, ready = rs.makeResponse()
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready)
// other disconnected => not ok
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.Disconnected,
})
code, ready = rs.makeResponse()
assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready)
}

View File

@ -12,7 +12,6 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
@ -27,9 +26,7 @@ import (
const (
dialTimeout = 15 * time.Second
muxerTimeout = 5 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
DuplicateConnectionError = "EDUPCONN"
FeatureSerializedHeaders = "serialized_headers"
FeatureQuickReconnects = "quick_reconnects"
)
@ -37,9 +34,7 @@ const (
type rpcName string
const (
register rpcName = "register"
reconnect rpcName = "reconnect"
unregister rpcName = "unregister"
authenticate rpcName = " authenticate"
)
@ -64,7 +59,6 @@ type TunnelConfig struct {
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChans []chan connection.Event
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
}
@ -90,11 +84,6 @@ type clientRegisterTunnelError struct {
cause error
}
func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
counter.WithLabelValues(cause.Error(), string(name)).Inc()
return clientRegisterTunnelError{cause: cause}
}
func (e clientRegisterTunnelError) Error() string {
return e.cause.Error()
}
@ -466,5 +455,4 @@ func activeIncidentsMsg(incidents []Incident) string {
incidentStrings = append(incidentStrings, incidentString)
}
return preamble + " " + strings.Join(incidentStrings, "; ")
}

View File

@ -53,7 +53,7 @@ func TestWaitForBackoffFallback(t *testing.T) {
config := &TunnelConfig{
Log: &log,
ProtocolSelector: protocolSelector,
Observer: connection.NewObserver(nil, nil, false),
Observer: connection.NewObserver(nil, false),
}
connIndex := uint8(1)