Release Argo Tunnel Client 2018.5.7

This commit is contained in:
cloudflare-warp-bot 2018-05-30 22:26:09 +00:00
parent 7acd1b1fc8
commit 200ea2bfc6
12 changed files with 202 additions and 131 deletions

View File

@ -8,6 +8,6 @@ import (
cli "gopkg.in/urfave/cli.v2" cli "gopkg.in/urfave/cli.v2"
) )
func runApp(app *cli.App, shutdownC chan struct{}) { func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Run(os.Args) app.Run(os.Args)
} }

View File

@ -10,7 +10,7 @@ import (
cli "gopkg.in/urfave/cli.v2" cli "gopkg.in/urfave/cli.v2"
) )
func runApp(app *cli.App, shutdownC chan struct{}) { func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the Argo Tunnel system service", Usage: "Manages the Argo Tunnel system service",

View File

@ -15,7 +15,7 @@ const (
launchdIdentifier = "com.cloudflare.cloudflared" launchdIdentifier = "com.cloudflare.cloudflared"
) )
func runApp(app *cli.App, shutdownC chan struct{}) { func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the Argo Tunnel launch agent", Usage: "Manages the Argo Tunnel launch agent",

View File

@ -40,9 +40,12 @@ func main() {
raven.SetDSN(sentryDSN) raven.SetDSN(sentryDSN)
raven.SetRelease(Version) raven.SetRelease(Version)
// Shutdown channel used by the app. When closed, app must terminate. // Force shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner. // Windows service manager closes this channel when it receives shutdown command.
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
// Graceful shutdown channel used by the app. When closed, app must terminate.
// Windows service manager closes this channel when it receives stop command.
graceShutdownC := make(chan struct{})
app := &cli.App{} app := &cli.App{}
app.Name = "cloudflared" app.Name = "cloudflared"
@ -280,7 +283,7 @@ func main() {
tags := make(map[string]string) tags := make(map[string]string)
tags["hostname"] = c.String("hostname") tags["hostname"] = c.String("hostname")
raven.SetTagsContext(tags) raven.SetTagsContext(tags)
raven.CapturePanicAndWait(func() { err = startServer(c, shutdownC) }, nil) raven.CapturePanicAndWait(func() { err = startServer(c, shutdownC, graceShutdownC) }, nil)
if err != nil { if err != nil {
raven.CaptureErrorAndWait(err, nil) raven.CaptureErrorAndWait(err, nil)
} }
@ -374,16 +377,15 @@ func main() {
ArgsUsage: " ", // can't be the empty string or we get the default output ArgsUsage: " ", // can't be the empty string or we get the default output
}, },
} }
runApp(app, shutdownC) runApp(app, shutdownC, graceShutdownC)
} }
func startServer(c *cli.Context, shutdownC chan struct{}) error { func startServer(c *cli.Context, shutdownC, graceShutdownC chan struct{}) error {
var wg sync.WaitGroup var wg sync.WaitGroup
listeners := gracenet.Net{} listeners := gracenet.Net{}
errC := make(chan error) errC := make(chan error)
connectedSignal := make(chan struct{}) connectedSignal := make(chan struct{})
dnsReadySignal := make(chan struct{}) dnsReadySignal := make(chan struct{})
graceShutdownSignal := make(chan struct{})
// check whether client provides enough flags or env variables. If not, print help. // check whether client provides enough flags or env variables. If not, print help.
if ok := enoughOptionsSet(c); !ok { if ok := enoughOptionsSet(c); !ok {
@ -457,7 +459,7 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error {
if dnsProxyStandAlone(c) { if dnsProxyStandAlone(c) {
close(connectedSignal) close(connectedSignal)
// no grace period, handle SIGINT/SIGTERM immediately // no grace period, handle SIGINT/SIGTERM immediately
return waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, 0) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0)
} }
if c.IsSet("hello-world") { if c.IsSet("hello-world") {
@ -483,23 +485,23 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownSignal, connectedSignal) errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownC, connectedSignal)
}() }()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, c.Duration("grace-period")) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))
} }
func waitToShutdown(wg *sync.WaitGroup, func waitToShutdown(wg *sync.WaitGroup,
errC chan error, errC chan error,
shutdownC, graceShutdownSignal chan struct{}, shutdownC, graceShutdownC chan struct{},
gracePeriod time.Duration, gracePeriod time.Duration,
) error { ) error {
var err error var err error
if gracePeriod > 0 { if gracePeriod > 0 {
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownSignal, gracePeriod) err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownC, gracePeriod)
} else { } else {
err = waitForSignal(errC, shutdownC) err = waitForSignal(errC, shutdownC)
close(graceShutdownSignal) close(graceShutdownC)
} }
if err != nil { if err != nil {

View File

@ -7,6 +7,9 @@ import (
"time" "time"
) )
// waitForSignal notifies all routines to shutdownC immediately by closing the
// shutdownC when one of the routines in main exits, or when this process receives
// SIGTERM/SIGINT
func waitForSignal(errC chan error, shutdownC chan struct{}) error { func waitForSignal(errC chan error, shutdownC chan struct{}) error {
signals := make(chan os.Signal, 10) signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
@ -23,18 +26,44 @@ func waitForSignal(errC chan error, shutdownC chan struct{}) error {
return nil return nil
} }
func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSignal chan struct{}, gracePeriod time.Duration) error { // waitForSignalWithGraceShutdown notifies all routines to shutdown immediately
// by closing the shutdownC when one of the routines in main exits.
// When this process recieves SIGTERM/SIGINT, it closes the graceShutdownC to
// notify certain routines to start graceful shutdown. When grace period is over,
// or when some routine exits, it notifies the rest of the routines to shutdown
// immediately by closing shutdownC.
// In the case of handling commands from Windows Service Manager, closing graceShutdownC
// initiate graceful shutdown.
func waitForSignalWithGraceShutdown(errC chan error,
shutdownC, graceShutdownC chan struct{},
gracePeriod time.Duration,
) error {
signals := make(chan os.Signal, 10) signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals) defer signal.Stop(signals)
select { select {
case err := <-errC: case err := <-errC:
close(graceShutdownSignal) close(graceShutdownC)
close(shutdownC) close(shutdownC)
return err return err
case <-signals: case <-signals:
close(graceShutdownSignal) close(graceShutdownC)
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
case <-graceShutdownC:
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
case <-shutdownC:
close(graceShutdownC)
}
return nil
}
func waitForGracePeriod(signals chan os.Signal,
errC chan error,
shutdownC chan struct{},
gracePeriod time.Duration,
) {
logger.Infof("Initiating graceful shutdown...") logger.Infof("Initiating graceful shutdown...")
// Unregister signal handler early, so the client can send a second SIGTERM/SIGINT // Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
// to force shutdown cloudflared // to force shutdown cloudflared
@ -47,9 +76,4 @@ func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSig
case <-errC: case <-errC:
} }
close(shutdownC) close(shutdownC)
case <-shutdownC:
close(graceShutdownSignal)
}
return nil
} }

