TUN-3792: Handle graceful shutdown correctly when running as a windows service. Only expose one shutdown channel globally, which now triggers the graceful shutdown sequence across all modes. Removed separate handling of zero-duration grace period, instead it's checked only when we need to wait for exit.

This commit is contained in:
Igor Postelnik 2021-01-25 15:51:58 -06:00
parent d87bfcbe55
commit 6cdd20e820
9 changed files with 143 additions and 243 deletions

View File

@ -55,12 +55,11 @@ const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b@sentry.io/189878"
var ( var (
shutdownC chan struct{} shutdownC chan struct{}
graceShutdownC chan struct{}
) )
// Init will initialize and store vars from the main program // Init will initialize and store vars from the main program
func Init(s, g chan struct{}) { func Init(shutdown chan struct{}) {
shutdownC, graceShutdownC = s, g shutdownC = shutdown
} }
// Flags return the global flags for Access related commands (hopefully none) // Flags return the global flags for Access related commands (hopefully none)

View File

@ -16,7 +16,7 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the Argo Tunnel system service", Usage: "Manages the Argo Tunnel system service",

View File

@ -17,7 +17,7 @@ const (
launchdIdentifier = "com.cloudflare.cloudflared" launchdIdentifier = "com.cloudflare.cloudflared"
) )
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the Argo Tunnel launch agent", Usage: "Manages the Argo Tunnel launch agent",

View File

