TUN-2940: Added delay parameter to stdin reconnect command.
This commit is contained in:
		
							parent
							
								
									41c358147c
								
							
						
					
					
						commit
						dd0881f32b
					
				| 
						 | 
				
			
			@ -11,6 +11,7 @@ import (
 | 
			
		|||
	"reflect"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"runtime/trace"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -421,7 +422,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
 | 
			
		|||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reconnectCh := make(chan struct{}, 1)
 | 
			
		||||
	reconnectCh := make(chan origin.ReconnectSignal, 1)
 | 
			
		||||
	if c.IsSet("stdin-control") {
 | 
			
		||||
		logger.Warn("Enabling control through stdin")
 | 
			
		||||
		go stdinControl(reconnectCh)
 | 
			
		||||
| 
						 | 
				
			
			@ -1112,17 +1113,34 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func stdinControl(reconnectCh chan struct{}) {
 | 
			
		||||
func stdinControl(reconnectCh chan origin.ReconnectSignal) {
 | 
			
		||||
	for {
 | 
			
		||||
		scanner := bufio.NewScanner(os.Stdin)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			command := scanner.Text()
 | 
			
		||||
			parts := strings.SplitN(command, " ", 2)
 | 
			
		||||
 | 
			
		||||
			switch command {
 | 
			
		||||
			switch parts[0] {
 | 
			
		||||
			case "":
 | 
			
		||||
				break
 | 
			
		||||
			case "reconnect":
 | 
			
		||||
				reconnectCh <- struct{}{}
 | 
			
		||||
				var reconnect origin.ReconnectSignal
 | 
			
		||||
				if len(parts) > 1 {
 | 
			
		||||
					var err error
 | 
			
		||||
					if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil {
 | 
			
		||||
						logger.Error(err.Error())
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				logger.Infof("Sending reconnect signal %+v", reconnect)
 | 
			
		||||
				reconnectCh <- reconnect
 | 
			
		||||
			default:
 | 
			
		||||
				logger.Warn("Unknown command: ", command)
 | 
			
		||||
				fallthrough
 | 
			
		||||
			case "help":
 | 
			
		||||
				logger.Info(`Supported command: 
 | 
			
		||||
reconnect [delay] 
 | 
			
		||||
- restarts one randomly chosen connection with optional delay before reconnect`)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,21 @@
 | 
			
		|||
package origin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ReconnectSignal struct {
 | 
			
		||||
	// wait this many seconds before re-establish the connection
 | 
			
		||||
	Delay time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Error allows us to use ReconnectSignal as a special error to force connection abort
 | 
			
		||||
func (r *ReconnectSignal) Error() string {
 | 
			
		||||
	return "reconnect signal"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *ReconnectSignal) DelayBeforeReconnect() {
 | 
			
		||||
	if r.Delay > 0 {
 | 
			
		||||
		time.Sleep(r.Delay)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +105,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
 | 
			
		|||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error {
 | 
			
		||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
 | 
			
		||||
	logger := s.config.Logger
 | 
			
		||||
	if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
| 
						 | 
				
			
			@ -191,7 +191,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// Returns nil if initialization succeeded, else the initialization error.
 | 
			
		||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error {
 | 
			
		||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
 | 
			
		||||
	logger := s.logger
 | 
			
		||||
 | 
			
		||||
	s.lastResolve = time.Now()
 | 
			
		||||
| 
						 | 
				
			
			@ -221,7 +221,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
 | 
			
		|||
 | 
			
		||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
 | 
			
		||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
 | 
			
		||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) {
 | 
			
		||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
 | 
			
		||||
	var (
 | 
			
		||||
		addr *net.TCPAddr
 | 
			
		||||
		err  error
 | 
			
		||||
| 
						 | 
				
			
			@ -265,7 +265,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
 | 
			
		|||
 | 
			
		||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
 | 
			
		||||
// s.tunnelErrors.
 | 
			
		||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan struct{}) {
 | 
			
		||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
 | 
			
		||||
	var (
 | 
			
		||||
		addr *net.TCPAddr
 | 
			
		||||
		err  error
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -179,7 +179,7 @@ func (c *TunnelConfig) SupportedFeatures() []string {
 | 
			
		|||
	return basic
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan struct{}) error {
 | 
			
		||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
 | 
			
		||||
	s, err := NewSupervisor(config, cloudflaredID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
| 
						 | 
				
			
			@ -195,7 +195,7 @@ func ServeTunnelLoop(ctx context.Context,
 | 
			
		|||
	connectedSignal *signal.Signal,
 | 
			
		||||
	u uuid.UUID,
 | 
			
		||||
	bufferPool *buffer.Pool,
 | 
			
		||||
	reconnectCh chan struct{},
 | 
			
		||||
	reconnectCh chan ReconnectSignal,
 | 
			
		||||
) error {
 | 
			
		||||
	connectionLogger := config.Logger.WithField("connectionID", connectionID)
 | 
			
		||||
	config.Metrics.incrementHaConnections()
 | 
			
		||||
| 
						 | 
				
			
			@ -244,7 +244,7 @@ func ServeTunnel(
 | 
			
		|||
	backoff *BackoffHandler,
 | 
			
		||||
	u uuid.UUID,
 | 
			
		||||
	bufferPool *buffer.Pool,
 | 
			
		||||
	reconnectCh chan struct{},
 | 
			
		||||
	reconnectCh chan ReconnectSignal,
 | 
			
		||||
) (err error, recoverable bool) {
 | 
			
		||||
	// Treat panics as recoverable errors
 | 
			
		||||
	defer func() {
 | 
			
		||||
| 
						 | 
				
			
			@ -332,13 +332,14 @@ func ServeTunnel(
 | 
			
		|||
	})
 | 
			
		||||
 | 
			
		||||
	errGroup.Go(func() error {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-reconnectCh:
 | 
			
		||||
			return fmt.Errorf("received disconnect signal")
 | 
			
		||||
		case <-serveCtx.Done():
 | 
			
		||||
			return nil
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case reconnect := <-reconnectCh:
 | 
			
		||||
				return &reconnect
 | 
			
		||||
			case <-serveCtx.Done():
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	errGroup.Go(func() error {
 | 
			
		||||
| 
						 | 
				
			
			@ -372,7 +373,11 @@ func ServeTunnel(
 | 
			
		|||
			logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
 | 
			
		||||
			return err, true
 | 
			
		||||
		case muxerShutdownError:
 | 
			
		||||
			logger.Infof("Muxer shutdown")
 | 
			
		||||
			logger.Info("Muxer shutdown")
 | 
			
		||||
			return err, true
 | 
			
		||||
		case *ReconnectSignal:
 | 
			
		||||
			logger.Warnf("Restarting due to reconnect signal in %d seconds", castedErr.Delay)
 | 
			
		||||
			castedErr.DelayBeforeReconnect()
 | 
			
		||||
			return err, true
 | 
			
		||||
		default:
 | 
			
		||||
			logger.WithError(err).Error("Serve tunnel error")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ func createLogger(t *testing.T) *Logger {
 | 
			
		|||
//	}()
 | 
			
		||||
//
 | 
			
		||||
//	logger.Write([]byte(testStr))
 | 
			
		||||
//	time.Sleep(2 * time.Millisecond)
 | 
			
		||||
//	time.DelayBeforeReconnect(2 * time.Millisecond)
 | 
			
		||||
//	data, err := ioutil.ReadFile(logFileName)
 | 
			
		||||
//	if err != nil {
 | 
			
		||||
//		t.Fatal("couldn't read the log file!", err)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue