TUN-2846: Trigger debug reconnects from stdin commands, not SIGUSR1

This commit is contained in:
Areg Harutyunyan 2020-03-27 14:39:59 +00:00
parent 42246f986c
commit ae374c0463
3 changed files with 36 additions and 14 deletions

View File

@ -1,18 +1,17 @@
package tunnel package tunnel
import ( import (
"bufio"
"context" "context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
"net/url" "net/url"
"os" "os"
ossig "os/signal"
"reflect" "reflect"
"runtime" "runtime"
"runtime/trace" "runtime/trace"
"sync" "sync"
"syscall"
"time" "time"
"github.com/cloudflare/cloudflared/awsuploader" "github.com/cloudflare/cloudflared/awsuploader"
@ -401,9 +400,11 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return err return err
} }
// When the user sends SIGUSR1, disconnect all connections. reconnectCh := make(chan struct{}, 1)
reconnectCh := make(chan os.Signal, 1) if c.IsSet("stdin-control") {
ossig.Notify(reconnectCh, syscall.SIGUSR1) logger.Warn("Enabling control through stdin")
go stdinControl(reconnectCh)
}
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -1066,5 +1067,28 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"HOST_KEY_PATH"}, EnvVars: []string{"HOST_KEY_PATH"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "stdin-control",
Usage: "Control the process using commands sent through stdin",
EnvVars: []string{"STDIN-CONTROL"},
Hidden: true,
Value: false,
}),
}
}
func stdinControl(reconnectCh chan struct{}) {
for {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
command := scanner.Text()
switch command {
case "reconnect":
reconnectCh <- struct{}{}
default:
logger.Warn("Unknown command: ", command)
}
}
} }
} }

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"time" "time"
@ -106,7 +105,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
}, nil }, nil
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error { func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error {
logger := s.config.Logger logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err return err
@ -192,7 +191,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
} }
// Returns nil if initialization succeeded, else the initialization error. // Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error { func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) error {
logger := s.logger logger := s.logger
s.lastResolve = time.Now() s.lastResolve = time.Now()
@ -222,7 +221,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on // startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) { func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan struct{}) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error
@ -266,7 +265,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
// startTunnel starts a new tunnel connection. The resulting error will be sent on // startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors. // s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) { func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan struct{}) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error

View File

@ -9,7 +9,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -170,7 +169,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
} }
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan os.Signal) error { func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan struct{}) error {
s, err := NewSupervisor(config, cloudflaredID) s, err := NewSupervisor(config, cloudflaredID)
if err != nil { if err != nil {
return err return err
@ -186,7 +185,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan os.Signal, reconnectCh chan struct{},
) error { ) error {
connectionLogger := config.Logger.WithField("connectionID", connectionID) connectionLogger := config.Logger.WithField("connectionID", connectionID)
config.Metrics.incrementHaConnections() config.Metrics.incrementHaConnections()
@ -235,7 +234,7 @@ func ServeTunnel(
backoff *BackoffHandler, backoff *BackoffHandler,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan os.Signal, reconnectCh chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {