TUN-1562: Refactor connectedSignal to be safe to close multiple times

This commit is contained in:
Adam Chalmers 2019-03-04 13:48:56 -06:00
parent fea3569956
commit 073c5bfdaa
7 changed files with 85 additions and 25 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ guide/public
.vscode
\#*\#
cscope.*
cloudflared

View File

@ -20,6 +20,7 @@ import (
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/websocket"
"github.com/coreos/go-systemd/daemon"
@ -180,8 +181,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
var wg sync.WaitGroup
listeners := gracenet.Net{}
errC := make(chan error)
connectedSignal := make(chan struct{})
closeConnOnce := sync.Once{}
connectedSignal := signal.New(make(chan struct{}))
dnsReadySignal := make(chan struct{})
if c.String("config") == "" {
@ -281,7 +281,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
// Serve DNS proxy stand-alone if no hostname or tag or app is going to run
if dnsProxyStandAlone(c) {
closeConnOnce.Do(func() { close(connectedSignal) })
connectedSignal.Notify()
// no grace period, handle SIGINT/SIGTERM immediately
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0)
}
@ -315,7 +315,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
c.Set("url", "http://"+listener.Addr().String())
}
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, &closeConnOnce)
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger)
if err != nil {
return err
}
@ -375,13 +375,13 @@ func waitToShutdown(wg *sync.WaitGroup,
return err
}
func notifySystemd(waitForSignal chan struct{}) {
<-waitForSignal
func notifySystemd(waitForSignal *signal.Signal) {
<-waitForSignal.Wait()
daemon.SdNotify(false, "READY=1")
}
func writePidFile(waitForSignal chan struct{}, pidFile string) {
<-waitForSignal
func writePidFile(waitForSignal *signal.Signal, pidFile string) {
<-waitForSignal.Wait()
file, err := os.Create(pidFile)
if err != nil {
logger.WithError(err).Errorf("Unable to write pid to %s", pidFile)

View File

@ -13,7 +13,6 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
@ -142,7 +141,6 @@ func prepareTunnelConfig(
buildInfo *origin.BuildInfo,
version string, logger,
transportLogger *logrus.Logger,
closeConnOnce *sync.Once,
) (*origin.TunnelConfig, error) {
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
@ -238,7 +236,6 @@ func prepareTunnelConfig(
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
CompressionQuality: c.Uint64("compression-quality"),
IncidentLookup: origin.NewIncidentLookup(),
CloseConnOnce: closeConnOnce,
}, nil
}

View File

@ -6,6 +6,8 @@ import (
"net"
"time"
"github.com/cloudflare/cloudflared/signal"
"github.com/google/uuid"
)
@ -51,7 +53,7 @@ func NewSupervisor(config *TunnelConfig) *Supervisor {
}
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error {
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal, u); err != nil {
return err
@ -120,7 +122,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}, u u
}
}
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error {
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
logger := s.config.Logger
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
if err != nil {
@ -143,11 +145,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct
return fmt.Errorf("context was canceled")
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-connectedSignal:
case <-connectedSignal.Wait():
}
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
go s.startTunnel(ctx, i, make(chan struct{}), u)
ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch, u)
time.Sleep(registrationInterval)
}
return nil
@ -155,7 +158,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) {
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
defer func() {
s.tunnelErrors <- tunnelError{index: 0, err: err}
@ -183,17 +186,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}, u uuid.UUID) {
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
signal := make(chan struct{})
s.tunnelsConnecting[index] = signal
s.nextConnectedSignal = signal
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
sig := make(chan struct{})
s.tunnelsConnecting[index] = sig
s.nextConnectedSignal = sig
s.nextConnectedIndex = index
return signal
return signal.New(sig)
}
func (s *Supervisor) waitForNextTunnel(index int) bool {

View File

@ -15,6 +15,7 @@ import (
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
@ -127,7 +128,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
}
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal *signal.Signal) error {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-shutdownC
@ -155,7 +156,7 @@ func ServeTunnelLoop(ctx context.Context,
config *TunnelConfig,
addr *net.TCPAddr,
connectionID uint8,
connectedSignal chan struct{},
connectedSignal *signal.Signal,
u uuid.UUID,
) error {
logger := config.Logger
@ -165,7 +166,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedFuse := h2mux.NewBooleanFuse()
go func() {
if connectedFuse.Await() {
config.CloseConnOnce.Do(func() { close(connectedSignal) })
connectedSignal.Notify()
}
}()
// Ensure the above goroutine will terminate if we return without connecting

33
signal/safe_signal.go Normal file
View File

@ -0,0 +1,33 @@
package signal
import (
"sync"
)
// Signal lets goroutines signal that some event has occurred. Other goroutines can wait for the signal.
type Signal struct {
ch chan struct{}
once sync.Once
}
// New wraps a channel and turns it into a signal for a one-time event.
func New(ch chan struct{}) *Signal {
return &Signal{
ch: ch,
once: sync.Once{},
}
}
// Notify alerts any goroutines waiting on this signal that the event has occurred.
// After the first call to Notify(), future calls are no-op.
func (s *Signal) Notify() {
s.once.Do(func() {
close(s.ch)
})
}
// Wait returns a channel which will be written to when Notify() is called for the first time.
// This channel will never be written to a second time.
func (s *Signal) Wait() <-chan struct{} {
return s.ch
}

View File

@ -0,0 +1,25 @@
package signal
import (
"testing"
)
func TestMultiNotifyDoesntCrash(t *testing.T) {
sig := New(make(chan struct{}))
sig.Notify()
sig.Notify()
// If code has reached here without crashing, the test has passed.
}
func TestWait(t *testing.T) {
sig := New(make(chan struct{}))
sig.Notify()
select {
case <-sig.Wait():
// Test succeeds
return
default:
// sig.Wait() should have been read from, because sig.Notify() wrote to it.
t.Fail()
}
}