View File

@ -89,6 +89,15 @@ func TestWaitForSignalWithGraceShutdown(t *testing.T) {
testChannelClosed(t, shutdownC) testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC) testChannelClosed(t, graceshutdownC)
// graceshutdownC closed, shutdownC should also be closed and no error
errC = make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
close(graceshutdownC)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.NoError(t, err)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
// Test handling SIGTERM & SIGINT // Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {

View File

@ -31,7 +31,7 @@ const (
serviceConfigFailureActionsFlag = 4 serviceConfigFailureActionsFlag = 4
) )
func runApp(app *cli.App, shutdownC chan struct{}) { func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the Argo Tunnel Windows service", Usage: "Manages the Argo Tunnel Windows service",
@ -70,7 +70,7 @@ func runApp(app *cli.App, shutdownC chan struct{}) {
// Run executes service name by calling windowsService which is a Handler // Run executes service name by calling windowsService which is a Handler
// interface that implements Execute method. // interface that implements Execute method.
// It will set service status to stop after Execute returns // It will set service status to stop after Execute returns
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC}) err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC, graceShutdownC: graceShutdownC})
if err != nil { if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err)) elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
return return
@ -82,6 +82,7 @@ type windowsService struct {
app *cli.App app *cli.App
elog *eventlog.Log elog *eventlog.Log
shutdownC chan struct{} shutdownC chan struct{}
graceShutdownC chan struct{}
} }
// called by the package code at the start of the service // called by the package code at the start of the service
@ -103,7 +104,7 @@ func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, stat
statusChan <- c.CurrentStatus statusChan <- c.CurrentStatus
case svc.Stop: case svc.Stop:
s.elog.Info(1, "received stop control request") s.elog.Info(1, "received stop control request")
close(s.shutdownC) close(s.graceShutdownC)
statusChan <- svc.Status{State: svc.StopPending} statusChan <- svc.Status{State: svc.StopPending}
case svc.Shutdown: case svc.Shutdown:
s.elog.Info(1, "received shutdown control request") s.elog.Info(1, "received shutdown control request")

