From 200ea2bfc61920b2ca2aea2a4a94d56ab435fdeb Mon Sep 17 00:00:00 2001 From: cloudflare-warp-bot Date: Wed, 30 May 2018 22:26:09 +0000 Subject: [PATCH] Release Argo Tunnel Client 2018.5.7 --- cmd/cloudflared/generic_service.go | 2 +- cmd/cloudflared/linux_service.go | 2 +- cmd/cloudflared/macos_service.go | 2 +- cmd/cloudflared/main.go | 28 ++--- cmd/cloudflared/signal.go | 56 +++++++--- cmd/cloudflared/signal_test.go | 9 ++ cmd/cloudflared/windows_service.go | 13 ++- h2mux/h2mux.go | 38 ++++--- h2mux/h2mux_test.go | 6 +- origin/metrics.go | 2 - origin/supervisor.go | 2 + origin/tunnel.go | 173 +++++++++++++++++------------ 12 files changed, 202 insertions(+), 131 deletions(-) diff --git a/cmd/cloudflared/generic_service.go b/cmd/cloudflared/generic_service.go index a2fbf494..abb24111 100644 --- a/cmd/cloudflared/generic_service.go +++ b/cmd/cloudflared/generic_service.go @@ -8,6 +8,6 @@ import ( cli "gopkg.in/urfave/cli.v2" ) -func runApp(app *cli.App, shutdownC chan struct{}) { +func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { app.Run(os.Args) } diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go index 19185171..2c09160f 100644 --- a/cmd/cloudflared/linux_service.go +++ b/cmd/cloudflared/linux_service.go @@ -10,7 +10,7 @@ import ( cli "gopkg.in/urfave/cli.v2" ) -func runApp(app *cli.App, shutdownC chan struct{}) { +func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", Usage: "Manages the Argo Tunnel system service", diff --git a/cmd/cloudflared/macos_service.go b/cmd/cloudflared/macos_service.go index 48102e55..d4fd3f20 100644 --- a/cmd/cloudflared/macos_service.go +++ b/cmd/cloudflared/macos_service.go @@ -15,7 +15,7 @@ const ( launchdIdentifier = "com.cloudflare.cloudflared" ) -func runApp(app *cli.App, shutdownC chan struct{}) { +func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", Usage: "Manages the Argo Tunnel launch agent", diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index 7e45b9da..4daf0689 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -40,9 +40,12 @@ func main() { raven.SetDSN(sentryDSN) raven.SetRelease(Version) - // Shutdown channel used by the app. When closed, app must terminate. - // May be closed by the Windows service runner. + // Force shutdown channel used by the app. When closed, app must terminate. + // Windows service manager closes this channel when it receives shutdown command. shutdownC := make(chan struct{}) + // Graceful shutdown channel used by the app. When closed, app must terminate. + // Windows service manager closes this channel when it receives stop command. + graceShutdownC := make(chan struct{}) app := &cli.App{} app.Name = "cloudflared" @@ -280,7 +283,7 @@ func main() { tags := make(map[string]string) tags["hostname"] = c.String("hostname") raven.SetTagsContext(tags) - raven.CapturePanicAndWait(func() { err = startServer(c, shutdownC) }, nil) + raven.CapturePanicAndWait(func() { err = startServer(c, shutdownC, graceShutdownC) }, nil) if err != nil { raven.CaptureErrorAndWait(err, nil) } @@ -374,16 +377,15 @@ func main() { ArgsUsage: " ", // can't be the empty string or we get the default output }, } - runApp(app, shutdownC) + runApp(app, shutdownC, graceShutdownC) } -func startServer(c *cli.Context, shutdownC chan struct{}) error { +func startServer(c *cli.Context, shutdownC, graceShutdownC chan struct{}) error { var wg sync.WaitGroup listeners := gracenet.Net{} errC := make(chan error) connectedSignal := make(chan struct{}) dnsReadySignal := make(chan struct{}) - graceShutdownSignal := make(chan struct{}) // check whether client provides enough flags or env variables. If not, print help. if ok := enoughOptionsSet(c); !ok { @@ -430,7 +432,7 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error { if isAutoupdateEnabled(c) { logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) wg.Add(1) - go func(){ + go func() { defer wg.Done() errC <- autoupdate(c.Duration("autoupdate-freq"), &listeners, shutdownC) }() @@ -457,7 +459,7 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error { if dnsProxyStandAlone(c) { close(connectedSignal) // no grace period, handle SIGINT/SIGTERM immediately - return waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, 0) + return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0) } if c.IsSet("hello-world") { @@ -483,23 +485,23 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error { wg.Add(1) go func() { defer wg.Done() - errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownSignal, connectedSignal) + errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownC, connectedSignal) }() - return waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, c.Duration("grace-period")) + return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period")) } func waitToShutdown(wg *sync.WaitGroup, errC chan error, - shutdownC, graceShutdownSignal chan struct{}, + shutdownC, graceShutdownC chan struct{}, gracePeriod time.Duration, ) error { var err error if gracePeriod > 0 { - err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownSignal, gracePeriod) + err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownC, gracePeriod) } else { err = waitForSignal(errC, shutdownC) - close(graceShutdownSignal) + close(graceShutdownC) } if err != nil { diff --git a/cmd/cloudflared/signal.go b/cmd/cloudflared/signal.go index 2fb22ff1..8de703b3 100644 --- a/cmd/cloudflared/signal.go +++ b/cmd/cloudflared/signal.go @@ -7,6 +7,9 @@ import ( "time" ) +// waitForSignal notifies all routines to shutdownC immediately by closing the +// shutdownC when one of the routines in main exits, or when this process receives +// SIGTERM/SIGINT func waitForSignal(errC chan error, shutdownC chan struct{}) error { signals := make(chan os.Signal, 10) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) @@ -23,33 +26,54 @@ func waitForSignal(errC chan error, shutdownC chan struct{}) error { return nil } -func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSignal chan struct{}, gracePeriod time.Duration) error { +// waitForSignalWithGraceShutdown notifies all routines to shutdown immediately +// by closing the shutdownC when one of the routines in main exits. +// When this process recieves SIGTERM/SIGINT, it closes the graceShutdownC to +// notify certain routines to start graceful shutdown. When grace period is over, +// or when some routine exits, it notifies the rest of the routines to shutdown +// immediately by closing shutdownC. +// In the case of handling commands from Windows Service Manager, closing graceShutdownC +// initiate graceful shutdown. +func waitForSignalWithGraceShutdown(errC chan error, + shutdownC, graceShutdownC chan struct{}, + gracePeriod time.Duration, +) error { signals := make(chan os.Signal, 10) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) defer signal.Stop(signals) select { case err := <-errC: - close(graceShutdownSignal) + close(graceShutdownC) close(shutdownC) return err case <-signals: - close(graceShutdownSignal) - logger.Infof("Initiating graceful shutdown...") - // Unregister signal handler early, so the client can send a second SIGTERM/SIGINT - // to force shutdown cloudflared - signal.Stop(signals) - graceTimerTick := time.Tick(gracePeriod) - // send close signal via shutdownC when grace period expires or when an - // error is encountered. - select { - case <-graceTimerTick: - case <-errC: - } - close(shutdownC) + close(graceShutdownC) + waitForGracePeriod(signals, errC, shutdownC, gracePeriod) + case <-graceShutdownC: + waitForGracePeriod(signals, errC, shutdownC, gracePeriod) case <-shutdownC: - close(graceShutdownSignal) + close(graceShutdownC) } return nil } + +func waitForGracePeriod(signals chan os.Signal, + errC chan error, + shutdownC chan struct{}, + gracePeriod time.Duration, +) { + logger.Infof("Initiating graceful shutdown...") + // Unregister signal handler early, so the client can send a second SIGTERM/SIGINT + // to force shutdown cloudflared + signal.Stop(signals) + graceTimerTick := time.Tick(gracePeriod) + // send close signal via shutdownC when grace period expires or when an + // error is encountered. + select { + case <-graceTimerTick: + case <-errC: + } + close(shutdownC) +} diff --git a/cmd/cloudflared/signal_test.go b/cmd/cloudflared/signal_test.go index c0e0d546..4e06d001 100644 --- a/cmd/cloudflared/signal_test.go +++ b/cmd/cloudflared/signal_test.go @@ -89,6 +89,15 @@ func TestWaitForSignalWithGraceShutdown(t *testing.T) { testChannelClosed(t, shutdownC) testChannelClosed(t, graceshutdownC) + // graceshutdownC closed, shutdownC should also be closed and no error + errC = make(chan error) + shutdownC = make(chan struct{}) + graceshutdownC = make(chan struct{}) + close(graceshutdownC) + err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick) + assert.NoError(t, err) + testChannelClosed(t, shutdownC) + testChannelClosed(t, graceshutdownC) // Test handling SIGTERM & SIGINT for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { diff --git a/cmd/cloudflared/windows_service.go b/cmd/cloudflared/windows_service.go index be8b2e12..0ce2dfea 100644 --- a/cmd/cloudflared/windows_service.go +++ b/cmd/cloudflared/windows_service.go @@ -31,7 +31,7 @@ const ( serviceConfigFailureActionsFlag = 4 ) -func runApp(app *cli.App, shutdownC chan struct{}) { +func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", Usage: "Manages the Argo Tunnel Windows service", @@ -70,7 +70,7 @@ func runApp(app *cli.App, shutdownC chan struct{}) { // Run executes service name by calling windowsService which is a Handler // interface that implements Execute method. // It will set service status to stop after Execute returns - err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC}) + err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC, graceShutdownC: graceShutdownC}) if err != nil { elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err)) return @@ -79,9 +79,10 @@ func runApp(app *cli.App, shutdownC chan struct{}) { } type windowsService struct { - app *cli.App - elog *eventlog.Log - shutdownC chan struct{} + app *cli.App + elog *eventlog.Log + shutdownC chan struct{} + graceShutdownC chan struct{} } // called by the package code at the start of the service @@ -103,7 +104,7 @@ func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, stat statusChan <- c.CurrentStatus case svc.Stop: s.elog.Info(1, "received stop control request") - close(s.shutdownC) + close(s.graceShutdownC) statusChan <- svc.Status{State: svc.StopPending} case svc.Shutdown: s.elog.Info(1, "received shutdown control request") diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index aeef502a..dd17df3f 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -1,6 +1,7 @@ package h2mux import ( + "context" "io" "strings" "sync" @@ -9,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "golang.org/x/sync/errgroup" ) const ( @@ -256,31 +258,31 @@ func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time. return nil } -func (m *Muxer) Serve() error { +func (m *Muxer) Serve(ctx context.Context) error { logger := m.config.Logger.WithField("name", m.config.Name) - errChan := make(chan error) - go func() { - errChan <- m.muxReader.run(logger) + errGroup, _ := errgroup.WithContext(ctx) + errGroup.Go(func() error { + err := m.muxReader.run(logger) m.explicitShutdown.Fuse(false) m.r.Close() m.abort() - }() - go func() { - errChan <- m.muxWriter.run(logger) + return err + }) + + errGroup.Go(func() error { + err := m.muxWriter.run(logger) m.explicitShutdown.Fuse(false) m.w.Close() m.abort() - }() - go func() { - errChan <- m.muxMetricsUpdater.run(logger) - }() - err := <-errChan - go func() { - // discard errors as other handler and muxMetricsUpdater close - <-errChan - <-errChan - close(errChan) - }() + return err + }) + + errGroup.Go(func() error { + err := m.muxMetricsUpdater.run(logger) + return err + }) + + err := errGroup.Wait() if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) { return err } diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 8e566b5c..5dcaa679 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -2,6 +2,7 @@ package h2mux import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -78,11 +79,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) { } func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) { + ctx := context.Background() p.Handshake(t) var wg sync.WaitGroup wg.Add(2) go func() { - err := p.EdgeMux.Serve() + err := p.EdgeMux.Serve(ctx) if err != nil && err != io.EOF && err != io.ErrClosedPipe { t.Errorf("error in edge muxer Serve(): %s", err) } @@ -90,7 +92,7 @@ func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) { wg.Done() }() go func() { - err := p.OriginMux.Serve() + err := p.OriginMux.Serve(ctx) if err != nil && err != io.EOF && err != io.ErrClosedPipe { t.Errorf("error in origin muxer Serve(): %s", err) } diff --git a/origin/metrics.go b/origin/metrics.go index 8ee560c1..3ae30b56 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -360,8 +360,6 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { t.concurrentRequestsLock.Lock() if _, ok := t.concurrentRequests[connectionID]; ok { t.concurrentRequests[connectionID] -= 1 - } else { - logger.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") } t.concurrentRequestsLock.Unlock() diff --git a/origin/supervisor.go b/origin/supervisor.go index 9a315ad8..297d5153 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -51,6 +51,7 @@ func NewSupervisor(config *TunnelConfig) *Supervisor { } func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error { + logger := s.config.Logger if err := s.initialize(ctx, connectedSignal); err != nil { return err } @@ -119,6 +120,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err } func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { + logger := s.config.Logger edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) if err != nil { logger.Infof("ResolveEdgeIPs err") diff --git a/origin/tunnel.go b/origin/tunnel.go index a7d0c579..17cc2a4d 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -9,10 +9,10 @@ import ( "net/url" "strconv" "strings" - "sync" "time" "golang.org/x/net/context" + "golang.org/x/sync/errgroup" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/tunnelrpc" @@ -27,8 +27,6 @@ import ( rpc "zombiezen.com/go/capnproto2/rpc" ) -var logger *logrus.Logger - const ( dialTimeout = 15 * time.Second @@ -76,12 +74,28 @@ func (e dupConnRegisterTunnelError) Error() string { return "already connected to this server" } -type printableRegisterTunnelError struct { +type muxerShutdownError struct{} + +func (e muxerShutdownError) Error() string { + return "muxer shutdown" +} + +// RegisterTunnel error from server +type serverRegisterTunnelError struct { cause error permanent bool } -func (e printableRegisterTunnelError) Error() string { +func (e serverRegisterTunnelError) Error() string { + return e.cause.Error() +} + +// RegisterTunnel error from client +type clientRegisterTunnelError struct { + cause error +} + +func (e clientRegisterTunnelError) Error() string { return e.cause.Error() } @@ -105,7 +119,6 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { - logger = config.Logger ctx, cancel := context.WithCancel(context.Background()) go func() { <-shutdownC @@ -129,6 +142,7 @@ func ServeTunnelLoop(ctx context.Context, connectionID uint8, connectedSignal chan struct{}, ) error { + logger := config.Logger config.Metrics.incrementHaConnections() defer config.Metrics.decrementHaConnections() backoff := BackoffHandler{MaxRetries: config.Retries} @@ -162,8 +176,6 @@ func ServeTunnel( connectedFuse *h2mux.BooleanFuse, backoff *BackoffHandler, ) (err error, recoverable bool) { - var wg sync.WaitGroup - wg.Add(2) // Treat panics as recoverable errors defer func() { if r := recover(); r != nil { @@ -175,8 +187,16 @@ func ServeTunnel( recoverable = true } }() + + connectionTag := uint8ToString(connectionID) + logger := config.Logger.WithField("connectionID", connectionTag) + + // additional tags to send other than hostname which is set in cloudflared main package + tags := make(map[string]string) + tags["ha"] = connectionTag + // Returns error from parsing the origin URL or handshake errors - handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) + handler, originLocalIP, err := NewTunnelHandler(ctx, logger, config, addr.String(), connectionID) if err != nil { errLog := logger.WithError(err) switch err.(type) { @@ -190,59 +210,69 @@ func ServeTunnel( } return err, true } - serveCtx, serveCancel := context.WithCancel(ctx) - registerErrC := make(chan error, 1) - go func() { - defer wg.Done() - err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP) + + errGroup, serveCtx := errgroup.WithContext(ctx) + + errGroup.Go(func() error { + err := RegisterTunnel(serveCtx, logger, handler.muxer, config, connectionID, originLocalIP) if err == nil { connectedFuse.Fuse(true) backoff.SetGracePeriod() - } else { - serveCancel() } - registerErrC <- err - }() - updateMetricsTickC := time.Tick(config.MetricsUpdateFreq) - go func() { - defer wg.Done() - connectionTag := uint8ToString(connectionID) + return err + }) + + errGroup.Go(func() error { + updateMetricsTickC := time.Tick(config.MetricsUpdateFreq) for { select { - case <-serveCtx.Done(): + case <-serveCtx.Done(): // UnregisterTunnel blocks until the RPC call returns - UnregisterTunnel(handler.muxer, config.GracePeriod) + err := UnregisterTunnel(logger, handler.muxer, config.GracePeriod) handler.muxer.Shutdown() - return + return err case <-updateMetricsTickC: handler.UpdateMetrics(connectionTag) } } - }() + }) - err = handler.muxer.Serve() - serveCancel() - registerErr := <-registerErrC - wg.Wait() - if err != nil { - logger.WithError(err).Error("Tunnel error") - return err, true - } - if registerErr != nil { - // Don't retry on errors like entitlement failure or version too old - if e, ok := registerErr.(printableRegisterTunnelError); ok { - logger.Error(e) - return e.cause, !e.permanent - } else if e, ok := registerErr.(dupConnRegisterTunnelError); ok { - logger.Info("Already connected to this server, selecting a different one") - return e, true + errGroup.Go(func() error { + // All routines should stop when muxer finish serving. When muxer is shutdown + // gracefully, it doesn't return an error, so we need to return errMuxerShutdown + // here to notify other routines to stop + err := handler.muxer.Serve(serveCtx); + if err == nil { + return muxerShutdownError{} + } + return err + }) + + err = errGroup.Wait() + if err != nil { + switch castedErr := err.(type) { + case dupConnRegisterTunnelError: + logger.Info("Already connected to this server, selecting a different one") + return err, true + case serverRegisterTunnelError: + logger.WithError(castedErr.cause).Error("Register tunnel error from server side") + // Don't send registration error return from server to Sentry. They are + // logged on server side + return castedErr.cause, !castedErr.permanent + case clientRegisterTunnelError: + logger.WithError(castedErr.cause).Error("Register tunnel error on client side") + raven.CaptureErrorAndWait(castedErr.cause, tags) + return err, true + case muxerShutdownError: + logger.Infof("Muxer shutdown") + return err, true + default: + logger.WithError(err).Error("Serve tunnel error") + raven.CaptureErrorAndWait(err, tags) + return err, true } - // Only log errors to Sentry that may have been caused by the client side, to reduce dupes - raven.CaptureError(registerErr, nil) - logger.Error("Cannot register") - return err, true } - return nil, false + return nil, true } func IsRPCStreamResponse(headers []h2mux.Header) bool { @@ -255,8 +285,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool { return true } -func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { - logger := logger.WithField("subsystem", "rpc") +func RegisterTunnel(ctx context.Context, logger *logrus.Entry, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { logger.Debug("initiating RPC stream to register") stream, err := muxer.OpenStream([]h2mux.Header{ {Name: ":method", Value: "RPC"}, @@ -265,16 +294,14 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi }, nil) if err != nil { // RPC stream open error - raven.CaptureError(err, nil) - return err + return clientRegisterTunnelError{cause: err} } if !IsRPCStreamResponse(stream.Headers) { // stream response error - raven.CaptureError(err, nil) - return err + return clientRegisterTunnelError{cause: err} } conn := rpc.NewConn( - tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), + tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), ) defer conn.Close() @@ -293,7 +320,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi LogServerInfo(logger, serverInfoPromise.Result(), connectionID, config.Metrics) if err != nil { // RegisterTunnel RPC failure - return err + return clientRegisterTunnelError{cause: err} } for _, logLine := range registration.LogLines { logger.Info(logLine) @@ -301,7 +328,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi if registration.Err == DuplicateConnectionError { return dupConnRegisterTunnelError{} } else if registration.Err != "" { - return printableRegisterTunnelError{ + return serverRegisterTunnelError{ cause: fmt.Errorf("Server error: %s", registration.Err), permanent: registration.PermanentFailure, } @@ -310,8 +337,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi return nil } -func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration) error { - logger := logger.WithField("subsystem", "rpc") +func UnregisterTunnel(logger *logrus.Entry, muxer *h2mux.Muxer, gracePeriod time.Duration) error { logger.Debug("initiating RPC stream to unregister") stream, err := muxer.OpenStream([]h2mux.Header{ {Name: ":method", Value: "RPC"}, @@ -320,17 +346,15 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration) error { }, nil) if err != nil { // RPC stream open error - raven.CaptureError(err, nil) return err } if !IsRPCStreamResponse(stream.Headers) { // stream response error - raven.CaptureError(err, nil) return err } ctx := context.Background() conn := rpc.NewConn( - tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), + tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), ) defer conn.Close() @@ -408,12 +432,18 @@ type TunnelHandler struct { metrics *TunnelMetrics // connectionID is only used by metrics, and prometheus requires labels to be string connectionID string + logger *logrus.Entry } var dialer = net.Dialer{DualStack: true} // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error -func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, connectionID uint8) (*TunnelHandler, string, error) { +func NewTunnelHandler(ctx context.Context, + logger *logrus.Entry, + config *TunnelConfig, + addr string, + connectionID uint8, +) (*TunnelHandler, string, error) { originURL, err := validation.ValidateUrl(config.OriginUrl) if err != nil { return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL) @@ -425,6 +455,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co tags: config.Tags, metrics: config.Metrics, connectionID: uint8ToString(connectionID), + logger: logger, } if h.httpClient == nil { h.httpClient = http.DefaultTransport @@ -471,11 +502,11 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { h.metrics.incrementRequests(h.connectionID) req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { - logger.WithError(err).Panic("Unexpected error from http.NewRequest") + h.logger.WithError(err).Panic("Unexpected error from http.NewRequest") } err = H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { - logger.WithError(err).Error("invalid request received") + h.logger.WithError(err).Error("invalid request received") } h.AppendTagHeaders(req) cfRay := FindCfRayHeader(req) @@ -510,7 +541,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { } func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { - logger.WithError(err).Error("HTTP request error") + h.logger.WithError(err).Error("HTTP request error") stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) stream.Write([]byte("502 Bad Gateway")) h.metrics.incrementResponses(h.connectionID, "502") @@ -518,20 +549,20 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) { if cfRay != "" { - logger.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto) + h.logger.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto) } else { - logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto) + h.logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto) } - logger.Debugf("Request Headers %+v", req.Header) + h.logger.Debugf("Request Headers %+v", req.Header) } func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) { if cfRay != "" { - logger.WithField("CF-RAY", cfRay).Infof("%s", r.Status) + h.logger.WithField("CF-RAY", cfRay).Infof("%s", r.Status) } else { - logger.Infof("%s", r.Status) + h.logger.Infof("%s", r.Status) } - logger.Debugf("Response Headers %+v", r.Header) + h.logger.Debugf("Response Headers %+v", r.Header) } func (h *TunnelHandler) UpdateMetrics(connectionID string) {