diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 9e940126..e2a6488c 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "runtime/trace" + "strings" "sync" "time" @@ -421,7 +422,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan return err } - reconnectCh := make(chan struct{}, 1) + reconnectCh := make(chan origin.ReconnectSignal, 1) if c.IsSet("stdin-control") { logger.Warn("Enabling control through stdin") go stdinControl(reconnectCh) @@ -1112,17 +1113,34 @@ func tunnelFlags(shouldHide bool) []cli.Flag { } } -func stdinControl(reconnectCh chan struct{}) { +func stdinControl(reconnectCh chan origin.ReconnectSignal) { for { scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { command := scanner.Text() + parts := strings.SplitN(command, " ", 2) - switch command { + switch parts[0] { + case "": + break case "reconnect": - reconnectCh <- struct{}{} + var reconnect origin.ReconnectSignal + if len(parts) > 1 { + var err error + if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil { + logger.Error(err.Error()) + continue + } + } + logger.Infof("Sending reconnect signal %+v", reconnect) + reconnectCh <- reconnect default: logger.Warn("Unknown command: ", command) + fallthrough + case "help": + logger.Info(`Supported command: +reconnect [delay] +- restarts one randomly chosen connection with optional delay before reconnect`) } } } diff --git a/origin/external_control.go b/origin/external_control.go new file mode 100644 index 00000000..d59759ed --- /dev/null +++ b/origin/external_control.go @@ -0,0 +1,21 @@ +package origin + +import ( + "time" +) + +type ReconnectSignal struct { + // wait this many seconds before re-establish the connection + Delay time.Duration +} + +// Error allows us to use ReconnectSignal as a special error to force connection abort +func (r *ReconnectSignal) Error() string { + return "reconnect signal" +} + +func (r *ReconnectSignal) DelayBeforeReconnect() { + if r.Delay > 0 { + time.Sleep(r.Delay) + } +} diff --git a/origin/supervisor.go b/origin/supervisor.go index 37bb0577..cd21b57f 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -105,7 +105,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { }, nil } -func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error { +func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { logger := s.config.Logger if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { return err @@ -191,7 +191,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re } // Returns nil if initialization succeeded, else the initialization error. -func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error { +func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { logger := s.logger s.lastResolve = time.Now() @@ -221,7 +221,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed -func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) { +func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) { var ( addr *net.TCPAddr err error @@ -265,7 +265,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. -func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan struct{}) { +func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) { var ( addr *net.TCPAddr err error diff --git a/origin/tunnel.go b/origin/tunnel.go index bc4b9c50..a373ee22 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -179,7 +179,7 @@ func (c *TunnelConfig) SupportedFeatures() []string { return basic } -func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan struct{}) error { +func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { s, err := NewSupervisor(config, cloudflaredID) if err != nil { return err @@ -195,7 +195,7 @@ func ServeTunnelLoop(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID, bufferPool *buffer.Pool, - reconnectCh chan struct{}, + reconnectCh chan ReconnectSignal, ) error { connectionLogger := config.Logger.WithField("connectionID", connectionID) config.Metrics.incrementHaConnections() @@ -244,7 +244,7 @@ func ServeTunnel( backoff *BackoffHandler, u uuid.UUID, bufferPool *buffer.Pool, - reconnectCh chan struct{}, + reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -332,13 +332,14 @@ func ServeTunnel( }) errGroup.Go(func() error { - select { - case <-reconnectCh: - return fmt.Errorf("received disconnect signal") - case <-serveCtx.Done(): - return nil + for { + select { + case reconnect := <-reconnectCh: + return &reconnect + case <-serveCtx.Done(): + return nil + } } - }) errGroup.Go(func() error { @@ -372,7 +373,11 @@ func ServeTunnel( logger.WithError(castedErr.cause).Error("Register tunnel error on client side") return err, true case muxerShutdownError: - logger.Infof("Muxer shutdown") + logger.Info("Muxer shutdown") + return err, true + case *ReconnectSignal: + logger.Warnf("Restarting due to reconnect signal in %d seconds", castedErr.Delay) + castedErr.DelayBeforeReconnect() return err, true default: logger.WithError(err).Error("Serve tunnel error") diff --git a/sshlog/logger_test.go b/sshlog/logger_test.go index e0a047be..4e08a504 100644 --- a/sshlog/logger_test.go +++ b/sshlog/logger_test.go @@ -32,7 +32,7 @@ func createLogger(t *testing.T) *Logger { // }() // // logger.Write([]byte(testStr)) -// time.Sleep(2 * time.Millisecond) +// time.DelayBeforeReconnect(2 * time.Millisecond) // data, err := ioutil.ReadFile(logFileName) // if err != nil { // t.Fatal("couldn't read the log file!", err)