View File

@ -1,6 +1,7 @@
package h2mux package h2mux
import ( import (
"context"
"io" "io"
"strings" "strings"
"sync" "sync"
@ -9,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"golang.org/x/sync/errgroup"
) )
const ( const (
@ -256,31 +258,31 @@ func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.
return nil return nil
} }
func (m *Muxer) Serve() error { func (m *Muxer) Serve(ctx context.Context) error {
logger := m.config.Logger.WithField("name", m.config.Name) logger := m.config.Logger.WithField("name", m.config.Name)
errChan := make(chan error) errGroup, _ := errgroup.WithContext(ctx)
go func() { errGroup.Go(func() error {
errChan <- m.muxReader.run(logger) err := m.muxReader.run(logger)
m.explicitShutdown.Fuse(false) m.explicitShutdown.Fuse(false)
m.r.Close() m.r.Close()
m.abort() m.abort()
}() return err
go func() { })
errChan <- m.muxWriter.run(logger)
errGroup.Go(func() error {
err := m.muxWriter.run(logger)
m.explicitShutdown.Fuse(false) m.explicitShutdown.Fuse(false)
m.w.Close() m.w.Close()
m.abort() m.abort()
}() return err
go func() { })
errChan <- m.muxMetricsUpdater.run(logger)
}() errGroup.Go(func() error {
err := <-errChan err := m.muxMetricsUpdater.run(logger)
go func() { return err
// discard errors as other handler and muxMetricsUpdater close })
<-errChan
<-errChan err := errGroup.Wait()
close(errChan)
}()
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) { if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
return err return err
} }

View File

@ -2,6 +2,7 @@ package h2mux
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -78,11 +79,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) {
} }
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) { func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
ctx := context.Background()
p.Handshake(t) p.Handshake(t)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
err := p.EdgeMux.Serve() err := p.EdgeMux.Serve(ctx)
if err != nil && err != io.EOF && err != io.ErrClosedPipe { if err != nil && err != io.EOF && err != io.ErrClosedPipe {
t.Errorf("error in edge muxer Serve(): %s", err) t.Errorf("error in edge muxer Serve(): %s", err)
} }
@ -90,7 +92,7 @@ func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
wg.Done() wg.Done()
}() }()
go func() { go func() {
err := p.OriginMux.Serve() err := p.OriginMux.Serve(ctx)
if err != nil && err != io.EOF && err != io.ErrClosedPipe { if err != nil && err != io.EOF && err != io.ErrClosedPipe {
t.Errorf("error in origin muxer Serve(): %s", err) t.Errorf("error in origin muxer Serve(): %s", err)
} }

View File

@ -360,8 +360,6 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsLock.Lock() t.concurrentRequestsLock.Lock()
if _, ok := t.concurrentRequests[connectionID]; ok { if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] -= 1 t.concurrentRequests[connectionID] -= 1
} else {
logger.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
} }
t.concurrentRequestsLock.Unlock() t.concurrentRequestsLock.Unlock()

