TUN-2940: Added delay parameter to stdin reconnect command.
This commit is contained in:
parent
41c358147c
commit
dd0881f32b
|
@ -11,6 +11,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/trace"
|
"runtime/trace"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -421,7 +422,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
reconnectCh := make(chan struct{}, 1)
|
reconnectCh := make(chan origin.ReconnectSignal, 1)
|
||||||
if c.IsSet("stdin-control") {
|
if c.IsSet("stdin-control") {
|
||||||
logger.Warn("Enabling control through stdin")
|
logger.Warn("Enabling control through stdin")
|
||||||
go stdinControl(reconnectCh)
|
go stdinControl(reconnectCh)
|
||||||
|
@ -1112,17 +1113,34 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func stdinControl(reconnectCh chan struct{}) {
|
func stdinControl(reconnectCh chan origin.ReconnectSignal) {
|
||||||
for {
|
for {
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
command := scanner.Text()
|
command := scanner.Text()
|
||||||
|
parts := strings.SplitN(command, " ", 2)
|
||||||
|
|
||||||
switch command {
|
switch parts[0] {
|
||||||
|
case "":
|
||||||
|
break
|
||||||
case "reconnect":
|
case "reconnect":
|
||||||
reconnectCh <- struct{}{}
|
var reconnect origin.ReconnectSignal
|
||||||
|
if len(parts) > 1 {
|
||||||
|
var err error
|
||||||
|
if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil {
|
||||||
|
logger.Error(err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Infof("Sending reconnect signal %+v", reconnect)
|
||||||
|
reconnectCh <- reconnect
|
||||||
default:
|
default:
|
||||||
logger.Warn("Unknown command: ", command)
|
logger.Warn("Unknown command: ", command)
|
||||||
|
fallthrough
|
||||||
|
case "help":
|
||||||
|
logger.Info(`Supported command:
|
||||||
|
reconnect [delay]
|
||||||
|
- restarts one randomly chosen connection with optional delay before reconnect`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReconnectSignal struct {
|
||||||
|
// wait this many seconds before re-establish the connection
|
||||||
|
Delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error allows us to use ReconnectSignal as a special error to force connection abort
|
||||||
|
func (r *ReconnectSignal) Error() string {
|
||||||
|
return "reconnect signal"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReconnectSignal) DelayBeforeReconnect() {
|
||||||
|
if r.Delay > 0 {
|
||||||
|
time.Sleep(r.Delay)
|
||||||
|
}
|
||||||
|
}
|
|
@ -105,7 +105,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||||
logger := s.config.Logger
|
logger := s.config.Logger
|
||||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -191,7 +191,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns nil if initialization succeeded, else the initialization error.
|
// Returns nil if initialization succeeded, else the initialization error.
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||||
logger := s.logger
|
logger := s.logger
|
||||||
|
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
|
@ -221,7 +221,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
|
|
||||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) {
|
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
|
||||||
var (
|
var (
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
|
@ -265,7 +265,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
|
|
||||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors.
|
// s.tunnelErrors.
|
||||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan struct{}) {
|
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
|
||||||
var (
|
var (
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
|
|
|
@ -179,7 +179,7 @@ func (c *TunnelConfig) SupportedFeatures() []string {
|
||||||
return basic
|
return basic
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan struct{}) error {
|
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
||||||
s, err := NewSupervisor(config, cloudflaredID)
|
s, err := NewSupervisor(config, cloudflaredID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -195,7 +195,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
u uuid.UUID,
|
u uuid.UUID,
|
||||||
bufferPool *buffer.Pool,
|
bufferPool *buffer.Pool,
|
||||||
reconnectCh chan struct{},
|
reconnectCh chan ReconnectSignal,
|
||||||
) error {
|
) error {
|
||||||
connectionLogger := config.Logger.WithField("connectionID", connectionID)
|
connectionLogger := config.Logger.WithField("connectionID", connectionID)
|
||||||
config.Metrics.incrementHaConnections()
|
config.Metrics.incrementHaConnections()
|
||||||
|
@ -244,7 +244,7 @@ func ServeTunnel(
|
||||||
backoff *BackoffHandler,
|
backoff *BackoffHandler,
|
||||||
u uuid.UUID,
|
u uuid.UUID,
|
||||||
bufferPool *buffer.Pool,
|
bufferPool *buffer.Pool,
|
||||||
reconnectCh chan struct{},
|
reconnectCh chan ReconnectSignal,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -332,13 +332,14 @@ func ServeTunnel(
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
select {
|
for {
|
||||||
case <-reconnectCh:
|
select {
|
||||||
return fmt.Errorf("received disconnect signal")
|
case reconnect := <-reconnectCh:
|
||||||
case <-serveCtx.Done():
|
return &reconnect
|
||||||
return nil
|
case <-serveCtx.Done():
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
|
@ -372,7 +373,11 @@ func ServeTunnel(
|
||||||
logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
|
logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
|
||||||
return err, true
|
return err, true
|
||||||
case muxerShutdownError:
|
case muxerShutdownError:
|
||||||
logger.Infof("Muxer shutdown")
|
logger.Info("Muxer shutdown")
|
||||||
|
return err, true
|
||||||
|
case *ReconnectSignal:
|
||||||
|
logger.Warnf("Restarting due to reconnect signal in %d seconds", castedErr.Delay)
|
||||||
|
castedErr.DelayBeforeReconnect()
|
||||||
return err, true
|
return err, true
|
||||||
default:
|
default:
|
||||||
logger.WithError(err).Error("Serve tunnel error")
|
logger.WithError(err).Error("Serve tunnel error")
|
||||||
|
|
|
@ -32,7 +32,7 @@ func createLogger(t *testing.T) *Logger {
|
||||||
// }()
|
// }()
|
||||||
//
|
//
|
||||||
// logger.Write([]byte(testStr))
|
// logger.Write([]byte(testStr))
|
||||||
// time.Sleep(2 * time.Millisecond)
|
// time.DelayBeforeReconnect(2 * time.Millisecond)
|
||||||
// data, err := ioutil.ReadFile(logFileName)
|
// data, err := ioutil.ReadFile(logFileName)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// t.Fatal("couldn't read the log file!", err)
|
// t.Fatal("couldn't read the log file!", err)
|
||||||
|
|
Loading…
Reference in New Issue