TUN-2940: Added delay parameter to stdin reconnect command.

This commit is contained in:
Igor Postelnik 2020-04-30 00:02:08 -05:00
parent 41c358147c
commit dd0881f32b
5 changed files with 63 additions and 19 deletions

View File

@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"runtime/trace" "runtime/trace"
"strings"
"sync" "sync"
"time" "time"
@ -421,7 +422,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return err return err
} }
reconnectCh := make(chan struct{}, 1) reconnectCh := make(chan origin.ReconnectSignal, 1)
if c.IsSet("stdin-control") { if c.IsSet("stdin-control") {
logger.Warn("Enabling control through stdin") logger.Warn("Enabling control through stdin")
go stdinControl(reconnectCh) go stdinControl(reconnectCh)
@ -1112,17 +1113,34 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
} }
} }
func stdinControl(reconnectCh chan struct{}) { func stdinControl(reconnectCh chan origin.ReconnectSignal) {
for { for {
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { for scanner.Scan() {
command := scanner.Text() command := scanner.Text()
parts := strings.SplitN(command, " ", 2)
switch command { switch parts[0] {
case "":
break
case "reconnect": case "reconnect":
reconnectCh <- struct{}{} var reconnect origin.ReconnectSignal
if len(parts) > 1 {
var err error
if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil {
logger.Error(err.Error())
continue
}
}
logger.Infof("Sending reconnect signal %+v", reconnect)
reconnectCh <- reconnect
default: default:
logger.Warn("Unknown command: ", command) logger.Warn("Unknown command: ", command)
fallthrough
case "help":
logger.Info(`Supported command:
reconnect [delay]
- restarts one randomly chosen connection with optional delay before reconnect`)
} }
} }
} }

View File

@ -0,0 +1,21 @@
package origin
import (
"time"
)
type ReconnectSignal struct {
// wait this many seconds before re-establish the connection
Delay time.Duration
}
// Error allows us to use ReconnectSignal as a special error to force connection abort
func (r *ReconnectSignal) Error() string {
return "reconnect signal"
}
func (r *ReconnectSignal) DelayBeforeReconnect() {
if r.Delay > 0 {
time.Sleep(r.Delay)
}
}

View File

@ -105,7 +105,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
}, nil }, nil
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error { func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.config.Logger logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err return err
@ -191,7 +191,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
} }
// Returns nil if initialization succeeded, else the initialization error. // Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error { func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.logger logger := s.logger
s.lastResolve = time.Now() s.lastResolve = time.Now()
@ -221,7 +221,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on // startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) { func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error
@ -265,7 +265,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
// startTunnel starts a new tunnel connection. The resulting error will be sent on // startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors. // s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan struct{}) { func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error

View File

@ -179,7 +179,7 @@ func (c *TunnelConfig) SupportedFeatures() []string {
return basic return basic
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan struct{}) error { func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
s, err := NewSupervisor(config, cloudflaredID) s, err := NewSupervisor(config, cloudflaredID)
if err != nil { if err != nil {
return err return err
@ -195,7 +195,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan struct{}, reconnectCh chan ReconnectSignal,
) error { ) error {
connectionLogger := config.Logger.WithField("connectionID", connectionID) connectionLogger := config.Logger.WithField("connectionID", connectionID)
config.Metrics.incrementHaConnections() config.Metrics.incrementHaConnections()
@ -244,7 +244,7 @@ func ServeTunnel(
backoff *BackoffHandler, backoff *BackoffHandler,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan struct{}, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -332,13 +332,14 @@ func ServeTunnel(
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
select { for {
case <-reconnectCh: select {
return fmt.Errorf("received disconnect signal") case reconnect := <-reconnectCh:
case <-serveCtx.Done(): return &reconnect
return nil case <-serveCtx.Done():
return nil
}
} }
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
@ -372,7 +373,11 @@ func ServeTunnel(
logger.WithError(castedErr.cause).Error("Register tunnel error on client side") logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
return err, true return err, true
case muxerShutdownError: case muxerShutdownError:
logger.Infof("Muxer shutdown") logger.Info("Muxer shutdown")
return err, true
case *ReconnectSignal:
logger.Warnf("Restarting due to reconnect signal in %d seconds", castedErr.Delay)
castedErr.DelayBeforeReconnect()
return err, true return err, true
default: default:
logger.WithError(err).Error("Serve tunnel error") logger.WithError(err).Error("Serve tunnel error")

View File

@ -32,7 +32,7 @@ func createLogger(t *testing.T) *Logger {
// }() // }()
// //
// logger.Write([]byte(testStr)) // logger.Write([]byte(testStr))
// time.Sleep(2 * time.Millisecond) // time.DelayBeforeReconnect(2 * time.Millisecond)
// data, err := ioutil.ReadFile(logFileName) // data, err := ioutil.ReadFile(logFileName)
// if err != nil { // if err != nil {
// t.Fatal("couldn't read the log file!", err) // t.Fatal("couldn't read the log file!", err)