View File

@ -51,6 +51,7 @@ func NewSupervisor(config *TunnelConfig) *Supervisor {
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error { func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal); err != nil { if err := s.initialize(ctx, connectedSignal); err != nil {
return err return err
} }
@ -119,6 +120,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
} }
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
logger := s.config.Logger
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
if err != nil { if err != nil {
logger.Infof("ResolveEdgeIPs err") logger.Infof("ResolveEdgeIPs err")

View File

@ -9,10 +9,10 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
@ -27,8 +27,6 @@ import (
rpc "zombiezen.com/go/capnproto2/rpc" rpc "zombiezen.com/go/capnproto2/rpc"
) )
var logger *logrus.Logger
const ( const (
dialTimeout = 15 * time.Second dialTimeout = 15 * time.Second
@ -76,12 +74,28 @@ func (e dupConnRegisterTunnelError) Error() string {
return "already connected to this server" return "already connected to this server"
} }
type printableRegisterTunnelError struct { type muxerShutdownError struct{}
func (e muxerShutdownError) Error() string {
return "muxer shutdown"
}
// RegisterTunnel error from server
type serverRegisterTunnelError struct {
cause error cause error
permanent bool permanent bool
} }
func (e printableRegisterTunnelError) Error() string { func (e serverRegisterTunnelError) Error() string {
return e.cause.Error()
}
// RegisterTunnel error from client
type clientRegisterTunnelError struct {
cause error
}
func (e clientRegisterTunnelError) Error() string {
return e.cause.Error() return e.cause.Error()
} }
@ -105,7 +119,6 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
} }
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
logger = config.Logger
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { go func() {
<-shutdownC <-shutdownC
@ -129,6 +142,7 @@ func ServeTunnelLoop(ctx context.Context,
connectionID uint8, connectionID uint8,
connectedSignal chan struct{}, connectedSignal chan struct{},
) error { ) error {
logger := config.Logger
config.Metrics.incrementHaConnections() config.Metrics.incrementHaConnections()
defer config.Metrics.decrementHaConnections() defer config.Metrics.decrementHaConnections()
backoff := BackoffHandler{MaxRetries: config.Retries} backoff := BackoffHandler{MaxRetries: config.Retries}
@ -162,8 +176,6 @@ func ServeTunnel(
connectedFuse *h2mux.BooleanFuse, connectedFuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *BackoffHandler,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
var wg sync.WaitGroup
wg.Add(2)
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -175,8 +187,16 @@ func ServeTunnel(
recoverable = true recoverable = true
} }
}() }()
connectionTag := uint8ToString(connectionID)
logger := config.Logger.WithField("connectionID", connectionTag)
// additional tags to send other than hostname which is set in cloudflared main package
tags := make(map[string]string)
tags["ha"] = connectionTag
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) handler, originLocalIP, err := NewTunnelHandler(ctx, logger, config, addr.String(), connectionID)
if err != nil { if err != nil {
errLog := logger.WithError(err) errLog := logger.WithError(err)
switch err.(type) { switch err.(type) {
@ -190,59 +210,69 @@ func ServeTunnel(
} }
return err, true return err, true
} }
serveCtx, serveCancel := context.WithCancel(ctx)
registerErrC := make(chan error, 1) errGroup, serveCtx := errgroup.WithContext(ctx)
go func() {
defer wg.Done() errGroup.Go(func() error {
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP) err := RegisterTunnel(serveCtx, logger, handler.muxer, config, connectionID, originLocalIP)
if err == nil { if err == nil {
connectedFuse.Fuse(true) connectedFuse.Fuse(true)
backoff.SetGracePeriod() backoff.SetGracePeriod()
} else {
serveCancel()
} }
registerErrC <- err return err
}() })
errGroup.Go(func() error {
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq) updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
go func() {
defer wg.Done()
connectionTag := uint8ToString(connectionID)
for { for {
select { select {
case <-serveCtx.Done(): case <-serveCtx.Done():
// UnregisterTunnel blocks until the RPC call returns // UnregisterTunnel blocks until the RPC call returns
UnregisterTunnel(handler.muxer, config.GracePeriod) err := UnregisterTunnel(logger, handler.muxer, config.GracePeriod)
handler.muxer.Shutdown() handler.muxer.Shutdown()
return return err
case <-updateMetricsTickC: case <-updateMetricsTickC:
handler.UpdateMetrics(connectionTag) handler.UpdateMetrics(connectionTag)
} }
} }
}() })
err = handler.muxer.Serve() errGroup.Go(func() error {
serveCancel() // All routines should stop when muxer finish serving. When muxer is shutdown
registerErr := <-registerErrC // gracefully, it doesn't return an error, so we need to return errMuxerShutdown
wg.Wait() // here to notify other routines to stop
err := handler.muxer.Serve(serveCtx);
if err == nil {
return muxerShutdownError{}
}
return err
})
err = errGroup.Wait()
if err != nil { if err != nil {
logger.WithError(err).Error("Tunnel error") switch castedErr := err.(type) {
return err, true case dupConnRegisterTunnelError:
}
if registerErr != nil {
// Don't retry on errors like entitlement failure or version too old
if e, ok := registerErr.(printableRegisterTunnelError); ok {
logger.Error(e)
return e.cause, !e.permanent
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
logger.Info("Already connected to this server, selecting a different one") logger.Info("Already connected to this server, selecting a different one")
return e, true return err, true
} case serverRegisterTunnelError:
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes logger.WithError(castedErr.cause).Error("Register tunnel error from server side")
raven.CaptureError(registerErr, nil) // Don't send registration error return from server to Sentry. They are
logger.Error("Cannot register") // logged on server side
return castedErr.cause, !castedErr.permanent
case clientRegisterTunnelError:
logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
raven.CaptureErrorAndWait(castedErr.cause, tags)
return err, true
case muxerShutdownError:
logger.Infof("Muxer shutdown")
return err, true
default:
logger.WithError(err).Error("Serve tunnel error")
raven.CaptureErrorAndWait(err, tags)
return err, true return err, true
} }
return nil, false }
return nil, true
} }
func IsRPCStreamResponse(headers []h2mux.Header) bool { func IsRPCStreamResponse(headers []h2mux.Header) bool {
@ -255,8 +285,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool {
return true return true
} }
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { func RegisterTunnel(ctx context.Context, logger *logrus.Entry, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
logger := logger.WithField("subsystem", "rpc")
logger.Debug("initiating RPC stream to register") logger.Debug("initiating RPC stream to register")
stream, err := muxer.OpenStream([]h2mux.Header{ stream, err := muxer.OpenStream([]h2mux.Header{
{Name: ":method", Value: "RPC"}, {Name: ":method", Value: "RPC"},
@ -265,16 +294,14 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
}, nil) }, nil)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
raven.CaptureError(err, nil) return clientRegisterTunnelError{cause: err}
return err
} }
if !IsRPCStreamResponse(stream.Headers) { if !IsRPCStreamResponse(stream.Headers) {
// stream response error // stream response error
raven.CaptureError(err, nil) return clientRegisterTunnelError{cause: err}
return err
} }
conn := rpc.NewConn( conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
) )
defer conn.Close() defer conn.Close()
@ -293,7 +320,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
LogServerInfo(logger, serverInfoPromise.Result(), connectionID, config.Metrics) LogServerInfo(logger, serverInfoPromise.Result(), connectionID, config.Metrics)
if err != nil { if err != nil {
// RegisterTunnel RPC failure // RegisterTunnel RPC failure
return err return clientRegisterTunnelError{cause: err}
} }
for _, logLine := range registration.LogLines { for _, logLine := range registration.LogLines {
logger.Info(logLine) logger.Info(logLine)
@ -301,7 +328,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
if registration.Err == DuplicateConnectionError { if registration.Err == DuplicateConnectionError {
return dupConnRegisterTunnelError{} return dupConnRegisterTunnelError{}
} else if registration.Err != "" { } else if registration.Err != "" {
return printableRegisterTunnelError{ return serverRegisterTunnelError{
cause: fmt.Errorf("Server error: %s", registration.Err), cause: fmt.Errorf("Server error: %s", registration.Err),
permanent: registration.PermanentFailure, permanent: registration.PermanentFailure,
} }
@ -310,8 +337,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
return nil return nil
} }
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration) error { func UnregisterTunnel(logger *logrus.Entry, muxer *h2mux.Muxer, gracePeriod time.Duration) error {
logger := logger.WithField("subsystem", "rpc")
logger.Debug("initiating RPC stream to unregister") logger.Debug("initiating RPC stream to unregister")
stream, err := muxer.OpenStream([]h2mux.Header{ stream, err := muxer.OpenStream([]h2mux.Header{
{Name: ":method", Value: "RPC"}, {Name: ":method", Value: "RPC"},
@ -320,17 +346,15 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration) error {
}, nil) }, nil)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
raven.CaptureError(err, nil)
return err return err
} }
if !IsRPCStreamResponse(stream.Headers) { if !IsRPCStreamResponse(stream.Headers) {
// stream response error // stream response error
raven.CaptureError(err, nil)
return err return err
} }
ctx := context.Background() ctx := context.Background()
conn := rpc.NewConn( conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
) )
defer conn.Close() defer conn.Close()
@ -408,12 +432,18 @@ type TunnelHandler struct {
metrics *TunnelMetrics metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string connectionID string
logger *logrus.Entry
} }
var dialer = net.Dialer{DualStack: true} var dialer = net.Dialer{DualStack: true}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, connectionID uint8) (*TunnelHandler, string, error) { func NewTunnelHandler(ctx context.Context,
logger *logrus.Entry,
config *TunnelConfig,
addr string,
connectionID uint8,
) (*TunnelHandler, string, error) {
originURL, err := validation.ValidateUrl(config.OriginUrl) originURL, err := validation.ValidateUrl(config.OriginUrl)
if err != nil { if err != nil {
return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL) return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL)
@ -425,6 +455,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
tags: config.Tags, tags: config.Tags,
metrics: config.Metrics, metrics: config.Metrics,
connectionID: uint8ToString(connectionID), connectionID: uint8ToString(connectionID),
logger: logger,
} }
if h.httpClient == nil { if h.httpClient == nil {
h.httpClient = http.DefaultTransport h.httpClient = http.DefaultTransport
@ -471,11 +502,11 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID) h.metrics.incrementRequests(h.connectionID)
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil { if err != nil {
logger.WithError(err).Panic("Unexpected error from http.NewRequest") h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
} }
err = H2RequestHeadersToH1Request(stream.Headers, req) err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil { if err != nil {
logger.WithError(err).Error("invalid request received") h.logger.WithError(err).Error("invalid request received")
} }
h.AppendTagHeaders(req) h.AppendTagHeaders(req)
cfRay := FindCfRayHeader(req) cfRay := FindCfRayHeader(req)
@ -510,7 +541,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
} }
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
logger.WithError(err).Error("HTTP request error") h.logger.WithError(err).Error("HTTP request error")
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
stream.Write([]byte("502 Bad Gateway")) stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502") h.metrics.incrementResponses(h.connectionID, "502")
@ -518,20 +549,20 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) { func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
if cfRay != "" { if cfRay != "" {
logger.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto) h.logger.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
} else { } else {
logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto) h.logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
} }
logger.Debugf("Request Headers %+v", req.Header) h.logger.Debugf("Request Headers %+v", req.Header)
} }
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) { func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
if cfRay != "" { if cfRay != "" {
logger.WithField("CF-RAY", cfRay).Infof("%s", r.Status) h.logger.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
} else { } else {
logger.Infof("%s", r.Status) h.logger.Infof("%s", r.Status)
} }
logger.Debugf("Response Headers %+v", r.Header) h.logger.Debugf("Response Headers %+v", r.Header)
} }
func (h *TunnelHandler) UpdateMetrics(connectionID string) { func (h *TunnelHandler) UpdateMetrics(connectionID string) {