@ -46,10 +46,7 @@ func main() {
metrics.RegisterBuildInfo(BuildTime, Version) metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetRelease(Version) raven.SetRelease(Version)
// Force shutdown channel used by the app. When closed, app must terminate. // Graceful shutdown channel used by the app. When closed, app must terminate gracefully.
// Windows service manager closes this channel when it receives shutdown command.
shutdownC := make(chan struct{})
// Graceful shutdown channel used by the app. When closed, app must terminate.
// Windows service manager closes this channel when it receives stop command. // Windows service manager closes this channel when it receives stop command.
graceShutdownC := make(chan struct{}) graceShutdownC := make(chan struct{})
@ -77,14 +74,14 @@ func main() {
See https://developers.cloudflare.com/argo-tunnel/ for more in-depth documentation.` See https://developers.cloudflare.com/argo-tunnel/ for more in-depth documentation.`
app.Flags = flags() app.Flags = flags()
app.Action = action(Version, shutdownC, graceShutdownC) app.Action = action(graceShutdownC)
app.Before = tunnel.SetFlagsFromConfigFile app.Before = tunnel.SetFlagsFromConfigFile
app.Commands = commands(cli.ShowVersion) app.Commands = commands(cli.ShowVersion)
tunnel.Init(Version, shutdownC, graceShutdownC) // we need this to support the tunnel sub command... tunnel.Init(Version, graceShutdownC) // we need this to support the tunnel sub command...
access.Init(shutdownC, graceShutdownC) access.Init(graceShutdownC)
updater.Init(Version) updater.Init(Version)
runApp(app, shutdownC, graceShutdownC) runApp(app, graceShutdownC)
} }
func commands(version func(c *cli.Context)) []*cli.Command { func commands(version func(c *cli.Context)) []*cli.Command {
@ -145,10 +142,10 @@ func isEmptyInvocation(c *cli.Context) bool {
return c.NArg() == 0 && c.NumFlags() == 0 return c.NArg() == 0 && c.NumFlags() == 0
} }
func action(version string, shutdownC, graceShutdownC chan struct{}) cli.ActionFunc { func action(graceShutdownC chan struct{}) cli.ActionFunc {
return cliutil.ErrorHandler(func(c *cli.Context) (err error) { return cliutil.ErrorHandler(func(c *cli.Context) (err error) {
if isEmptyInvocation(c) { if isEmptyInvocation(c) {
return handleServiceMode(c, shutdownC) return handleServiceMode(c, graceShutdownC)
} }
tags := make(map[string]string) tags := make(map[string]string)
tags["hostname"] = c.String("hostname") tags["hostname"] = c.String("hostname")

View File

@ -86,7 +86,6 @@ const (
) )
var ( var (
shutdownC chan struct{}
graceShutdownC chan struct{} graceShutdownC chan struct{}
version string version string
) )
@ -165,8 +164,8 @@ func TunnelCommand(c *cli.Context) error {
return runClassicTunnel(sc) return runClassicTunnel(sc)
} }
func Init(v string, s, g chan struct{}) { func Init(ver string, gracefulShutdown chan struct{}) {
version, shutdownC, graceShutdownC = v, s, g version, graceShutdownC = ver, gracefulShutdown
} }
// runAdhocNamedTunnel create, route and run a named tunnel in one command // runAdhocNamedTunnel create, route and run a named tunnel in one command
@ -222,8 +221,6 @@ func StartServer(
var wg sync.WaitGroup var wg sync.WaitGroup
listeners := gracenet.Net{} listeners := gracenet.Net{}
errC := make(chan error) errC := make(chan error)
connectedSignal := signal.New(make(chan struct{}))
dnsReadySignal := make(chan struct{})
if config.GetConfiguration().Source() == "" { if config.GetConfiguration().Source() == "" {
log.Info().Msg(config.ErrNoConfigFile.Error()) log.Info().Msg(config.ErrNoConfigFile.Error())
@ -266,30 +263,29 @@ func StartServer(
buildInfo.Log(log) buildInfo.Log(log)
logClientOptions(c, log) logClientOptions(c, log)
// this context drives the server, when it's cancelled tunnel and all other components (origins, dns, etc...) should stop
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go waitForSignal(graceShutdownC, log)
if c.IsSet("proxy-dns") { if c.IsSet("proxy-dns") {
dnsReadySignal := make(chan struct{})
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- runDNSProxyServer(c, dnsReadySignal, shutdownC, log) errC <- runDNSProxyServer(c, dnsReadySignal, ctx.Done(), log)
}() }()
} else { // Wait for proxy-dns to come up (if used)
close(dnsReadySignal) <-dnsReadySignal
} }
// Wait for proxy-dns to come up (if used) connectedSignal := signal.New(make(chan struct{}))
<-dnsReadySignal
go notifySystemd(connectedSignal) go notifySystemd(connectedSignal)
if c.IsSet("pidfile") { if c.IsSet("pidfile") {
go writePidFile(connectedSignal, c.String("pidfile"), log) go writePidFile(connectedSignal, c.String("pidfile"), log)
} }
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-shutdownC
cancel()
}()
// update needs to be after DNS proxy is up to resolve equinox server address // update needs to be after DNS proxy is up to resolve equinox server address
if updater.IsAutoupdateEnabled(c, log) { if updater.IsAutoupdateEnabled(c, log) {
autoupdateFreq := c.Duration("autoupdate-freq") autoupdateFreq := c.Duration("autoupdate-freq")
@ -306,7 +302,7 @@ func StartServer(
if dnsProxyStandAlone(c) { if dnsProxyStandAlone(c) {
connectedSignal.Notify() connectedSignal.Notify()
// no grace period, handle SIGINT/SIGTERM immediately // no grace period, handle SIGINT/SIGTERM immediately
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0, log) return waitToShutdown(&wg, cancel, errC, graceShutdownC, 0, log)
} }
url := c.String("url") url := c.String("url")
@ -338,10 +334,10 @@ func StartServer(
defer wg.Done() defer wg.Done()
readinessServer := metrics.NewReadyServer(log) readinessServer := metrics.NewReadyServer(log)
observer.RegisterSink(readinessServer) observer.RegisterSink(readinessServer)
errC <- metrics.ServeMetrics(metricsListener, shutdownC, readinessServer, log) errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, log)
}() }()
if err := ingressRules.StartOrigins(&wg, log, shutdownC, errC); err != nil { if err := ingressRules.StartOrigins(&wg, log, ctx.Done(), errC); err != nil {
return err return err
} }
@ -353,7 +349,10 @@ func StartServer(
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer func() {
wg.Done()
log.Info().Msg("Tunnel server stopped")
}()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC) errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC)
}() }()
@ -369,7 +368,7 @@ func StartServer(
observer.RegisterSink(app) observer.RegisterSink(app)
} }
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log) return waitToShutdown(&wg, cancel, errC, graceShutdownC, c.Duration("grace-period"), log)
} }
func SetFlagsFromConfigFile(c *cli.Context) error { func SetFlagsFromConfigFile(c *cli.Context) error {
@ -393,30 +392,44 @@ func SetFlagsFromConfigFile(c *cli.Context) error {
} }
func waitToShutdown(wg *sync.WaitGroup, func waitToShutdown(wg *sync.WaitGroup,
errC chan error, cancelServerContext func(),
shutdownC, graceShutdownC chan struct{}, errC <-chan error,
graceShutdownC <-chan struct{},
gracePeriod time.Duration, gracePeriod time.Duration,
log *zerolog.Logger, log *zerolog.Logger,
) error { ) error {
var err error var err error
if gracePeriod > 0 { select {
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownC, gracePeriod, log) case err = <-errC:
} else { log.Error().Err(err).Msg("Initiating shutdown")
err = waitForSignal(errC, shutdownC, log) case <-graceShutdownC:
close(graceShutdownC) log.Debug().Msg("Graceful shutdown signalled")
if gracePeriod > 0 {
// wait for either grace period or service termination
select {
case <-time.Tick(gracePeriod):
case <-errC:
}
}
} }
if err != nil { // stop server context
log.Err(err).Msg("Quitting due to error") cancelServerContext()
} else {
log.Info().Msg("Quitting...") // Wait for clean exit, discarding all errors while we wait
} stopDiscarding := make(chan struct{})
// Wait for clean exit, discarding all errors
go func() { go func() {
for range errC { for {
select {
case <-errC: // ignore
case <-stopDiscarding:
return
}
} }
}() }()
wg.Wait() wg.Wait()
close(stopDiscarding)
return err return err
} }

