diff --git a/.gitignore b/.gitignore index ea099fb6..cc85adaa 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ guide/public .vscode \#*\# cscope.* +cloudflared diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 46468567..472b3dd9 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -20,6 +20,7 @@ import ( "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/websocket" "github.com/coreos/go-systemd/daemon" @@ -180,8 +181,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan var wg sync.WaitGroup listeners := gracenet.Net{} errC := make(chan error) - connectedSignal := make(chan struct{}) - closeConnOnce := sync.Once{} + connectedSignal := signal.New(make(chan struct{})) dnsReadySignal := make(chan struct{}) if c.String("config") == "" { @@ -281,7 +281,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan // Serve DNS proxy stand-alone if no hostname or tag or app is going to run if dnsProxyStandAlone(c) { - closeConnOnce.Do(func() { close(connectedSignal) }) + connectedSignal.Notify() // no grace period, handle SIGINT/SIGTERM immediately return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0) } @@ -315,7 +315,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan c.Set("url", "http://"+listener.Addr().String()) } - tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, &closeConnOnce) + tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger) if err != nil { return err } @@ -375,13 +375,13 @@ func waitToShutdown(wg *sync.WaitGroup, return err } -func notifySystemd(waitForSignal chan struct{}) { - <-waitForSignal +func notifySystemd(waitForSignal *signal.Signal) { + <-waitForSignal.Wait() daemon.SdNotify(false, "READY=1") } -func writePidFile(waitForSignal chan struct{}, pidFile string) { - <-waitForSignal +func writePidFile(waitForSignal *signal.Signal, pidFile string) { + <-waitForSignal.Wait() file, err := os.Create(pidFile) if err != nil { logger.WithError(err).Errorf("Unable to write pid to %s", pidFile) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 0d3510d8..fd900bac 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -13,7 +13,6 @@ import ( "path/filepath" "runtime" "strings" - "sync" "time" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" @@ -142,7 +141,6 @@ func prepareTunnelConfig( buildInfo *origin.BuildInfo, version string, logger, transportLogger *logrus.Logger, - closeConnOnce *sync.Once, ) (*origin.TunnelConfig, error) { hostname, err := validation.ValidateHostname(c.String("hostname")) if err != nil { @@ -238,7 +236,6 @@ func prepareTunnelConfig( NoChunkedEncoding: c.Bool("no-chunked-encoding"), CompressionQuality: c.Uint64("compression-quality"), IncidentLookup: origin.NewIncidentLookup(), - CloseConnOnce: closeConnOnce, }, nil } diff --git a/origin/supervisor.go b/origin/supervisor.go index 36960b19..5f60a392 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -6,6 +6,8 @@ import ( "net" "time" + "github.com/cloudflare/cloudflared/signal" + "github.com/google/uuid" ) @@ -51,7 +53,7 @@ func NewSupervisor(config *TunnelConfig) *Supervisor { } } -func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error { +func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { logger := s.config.Logger if err := s.initialize(ctx, connectedSignal, u); err != nil { return err @@ -120,7 +122,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}, u u } } -func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error { +func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { logger := s.config.Logger edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) if err != nil { @@ -143,11 +145,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct return fmt.Errorf("context was canceled") case tunnelError := <-s.tunnelErrors: return tunnelError.err - case <-connectedSignal: + case <-connectedSignal.Wait(): } // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { - go s.startTunnel(ctx, i, make(chan struct{}), u) + ch := signal.New(make(chan struct{})) + go s.startTunnel(ctx, i, ch, u) time.Sleep(registrationInterval) } return nil @@ -155,7 +158,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed -func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) { +func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) { err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) defer func() { s.tunnelErrors <- tunnelError{index: 0, err: err} @@ -183,17 +186,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. -func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}, u uuid.UUID) { +func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) { err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) s.tunnelErrors <- tunnelError{index: index, err: err} } -func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} { - signal := make(chan struct{}) - s.tunnelsConnecting[index] = signal - s.nextConnectedSignal = signal +func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { + sig := make(chan struct{}) + s.tunnelsConnecting[index] = sig + s.nextConnectedSignal = sig s.nextConnectedIndex = index - return signal + return signal.New(sig) } func (s *Supervisor) waitForNextTunnel(index int) bool { diff --git a/origin/tunnel.go b/origin/tunnel.go index ea373e0d..fa022f00 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -15,6 +15,7 @@ import ( "time" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" @@ -127,7 +128,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } } -func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { +func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal *signal.Signal) error { ctx, cancel := context.WithCancel(context.Background()) go func() { <-shutdownC @@ -155,7 +156,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, - connectedSignal chan struct{}, + connectedSignal *signal.Signal, u uuid.UUID, ) error { logger := config.Logger @@ -165,7 +166,7 @@ func ServeTunnelLoop(ctx context.Context, connectedFuse := h2mux.NewBooleanFuse() go func() { if connectedFuse.Await() { - config.CloseConnOnce.Do(func() { close(connectedSignal) }) + connectedSignal.Notify() } }() // Ensure the above goroutine will terminate if we return without connecting diff --git a/signal/safe_signal.go b/signal/safe_signal.go new file mode 100644 index 00000000..6ab3930c --- /dev/null +++ b/signal/safe_signal.go @@ -0,0 +1,33 @@ +package signal + +import ( + "sync" +) + +// Signal lets goroutines signal that some event has occurred. Other goroutines can wait for the signal. +type Signal struct { + ch chan struct{} + once sync.Once +} + +// New wraps a channel and turns it into a signal for a one-time event. +func New(ch chan struct{}) *Signal { + return &Signal{ + ch: ch, + once: sync.Once{}, + } +} + +// Notify alerts any goroutines waiting on this signal that the event has occurred. +// After the first call to Notify(), future calls are no-op. +func (s *Signal) Notify() { + s.once.Do(func() { + close(s.ch) + }) +} + +// Wait returns a channel which will be written to when Notify() is called for the first time. +// This channel will never be written to a second time. +func (s *Signal) Wait() <-chan struct{} { + return s.ch +} diff --git a/signal/safe_signal_test.go b/signal/safe_signal_test.go new file mode 100644 index 00000000..5ad96223 --- /dev/null +++ b/signal/safe_signal_test.go @@ -0,0 +1,25 @@ +package signal + +import ( + "testing" +) + +func TestMultiNotifyDoesntCrash(t *testing.T) { + sig := New(make(chan struct{})) + sig.Notify() + sig.Notify() + // If code has reached here without crashing, the test has passed. +} + +func TestWait(t *testing.T) { + sig := New(make(chan struct{})) + sig.Notify() + select { + case <-sig.Wait(): + // Test succeeds + return + default: + // sig.Wait() should have been read from, because sig.Notify() wrote to it. + t.Fail() + } +}