package tunnel import ( "fmt" "syscall" "testing" "time" "github.com/cloudflare/cloudflared/logger" "github.com/stretchr/testify/assert" ) const tick = 100 * time.Millisecond var ( serverErr = fmt.Errorf("server error") shutdownErr = fmt.Errorf("receive shutdown") graceShutdownErr = fmt.Errorf("receive grace shutdown") ) func testChannelClosed(t *testing.T, c chan struct{}) { select { case <-c: return default: t.Fatal("Channel should be closed") } } func TestWaitForSignal(t *testing.T) { logger := logger.NewOutputWriter(logger.NewMockWriteManager()) // Test handling server error errC := make(chan error) shutdownC := make(chan struct{}) go func() { errC <- serverErr }() // received error, shutdownC should be closed err := waitForSignal(errC, shutdownC, logger) assert.Equal(t, serverErr, err) testChannelClosed(t, shutdownC) // Test handling SIGTERM & SIGINT for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { errC = make(chan error) shutdownC = make(chan struct{}) go func(shutdownC chan struct{}) { <-shutdownC errC <- shutdownErr }(shutdownC) go func(sig syscall.Signal) { // sleep for a tick to prevent sending signal before calling waitForSignal time.Sleep(tick) syscall.Kill(syscall.Getpid(), sig) }(sig) err = waitForSignal(errC, shutdownC, logger) assert.Equal(t, nil, err) assert.Equal(t, shutdownErr, <-errC) testChannelClosed(t, shutdownC) } } func TestWaitForSignalWithGraceShutdown(t *testing.T) { // Test server returning error errC := make(chan error) shutdownC := make(chan struct{}) graceshutdownC := make(chan struct{}) go func() { errC <- serverErr }() logger := logger.NewOutputWriter(logger.NewMockWriteManager()) // received error, both shutdownC and graceshutdownC should be closed err := waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, logger) assert.Equal(t, serverErr, err) testChannelClosed(t, shutdownC) testChannelClosed(t, graceshutdownC) // shutdownC closed, graceshutdownC should also be closed and no error errC = make(chan error) shutdownC = make(chan struct{}) graceshutdownC = make(chan struct{}) close(shutdownC) err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, logger) assert.NoError(t, err) testChannelClosed(t, shutdownC) testChannelClosed(t, graceshutdownC) // graceshutdownC closed, shutdownC should also be closed and no error errC = make(chan error) shutdownC = make(chan struct{}) graceshutdownC = make(chan struct{}) close(graceshutdownC) err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, logger) assert.NoError(t, err) testChannelClosed(t, shutdownC) testChannelClosed(t, graceshutdownC) // Test handling SIGTERM & SIGINT for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { errC := make(chan error) shutdownC = make(chan struct{}) graceshutdownC = make(chan struct{}) go func(shutdownC, graceshutdownC chan struct{}) { <-graceshutdownC <-shutdownC errC <- graceShutdownErr }(shutdownC, graceshutdownC) go func(sig syscall.Signal) { // sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown time.Sleep(tick) syscall.Kill(syscall.Getpid(), sig) }(sig) err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, logger) assert.Equal(t, nil, err) assert.Equal(t, graceShutdownErr, <-errC) testChannelClosed(t, shutdownC) testChannelClosed(t, graceshutdownC) } // Test handling SIGTERM & SIGINT, server send error before end of grace period for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { errC := make(chan error) shutdownC = make(chan struct{}) graceshutdownC = make(chan struct{}) go func(shutdownC, graceshutdownC chan struct{}) { <-graceshutdownC errC <- graceShutdownErr <-shutdownC errC <- shutdownErr }(shutdownC, graceshutdownC) go func(sig syscall.Signal) { // sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown time.Sleep(tick) syscall.Kill(syscall.Getpid(), sig) }(sig) err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick, logger) assert.Equal(t, nil, err) assert.Equal(t, shutdownErr, <-errC) testChannelClosed(t, shutdownC) testChannelClosed(t, graceshutdownC) } }