View File

@ -8,7 +8,7 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{}, log *zerolog.Logger) error { func runDNSProxyServer(c *cli.Context, dnsReadySignal chan struct{}, shutdownC <-chan struct{}, log *zerolog.Logger) error {
port := c.Int("proxy-dns-port") port := c.Int("proxy-dns-port")
if port <= 0 || port > 65535 { if port <= 0 || port > 65535 {
return errors.New("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.") return errors.New("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
@ -26,5 +26,6 @@ func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{},
} }
<-shutdownC <-shutdownC
_ = listener.Stop() _ = listener.Stop()
log.Info().Msg("DNS server stopped")
return nil return nil
} }

View File

@ -4,84 +4,20 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
const LogFieldSignal = "signal" // waitForSignal closes graceShutdownC to indicate that we should start graceful shutdown sequence
func waitForSignal(graceShutdownC chan struct{}, logger *zerolog.Logger) {
// waitForSignal notifies all routines to shutdownC immediately by closing the
// shutdownC when one of the routines in main exits, or when this process receives
// SIGTERM/SIGINT
func waitForSignal(errC chan error, shutdownC chan struct{}, log *zerolog.Logger) error {
signals := make(chan os.Signal, 10) signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals) defer signal.Stop(signals)
select { select {
case err := <-errC:
log.Err(err).Msg("terminating due to error")
close(shutdownC)
return err
case s := <-signals:
log.Info().Str(LogFieldSignal, s.String()).Msg("terminating due to signal")
close(shutdownC)
case <-shutdownC:
}
return nil
}
// waitForSignalWithGraceShutdown notifies all routines to shutdown immediately
// by closing the shutdownC when one of the routines in main exits.
// When this process recieves SIGTERM/SIGINT, it closes the graceShutdownC to
// notify certain routines to start graceful shutdown. When grace period is over,
// or when some routine exits, it notifies the rest of the routines to shutdown
// immediately by closing shutdownC.
// In the case of handling commands from Windows Service Manager, closing graceShutdownC
// initiate graceful shutdown.
func waitForSignalWithGraceShutdown(errC chan error,
shutdownC, graceShutdownC chan struct{},
gracePeriod time.Duration,
logger *zerolog.Logger,
) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
logger.Info().Msgf("Initiating shutdown due to %v ...", err)
close(graceShutdownC)
close(shutdownC)
return err
case s := <-signals: case s := <-signals:
logger.Info().Msgf("Initiating graceful shutdown due to signal %s ...", s) logger.Info().Msgf("Initiating graceful shutdown due to signal %s ...", s)
close(graceShutdownC) close(graceShutdownC)
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
case <-graceShutdownC: case <-graceShutdownC:
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
case <-shutdownC:
close(graceShutdownC)
} }
}
return nil
}
func waitForGracePeriod(signals chan os.Signal,
errC chan error,
shutdownC chan struct{},
gracePeriod time.Duration,
) {
// Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
// to force shutdown cloudflared
signal.Stop(signals)
graceTimerTick := time.Tick(gracePeriod)
// send close signal via shutdownC when grace period expires or when an
// error is encountered.
select {
case <-graceTimerTick:
case <-errC:
}
close(shutdownC)
}

