Release Argo Tunnel Client 2018.5.7
This commit is contained in:
parent
7acd1b1fc8
commit
200ea2bfc6
|
@ -8,6 +8,6 @@ import (
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runApp(app *cli.App, shutdownC chan struct{}) {
|
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
|
||||||
app.Run(os.Args)
|
app.Run(os.Args)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runApp(app *cli.App, shutdownC chan struct{}) {
|
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
|
||||||
app.Commands = append(app.Commands, &cli.Command{
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
Name: "service",
|
Name: "service",
|
||||||
Usage: "Manages the Argo Tunnel system service",
|
Usage: "Manages the Argo Tunnel system service",
|
||||||
|
|
|
@ -15,7 +15,7 @@ const (
|
||||||
launchdIdentifier = "com.cloudflare.cloudflared"
|
launchdIdentifier = "com.cloudflare.cloudflared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runApp(app *cli.App, shutdownC chan struct{}) {
|
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
|
||||||
app.Commands = append(app.Commands, &cli.Command{
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
Name: "service",
|
Name: "service",
|
||||||
Usage: "Manages the Argo Tunnel launch agent",
|
Usage: "Manages the Argo Tunnel launch agent",
|
||||||
|
|
|
@ -40,9 +40,12 @@ func main() {
|
||||||
raven.SetDSN(sentryDSN)
|
raven.SetDSN(sentryDSN)
|
||||||
raven.SetRelease(Version)
|
raven.SetRelease(Version)
|
||||||
|
|
||||||
// Shutdown channel used by the app. When closed, app must terminate.
|
// Force shutdown channel used by the app. When closed, app must terminate.
|
||||||
// May be closed by the Windows service runner.
|
// Windows service manager closes this channel when it receives shutdown command.
|
||||||
shutdownC := make(chan struct{})
|
shutdownC := make(chan struct{})
|
||||||
|
// Graceful shutdown channel used by the app. When closed, app must terminate.
|
||||||
|
// Windows service manager closes this channel when it receives stop command.
|
||||||
|
graceShutdownC := make(chan struct{})
|
||||||
|
|
||||||
app := &cli.App{}
|
app := &cli.App{}
|
||||||
app.Name = "cloudflared"
|
app.Name = "cloudflared"
|
||||||
|
@ -280,7 +283,7 @@ func main() {
|
||||||
tags := make(map[string]string)
|
tags := make(map[string]string)
|
||||||
tags["hostname"] = c.String("hostname")
|
tags["hostname"] = c.String("hostname")
|
||||||
raven.SetTagsContext(tags)
|
raven.SetTagsContext(tags)
|
||||||
raven.CapturePanicAndWait(func() { err = startServer(c, shutdownC) }, nil)
|
raven.CapturePanicAndWait(func() { err = startServer(c, shutdownC, graceShutdownC) }, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
raven.CaptureErrorAndWait(err, nil)
|
raven.CaptureErrorAndWait(err, nil)
|
||||||
}
|
}
|
||||||
|
@ -374,16 +377,15 @@ func main() {
|
||||||
ArgsUsage: " ", // can't be the empty string or we get the default output
|
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
runApp(app, shutdownC)
|
runApp(app, shutdownC, graceShutdownC)
|
||||||
}
|
}
|
||||||
|
|
||||||
func startServer(c *cli.Context, shutdownC chan struct{}) error {
|
func startServer(c *cli.Context, shutdownC, graceShutdownC chan struct{}) error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
listeners := gracenet.Net{}
|
listeners := gracenet.Net{}
|
||||||
errC := make(chan error)
|
errC := make(chan error)
|
||||||
connectedSignal := make(chan struct{})
|
connectedSignal := make(chan struct{})
|
||||||
dnsReadySignal := make(chan struct{})
|
dnsReadySignal := make(chan struct{})
|
||||||
graceShutdownSignal := make(chan struct{})
|
|
||||||
|
|
||||||
// check whether client provides enough flags or env variables. If not, print help.
|
// check whether client provides enough flags or env variables. If not, print help.
|
||||||
if ok := enoughOptionsSet(c); !ok {
|
if ok := enoughOptionsSet(c); !ok {
|
||||||
|
@ -457,7 +459,7 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error {
|
||||||
if dnsProxyStandAlone(c) {
|
if dnsProxyStandAlone(c) {
|
||||||
close(connectedSignal)
|
close(connectedSignal)
|
||||||
// no grace period, handle SIGINT/SIGTERM immediately
|
// no grace period, handle SIGINT/SIGTERM immediately
|
||||||
return waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, 0)
|
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.IsSet("hello-world") {
|
if c.IsSet("hello-world") {
|
||||||
|
@ -483,23 +485,23 @@ func startServer(c *cli.Context, shutdownC chan struct{}) error {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownSignal, connectedSignal)
|
errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownC, connectedSignal)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, c.Duration("grace-period"))
|
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitToShutdown(wg *sync.WaitGroup,
|
func waitToShutdown(wg *sync.WaitGroup,
|
||||||
errC chan error,
|
errC chan error,
|
||||||
shutdownC, graceShutdownSignal chan struct{},
|
shutdownC, graceShutdownC chan struct{},
|
||||||
gracePeriod time.Duration,
|
gracePeriod time.Duration,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
if gracePeriod > 0 {
|
if gracePeriod > 0 {
|
||||||
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownSignal, gracePeriod)
|
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownC, gracePeriod)
|
||||||
} else {
|
} else {
|
||||||
err = waitForSignal(errC, shutdownC)
|
err = waitForSignal(errC, shutdownC)
|
||||||
close(graceShutdownSignal)
|
close(graceShutdownC)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -7,6 +7,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// waitForSignal notifies all routines to shutdownC immediately by closing the
|
||||||
|
// shutdownC when one of the routines in main exits, or when this process receives
|
||||||
|
// SIGTERM/SIGINT
|
||||||
func waitForSignal(errC chan error, shutdownC chan struct{}) error {
|
func waitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||||
signals := make(chan os.Signal, 10)
|
signals := make(chan os.Signal, 10)
|
||||||
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
@ -23,18 +26,44 @@ func waitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSignal chan struct{}, gracePeriod time.Duration) error {
|
// waitForSignalWithGraceShutdown notifies all routines to shutdown immediately
|
||||||
|
// by closing the shutdownC when one of the routines in main exits.
|
||||||
|
// When this process recieves SIGTERM/SIGINT, it closes the graceShutdownC to
|
||||||
|
// notify certain routines to start graceful shutdown. When grace period is over,
|
||||||
|
// or when some routine exits, it notifies the rest of the routines to shutdown
|
||||||
|
// immediately by closing shutdownC.
|
||||||
|
// In the case of handling commands from Windows Service Manager, closing graceShutdownC
|
||||||
|
// initiate graceful shutdown.
|
||||||
|
func waitForSignalWithGraceShutdown(errC chan error,
|
||||||
|
shutdownC, graceShutdownC chan struct{},
|
||||||
|
gracePeriod time.Duration,
|
||||||
|
) error {
|
||||||
signals := make(chan os.Signal, 10)
|
signals := make(chan os.Signal, 10)
|
||||||
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||||
defer signal.Stop(signals)
|
defer signal.Stop(signals)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errC:
|
case err := <-errC:
|
||||||
close(graceShutdownSignal)
|
close(graceShutdownC)
|
||||||
close(shutdownC)
|
close(shutdownC)
|
||||||
return err
|
return err
|
||||||
case <-signals:
|
case <-signals:
|
||||||
close(graceShutdownSignal)
|
close(graceShutdownC)
|
||||||
|
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
|
||||||
|
case <-graceShutdownC:
|
||||||
|
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
|
||||||
|
case <-shutdownC:
|
||||||
|
close(graceShutdownC)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForGracePeriod(signals chan os.Signal,
|
||||||
|
errC chan error,
|
||||||
|
shutdownC chan struct{},
|
||||||
|
gracePeriod time.Duration,
|
||||||
|
) {
|
||||||
logger.Infof("Initiating graceful shutdown...")
|
logger.Infof("Initiating graceful shutdown...")
|
||||||
// Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
|
// Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
|
||||||
// to force shutdown cloudflared
|
// to force shutdown cloudflared
|
||||||
|
@ -47,9 +76,4 @@ func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSig
|
||||||
case <-errC:
|
case <-errC:
|
||||||
}
|
}
|
||||||
close(shutdownC)
|
close(shutdownC)
|
||||||
case <-shutdownC:
|
|
||||||
close(graceShutdownSignal)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,6 +89,15 @@ func TestWaitForSignalWithGraceShutdown(t *testing.T) {
|
||||||
testChannelClosed(t, shutdownC)
|
testChannelClosed(t, shutdownC)
|
||||||
testChannelClosed(t, graceshutdownC)
|
testChannelClosed(t, graceshutdownC)
|
||||||
|
|
||||||
|
// graceshutdownC closed, shutdownC should also be closed and no error
|
||||||
|
errC = make(chan error)
|
||||||
|
shutdownC = make(chan struct{})
|
||||||
|
graceshutdownC = make(chan struct{})
|
||||||
|
close(graceshutdownC)
|
||||||
|
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
testChannelClosed(t, shutdownC)
|
||||||
|
testChannelClosed(t, graceshutdownC)
|
||||||
|
|
||||||
// Test handling SIGTERM & SIGINT
|
// Test handling SIGTERM & SIGINT
|
||||||
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
|
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
|
||||||
|
|
|
@ -31,7 +31,7 @@ const (
|
||||||
serviceConfigFailureActionsFlag = 4
|
serviceConfigFailureActionsFlag = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
func runApp(app *cli.App, shutdownC chan struct{}) {
|
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
|
||||||
app.Commands = append(app.Commands, &cli.Command{
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
Name: "service",
|
Name: "service",
|
||||||
Usage: "Manages the Argo Tunnel Windows service",
|
Usage: "Manages the Argo Tunnel Windows service",
|
||||||
|
@ -70,7 +70,7 @@ func runApp(app *cli.App, shutdownC chan struct{}) {
|
||||||
// Run executes service name by calling windowsService which is a Handler
|
// Run executes service name by calling windowsService which is a Handler
|
||||||
// interface that implements Execute method.
|
// interface that implements Execute method.
|
||||||
// It will set service status to stop after Execute returns
|
// It will set service status to stop after Execute returns
|
||||||
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC})
|
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC, graceShutdownC: graceShutdownC})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
|
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
|
||||||
return
|
return
|
||||||
|
@ -82,6 +82,7 @@ type windowsService struct {
|
||||||
app *cli.App
|
app *cli.App
|
||||||
elog *eventlog.Log
|
elog *eventlog.Log
|
||||||
shutdownC chan struct{}
|
shutdownC chan struct{}
|
||||||
|
graceShutdownC chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// called by the package code at the start of the service
|
// called by the package code at the start of the service
|
||||||
|
@ -103,7 +104,7 @@ func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, stat
|
||||||
statusChan <- c.CurrentStatus
|
statusChan <- c.CurrentStatus
|
||||||
case svc.Stop:
|
case svc.Stop:
|
||||||
s.elog.Info(1, "received stop control request")
|
s.elog.Info(1, "received stop control request")
|
||||||
close(s.shutdownC)
|
close(s.graceShutdownC)
|
||||||
statusChan <- svc.Status{State: svc.StopPending}
|
statusChan <- svc.Status{State: svc.StopPending}
|
||||||
case svc.Shutdown:
|
case svc.Shutdown:
|
||||||
s.elog.Info(1, "received shutdown control request")
|
s.elog.Info(1, "received shutdown control request")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package h2mux
|
package h2mux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -256,31 +258,31 @@ func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Muxer) Serve() error {
|
func (m *Muxer) Serve(ctx context.Context) error {
|
||||||
logger := m.config.Logger.WithField("name", m.config.Name)
|
logger := m.config.Logger.WithField("name", m.config.Name)
|
||||||
errChan := make(chan error)
|
errGroup, _ := errgroup.WithContext(ctx)
|
||||||
go func() {
|
errGroup.Go(func() error {
|
||||||
errChan <- m.muxReader.run(logger)
|
err := m.muxReader.run(logger)
|
||||||
m.explicitShutdown.Fuse(false)
|
m.explicitShutdown.Fuse(false)
|
||||||
m.r.Close()
|
m.r.Close()
|
||||||
m.abort()
|
m.abort()
|
||||||
}()
|
return err
|
||||||
go func() {
|
})
|
||||||
errChan <- m.muxWriter.run(logger)
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
err := m.muxWriter.run(logger)
|
||||||
m.explicitShutdown.Fuse(false)
|
m.explicitShutdown.Fuse(false)
|
||||||
m.w.Close()
|
m.w.Close()
|
||||||
m.abort()
|
m.abort()
|
||||||
}()
|
return err
|
||||||
go func() {
|
})
|
||||||
errChan <- m.muxMetricsUpdater.run(logger)
|
|
||||||
}()
|
errGroup.Go(func() error {
|
||||||
err := <-errChan
|
err := m.muxMetricsUpdater.run(logger)
|
||||||
go func() {
|
return err
|
||||||
// discard errors as other handler and muxMetricsUpdater close
|
})
|
||||||
<-errChan
|
|
||||||
<-errChan
|
err := errGroup.Wait()
|
||||||
close(errChan)
|
|
||||||
}()
|
|
||||||
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
|
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package h2mux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -78,11 +79,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
|
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
p.Handshake(t)
|
p.Handshake(t)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
err := p.EdgeMux.Serve()
|
err := p.EdgeMux.Serve(ctx)
|
||||||
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||||
t.Errorf("error in edge muxer Serve(): %s", err)
|
t.Errorf("error in edge muxer Serve(): %s", err)
|
||||||
}
|
}
|
||||||
|
@ -90,7 +92,7 @@ func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
err := p.OriginMux.Serve()
|
err := p.OriginMux.Serve(ctx)
|
||||||
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||||
t.Errorf("error in origin muxer Serve(): %s", err)
|
t.Errorf("error in origin muxer Serve(): %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -360,8 +360,6 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||||
t.concurrentRequestsLock.Lock()
|
t.concurrentRequestsLock.Lock()
|
||||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||||
t.concurrentRequests[connectionID] -= 1
|
t.concurrentRequests[connectionID] -= 1
|
||||||
} else {
|
|
||||||
logger.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
|
||||||
}
|
}
|
||||||
t.concurrentRequestsLock.Unlock()
|
t.concurrentRequestsLock.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ func NewSupervisor(config *TunnelConfig) *Supervisor {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
|
||||||
|
logger := s.config.Logger
|
||||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -119,6 +120,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||||
|
logger := s.config.Logger
|
||||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Infof("ResolveEdgeIPs err")
|
logger.Infof("ResolveEdgeIPs err")
|
||||||
|
|
165
origin/tunnel.go
165
origin/tunnel.go
|
@ -9,10 +9,10 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
@ -27,8 +27,6 @@ import (
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger *logrus.Logger
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
@ -76,12 +74,28 @@ func (e dupConnRegisterTunnelError) Error() string {
|
||||||
return "already connected to this server"
|
return "already connected to this server"
|
||||||
}
|
}
|
||||||
|
|
||||||
type printableRegisterTunnelError struct {
|
type muxerShutdownError struct{}
|
||||||
|
|
||||||
|
func (e muxerShutdownError) Error() string {
|
||||||
|
return "muxer shutdown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTunnel error from server
|
||||||
|
type serverRegisterTunnelError struct {
|
||||||
cause error
|
cause error
|
||||||
permanent bool
|
permanent bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e printableRegisterTunnelError) Error() string {
|
func (e serverRegisterTunnelError) Error() string {
|
||||||
|
return e.cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTunnel error from client
|
||||||
|
type clientRegisterTunnelError struct {
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e clientRegisterTunnelError) Error() string {
|
||||||
return e.cause.Error()
|
return e.cause.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +119,6 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
|
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
|
||||||
logger = config.Logger
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
go func() {
|
go func() {
|
||||||
<-shutdownC
|
<-shutdownC
|
||||||
|
@ -129,6 +142,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
connectedSignal chan struct{},
|
connectedSignal chan struct{},
|
||||||
) error {
|
) error {
|
||||||
|
logger := config.Logger
|
||||||
config.Metrics.incrementHaConnections()
|
config.Metrics.incrementHaConnections()
|
||||||
defer config.Metrics.decrementHaConnections()
|
defer config.Metrics.decrementHaConnections()
|
||||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||||
|
@ -162,8 +176,6 @@ func ServeTunnel(
|
||||||
connectedFuse *h2mux.BooleanFuse,
|
connectedFuse *h2mux.BooleanFuse,
|
||||||
backoff *BackoffHandler,
|
backoff *BackoffHandler,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
|
@ -175,8 +187,16 @@ func ServeTunnel(
|
||||||
recoverable = true
|
recoverable = true
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
connectionTag := uint8ToString(connectionID)
|
||||||
|
logger := config.Logger.WithField("connectionID", connectionTag)
|
||||||
|
|
||||||
|
// additional tags to send other than hostname which is set in cloudflared main package
|
||||||
|
tags := make(map[string]string)
|
||||||
|
tags["ha"] = connectionTag
|
||||||
|
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
handler, originLocalIP, err := NewTunnelHandler(ctx, logger, config, addr.String(), connectionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errLog := logger.WithError(err)
|
errLog := logger.WithError(err)
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
|
@ -190,59 +210,69 @@ func ServeTunnel(
|
||||||
}
|
}
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
serveCtx, serveCancel := context.WithCancel(ctx)
|
|
||||||
registerErrC := make(chan error, 1)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
errGroup.Go(func() error {
|
||||||
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
|
err := RegisterTunnel(serveCtx, logger, handler.muxer, config, connectionID, originLocalIP)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
connectedFuse.Fuse(true)
|
connectedFuse.Fuse(true)
|
||||||
backoff.SetGracePeriod()
|
backoff.SetGracePeriod()
|
||||||
} else {
|
|
||||||
serveCancel()
|
|
||||||
}
|
}
|
||||||
registerErrC <- err
|
return err
|
||||||
}()
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
connectionTag := uint8ToString(connectionID)
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-serveCtx.Done():
|
case <-serveCtx.Done():
|
||||||
// UnregisterTunnel blocks until the RPC call returns
|
// UnregisterTunnel blocks until the RPC call returns
|
||||||
UnregisterTunnel(handler.muxer, config.GracePeriod)
|
err := UnregisterTunnel(logger, handler.muxer, config.GracePeriod)
|
||||||
handler.muxer.Shutdown()
|
handler.muxer.Shutdown()
|
||||||
return
|
return err
|
||||||
case <-updateMetricsTickC:
|
case <-updateMetricsTickC:
|
||||||
handler.UpdateMetrics(connectionTag)
|
handler.UpdateMetrics(connectionTag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
|
|
||||||
err = handler.muxer.Serve()
|
errGroup.Go(func() error {
|
||||||
serveCancel()
|
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||||
registerErr := <-registerErrC
|
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||||
wg.Wait()
|
// here to notify other routines to stop
|
||||||
|
err := handler.muxer.Serve(serveCtx);
|
||||||
|
if err == nil {
|
||||||
|
return muxerShutdownError{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
err = errGroup.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Tunnel error")
|
switch castedErr := err.(type) {
|
||||||
return err, true
|
case dupConnRegisterTunnelError:
|
||||||
}
|
|
||||||
if registerErr != nil {
|
|
||||||
// Don't retry on errors like entitlement failure or version too old
|
|
||||||
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
|
||||||
logger.Error(e)
|
|
||||||
return e.cause, !e.permanent
|
|
||||||
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
|
|
||||||
logger.Info("Already connected to this server, selecting a different one")
|
logger.Info("Already connected to this server, selecting a different one")
|
||||||
return e, true
|
return err, true
|
||||||
}
|
case serverRegisterTunnelError:
|
||||||
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
|
logger.WithError(castedErr.cause).Error("Register tunnel error from server side")
|
||||||
raven.CaptureError(registerErr, nil)
|
// Don't send registration error return from server to Sentry. They are
|
||||||
logger.Error("Cannot register")
|
// logged on server side
|
||||||
|
return castedErr.cause, !castedErr.permanent
|
||||||
|
case clientRegisterTunnelError:
|
||||||
|
logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
|
||||||
|
raven.CaptureErrorAndWait(castedErr.cause, tags)
|
||||||
|
return err, true
|
||||||
|
case muxerShutdownError:
|
||||||
|
logger.Infof("Muxer shutdown")
|
||||||
|
return err, true
|
||||||
|
default:
|
||||||
|
logger.WithError(err).Error("Serve tunnel error")
|
||||||
|
raven.CaptureErrorAndWait(err, tags)
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
return nil, false
|
}
|
||||||
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||||
|
@ -255,8 +285,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
|
func RegisterTunnel(ctx context.Context, logger *logrus.Entry, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
|
||||||
logger := logger.WithField("subsystem", "rpc")
|
|
||||||
logger.Debug("initiating RPC stream to register")
|
logger.Debug("initiating RPC stream to register")
|
||||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||||
{Name: ":method", Value: "RPC"},
|
{Name: ":method", Value: "RPC"},
|
||||||
|
@ -265,16 +294,14 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RPC stream open error
|
// RPC stream open error
|
||||||
raven.CaptureError(err, nil)
|
return clientRegisterTunnelError{cause: err}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
if !IsRPCStreamResponse(stream.Headers) {
|
if !IsRPCStreamResponse(stream.Headers) {
|
||||||
// stream response error
|
// stream response error
|
||||||
raven.CaptureError(err, nil)
|
return clientRegisterTunnelError{cause: err}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
conn := rpc.NewConn(
|
conn := rpc.NewConn(
|
||||||
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
|
||||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
||||||
)
|
)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
@ -293,7 +320,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
LogServerInfo(logger, serverInfoPromise.Result(), connectionID, config.Metrics)
|
LogServerInfo(logger, serverInfoPromise.Result(), connectionID, config.Metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RegisterTunnel RPC failure
|
// RegisterTunnel RPC failure
|
||||||
return err
|
return clientRegisterTunnelError{cause: err}
|
||||||
}
|
}
|
||||||
for _, logLine := range registration.LogLines {
|
for _, logLine := range registration.LogLines {
|
||||||
logger.Info(logLine)
|
logger.Info(logLine)
|
||||||
|
@ -301,7 +328,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
if registration.Err == DuplicateConnectionError {
|
if registration.Err == DuplicateConnectionError {
|
||||||
return dupConnRegisterTunnelError{}
|
return dupConnRegisterTunnelError{}
|
||||||
} else if registration.Err != "" {
|
} else if registration.Err != "" {
|
||||||
return printableRegisterTunnelError{
|
return serverRegisterTunnelError{
|
||||||
cause: fmt.Errorf("Server error: %s", registration.Err),
|
cause: fmt.Errorf("Server error: %s", registration.Err),
|
||||||
permanent: registration.PermanentFailure,
|
permanent: registration.PermanentFailure,
|
||||||
}
|
}
|
||||||
|
@ -310,8 +337,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration) error {
|
func UnregisterTunnel(logger *logrus.Entry, muxer *h2mux.Muxer, gracePeriod time.Duration) error {
|
||||||
logger := logger.WithField("subsystem", "rpc")
|
|
||||||
logger.Debug("initiating RPC stream to unregister")
|
logger.Debug("initiating RPC stream to unregister")
|
||||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||||
{Name: ":method", Value: "RPC"},
|
{Name: ":method", Value: "RPC"},
|
||||||
|
@ -320,17 +346,15 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration) error {
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RPC stream open error
|
// RPC stream open error
|
||||||
raven.CaptureError(err, nil)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !IsRPCStreamResponse(stream.Headers) {
|
if !IsRPCStreamResponse(stream.Headers) {
|
||||||
// stream response error
|
// stream response error
|
||||||
raven.CaptureError(err, nil)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn := rpc.NewConn(
|
conn := rpc.NewConn(
|
||||||
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
|
||||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
||||||
)
|
)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
@ -408,12 +432,18 @@ type TunnelHandler struct {
|
||||||
metrics *TunnelMetrics
|
metrics *TunnelMetrics
|
||||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||||
connectionID string
|
connectionID string
|
||||||
|
logger *logrus.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialer = net.Dialer{DualStack: true}
|
var dialer = net.Dialer{DualStack: true}
|
||||||
|
|
||||||
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
||||||
func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, connectionID uint8) (*TunnelHandler, string, error) {
|
func NewTunnelHandler(ctx context.Context,
|
||||||
|
logger *logrus.Entry,
|
||||||
|
config *TunnelConfig,
|
||||||
|
addr string,
|
||||||
|
connectionID uint8,
|
||||||
|
) (*TunnelHandler, string, error) {
|
||||||
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL)
|
return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL)
|
||||||
|
@ -425,6 +455,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
|
||||||
tags: config.Tags,
|
tags: config.Tags,
|
||||||
metrics: config.Metrics,
|
metrics: config.Metrics,
|
||||||
connectionID: uint8ToString(connectionID),
|
connectionID: uint8ToString(connectionID),
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
if h.httpClient == nil {
|
if h.httpClient == nil {
|
||||||
h.httpClient = http.DefaultTransport
|
h.httpClient = http.DefaultTransport
|
||||||
|
@ -471,11 +502,11 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
h.metrics.incrementRequests(h.connectionID)
|
h.metrics.incrementRequests(h.connectionID)
|
||||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Panic("Unexpected error from http.NewRequest")
|
h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||||
}
|
}
|
||||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("invalid request received")
|
h.logger.WithError(err).Error("invalid request received")
|
||||||
}
|
}
|
||||||
h.AppendTagHeaders(req)
|
h.AppendTagHeaders(req)
|
||||||
cfRay := FindCfRayHeader(req)
|
cfRay := FindCfRayHeader(req)
|
||||||
|
@ -510,7 +541,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
||||||
logger.WithError(err).Error("HTTP request error")
|
h.logger.WithError(err).Error("HTTP request error")
|
||||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||||
stream.Write([]byte("502 Bad Gateway"))
|
stream.Write([]byte("502 Bad Gateway"))
|
||||||
h.metrics.incrementResponses(h.connectionID, "502")
|
h.metrics.incrementResponses(h.connectionID, "502")
|
||||||
|
@ -518,20 +549,20 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
||||||
|
|
||||||
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
|
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
logger.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
|
h.logger.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
|
h.logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
|
||||||
}
|
}
|
||||||
logger.Debugf("Request Headers %+v", req.Header)
|
h.logger.Debugf("Request Headers %+v", req.Header)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
|
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
logger.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
|
h.logger.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
|
||||||
} else {
|
} else {
|
||||||
logger.Infof("%s", r.Status)
|
h.logger.Infof("%s", r.Status)
|
||||||
}
|
}
|
||||||
logger.Debugf("Response Headers %+v", r.Header)
|
h.logger.Debugf("Response Headers %+v", r.Header)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
|
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
|
||||||
|
|
Loading…
Reference in New Issue