View File

@ -2,11 +2,12 @@ package tunnel
import ( import (
"fmt" "fmt"
"github.com/rs/zerolog" "sync"
"syscall" "syscall"
"testing" "testing"
"time" "time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -18,40 +19,21 @@ var (
graceShutdownErr = fmt.Errorf("receive grace shutdown") graceShutdownErr = fmt.Errorf("receive grace shutdown")
) )
func testChannelClosed(t *testing.T, c chan struct{}) { func channelClosed(c chan struct{}) bool {
select { select {
case <-c: case <-c:
return return true
default: default:
t.Fatal("Channel should be closed") return false
} }
} }
func TestWaitForSignal(t *testing.T) { func TestSignalShutdown(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
// Test handling server error
errC := make(chan error)
shutdownC := make(chan struct{})
go func() {
errC <- serverErr
}()
// received error, shutdownC should be closed
err := waitForSignal(errC, shutdownC, &log)
assert.Equal(t, serverErr, err)
testChannelClosed(t, shutdownC)
// Test handling SIGTERM & SIGINT // Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC = make(chan error) graceShutdownC := make(chan struct{})
shutdownC = make(chan struct{})
go func(shutdownC chan struct{}) {
<-shutdownC
errC <- shutdownErr
}(shutdownC)
go func(sig syscall.Signal) { go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignal // sleep for a tick to prevent sending signal before calling waitForSignal
@ -59,99 +41,64 @@ func TestWaitForSignal(t *testing.T) {
_ = syscall.Kill(syscall.Getpid(), sig) _ = syscall.Kill(syscall.Getpid(), sig)
}(sig) }(sig)
err = waitForSignal(errC, shutdownC, &log) time.AfterFunc(time.Second, func() {
assert.Equal(t, nil, err) select {
assert.Equal(t, shutdownErr, <-errC) case <-graceShutdownC:
testChannelClosed(t, shutdownC) default:
close(graceShutdownC)
t.Fatal("waitForSignal timed out")
}
})
waitForSignal(graceShutdownC, &log)
assert.True(t, channelClosed(graceShutdownC))
} }
} }
func TestWaitForSignalWithGraceShutdown(t *testing.T) { func TestWaitForShutdown(t *testing.T) {
// Test server returning error log := zerolog.Nop()
errC := make(chan error)
shutdownC := make(chan struct{})
graceshutdownC := make(chan struct{})
errC := make(chan error)
graceShutdownC := make(chan struct{})
const gracePeriod = 5 * time.Second
contextCancelled := false
cancel := func() {
contextCancelled = true
}
var wg sync.WaitGroup
// on, error stop immediately
contextCancelled = false
startTime := time.Now()
go func() { go func() {
errC <- serverErr errC <- serverErr
}() }()
err := waitToShutdown(&wg, cancel, errC, graceShutdownC, gracePeriod, &log)
log := zerolog.Nop()
// received error, both shutdownC and graceshutdownC should be closed
err := waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, &log)
assert.Equal(t, serverErr, err) assert.Equal(t, serverErr, err)
testChannelClosed(t, shutdownC) assert.True(t, contextCancelled)
testChannelClosed(t, graceshutdownC) assert.False(t, channelClosed(graceShutdownC))
assert.True(t, time.Now().Sub(startTime) < time.Second) // check that wait ended early
// shutdownC closed, graceshutdownC should also be closed and no error // on graceful shutdown, ignore error but stop as soon as an error arrives
errC = make(chan error) contextCancelled = false
shutdownC = make(chan struct{}) startTime = time.Now()
graceshutdownC = make(chan struct{}) go func() {
close(shutdownC) close(graceShutdownC)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, &log) time.Sleep(tick)
assert.NoError(t, err) errC <- serverErr
testChannelClosed(t, shutdownC) }()
testChannelClosed(t, graceshutdownC) err = waitToShutdown(&wg, cancel, errC, graceShutdownC, gracePeriod, &log)
assert.Nil(t, err)
assert.True(t, contextCancelled)
assert.True(t, time.Now().Sub(startTime) < time.Second) // check that wait ended early
// graceshutdownC closed, shutdownC should also be closed and no error // with graceShutdownC closed stop right away without grace period
errC = make(chan error) contextCancelled = false
shutdownC = make(chan struct{}) startTime = time.Now()
graceshutdownC = make(chan struct{}) err = waitToShutdown(&wg, cancel, errC, graceShutdownC, 0, &log)
close(graceshutdownC) assert.Nil(t, err)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, &log) assert.True(t, contextCancelled)
assert.NoError(t, err) assert.True(t, time.Now().Sub(startTime) < time.Second) // check that wait ended early
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
// Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC := make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
go func(shutdownC, graceshutdownC chan struct{}) {
<-graceshutdownC
<-shutdownC
errC <- graceShutdownErr
}(shutdownC, graceshutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
time.Sleep(tick)
_ = syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, &log)
assert.Equal(t, nil, err)
assert.Equal(t, graceShutdownErr, <-errC)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
}
// Test handling SIGTERM & SIGINT, server send error before end of grace period
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC := make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
go func(shutdownC, graceshutdownC chan struct{}) {
<-graceshutdownC
errC <- graceShutdownErr
<-shutdownC
errC <- shutdownErr
}(shutdownC, graceshutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
time.Sleep(tick)
_ = syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, &log)
assert.Equal(t, nil, err)
assert.Equal(t, shutdownErr, <-errC)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
}
} }

View File

@ -40,21 +40,21 @@ const (
LogFieldWindowsServiceName = "windowsServiceName" LogFieldWindowsServiceName = "windowsServiceName"
) )
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the Argo Tunnel Windows service", Usage: "Manages the Argo Tunnel Windows service",
Subcommands: []*cli.Command{ Subcommands: []*cli.Command{
&cli.Command{ {
Name: "install", Name: "install",
Usage: "Install Argo Tunnel as a Windows service", Usage: "Install Argo Tunnel as a Windows service",
Action: installWindowsService, Action: installWindowsService,
}, },
&cli.Command{ {
Name: "uninstall", Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service", Usage: "Uninstall the Argo Tunnel service",
Action: uninstallWindowsService, Action: uninstallWindowsService,
}, },
}, },
}) })
@ -82,7 +82,7 @@ func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
// Run executes service name by calling windowsService which is a Handler // Run executes service name by calling windowsService which is a Handler
// interface that implements Execute method. // interface that implements Execute method.
// It will set service status to stop after Execute returns // It will set service status to stop after Execute returns
err = svc.Run(windowsServiceName, &windowsService{app: app, shutdownC: shutdownC, graceShutdownC: graceShutdownC}) err = svc.Run(windowsServiceName, &windowsService{app: app, graceShutdownC: graceShutdownC})
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok && int(errno) == serviceControllerConnectionFailure { if errno, ok := err.(syscall.Errno); ok && int(errno) == serviceControllerConnectionFailure {
// Hack: assume this is a false negative from the IsAnInteractiveSession() check above. // Hack: assume this is a false negative from the IsAnInteractiveSession() check above.
@ -96,11 +96,11 @@ func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
type windowsService struct { type windowsService struct {
app *cli.App app *cli.App
shutdownC chan struct{}
graceShutdownC chan struct{} graceShutdownC chan struct{}
} }
// called by the package code at the start of the service // Execute is called by the service manager when service starts, the state
// of the service will be set to Stopped when this function returns.
func (s *windowsService) Execute(serviceArgs []string, r <-chan svc.ChangeRequest, statusChan chan<- svc.Status) (ssec bool, errno uint32) { func (s *windowsService) Execute(serviceArgs []string, r <-chan svc.ChangeRequest, statusChan chan<- svc.Status) (ssec bool, errno uint32) {
log := logger.Create(nil) log := logger.Create(nil)
elog, err := eventlog.Open(windowsServiceName) elog, err := eventlog.Open(windowsServiceName)
@ -128,13 +128,12 @@ func (s *windowsService) Execute(serviceArgs []string, r <-chan svc.ChangeReques
} }
elog.Info(1, fmt.Sprintf("%s service arguments: %v", windowsServiceName, args)) elog.Info(1, fmt.Sprintf("%s service arguments: %v", windowsServiceName, args))
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
statusChan <- svc.Status{State: svc.StartPending} statusChan <- svc.Status{State: svc.StartPending}
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- s.app.Run(args) errC <- s.app.Run(args)
}() }()
statusChan <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} statusChan <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
for { for {
select { select {
@ -143,17 +142,25 @@ func (s *windowsService) Execute(serviceArgs []string, r <-chan svc.ChangeReques
case svc.Interrogate: case svc.Interrogate:
statusChan <- c.CurrentStatus statusChan <- c.CurrentStatus
case svc.Stop, svc.Shutdown: case svc.Stop, svc.Shutdown:
close(s.graceShutdownC) if s.graceShutdownC != nil {
statusChan <- svc.Status{State: svc.Stopped, Accepts: cmdsAccepted} // start graceful shutdown
elog.Info(1, "cloudflared starting graceful shutdown")
close(s.graceShutdownC)
s.graceShutdownC = nil
statusChan <- svc.Status{State: svc.StopPending}
continue
}
// repeated attempts at graceful shutdown forces immediate stop
elog.Info(1, "cloudflared terminating immediately")
statusChan <- svc.Status{State: svc.StopPending} statusChan <- svc.Status{State: svc.StopPending}
return return false, 0
default: default:
elog.Error(1, fmt.Sprintf("unexpected control request #%d", c)) elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
} }
case err := <-errC: case err := <-errC:
ssec = true
if err != nil { if err != nil {
elog.Error(1, fmt.Sprintf("cloudflared terminated with error %v", err)) elog.Error(1, fmt.Sprintf("cloudflared terminated with error %v", err))
ssec = true
errno = 1 errno = 1
} else { } else {
elog.Info(1, "cloudflared terminated without error") elog.Info(1, "cloudflared terminated without error")