Merge branch 'master' into master
This commit is contained in:
commit
9042025902
|
@ -1,3 +1,19 @@
|
|||
2019.11.3
|
||||
- 2019-11-20 TUN-2562: Update Cloudflare Origin CA RSA root
|
||||
|
||||
2019.11.2
|
||||
- 2019-11-18 TUN-2567: AuthOutcome can be turned back into AuthResponse
|
||||
- 2019-11-18 TUN-2563: Exposes config_version metrics
|
||||
|
||||
2019.11.1
|
||||
- 2019-11-12 Add db-connect, a SQL over HTTPS server
|
||||
- 2019-11-12 TUN-2053: Add a /healthcheck endpoint to the metrics server
|
||||
- 2019-11-13 TUN-2178: public API to create new h2mux.MuxedStreamRequest
|
||||
- 2019-11-13 TUN-2490: respect original representation of HTTP request path
|
||||
- 2019-11-18 TUN-2547: TunnelRPC definitions for Authenticate flow
|
||||
- 2019-11-18 TUN-2551: TunnelRPC definitions for ReconnectTunnel flow
|
||||
- 2019-11-05 TUN-2506: Expose active streams metrics
|
||||
|
||||
2019.11.0
|
||||
- 2019-11-04 TUN-2502: Switch to go modules
|
||||
- 2019-11-04 TUN-2500: Don't send client registration errors to Sentry
|
||||
|
|
|
@ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
EnvVars: []string{"TUNNEL_INTENT"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "use-reconnect-token",
|
||||
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",
|
||||
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: "dial-edge-timeout",
|
||||
Usage: "Maximum wait time to set up a connection with the edge",
|
||||
|
@ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Usage: "Absolute path of directory to save SSH host keys in",
|
||||
EnvVars: []string{"HOST_KEY_PATH"},
|
||||
Hidden: true,
|
||||
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -203,11 +203,14 @@ func prepareTunnelConfig(
|
|||
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
|
||||
}
|
||||
|
||||
dialContext := (&net.Dialer{
|
||||
dialer := &net.Dialer{
|
||||
Timeout: c.Duration("proxy-connect-timeout"),
|
||||
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
||||
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
|
||||
}).DialContext
|
||||
}
|
||||
if c.Bool("proxy-no-happy-eyeballs") {
|
||||
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
|
||||
}
|
||||
dialContext := dialer.DialContext
|
||||
|
||||
if c.IsSet("unix-socket") {
|
||||
unixSocket, err := config.ValidateUnixSocket(c)
|
||||
|
@ -272,6 +275,7 @@ func prepareTunnelConfig(
|
|||
TlsConfig: toEdgeTLSConfig,
|
||||
TransportLogger: transportLogger,
|
||||
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
|
||||
UseReconnectToken: c.Bool("use-reconnect-token"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,38 +2,26 @@ package connection
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
openStreamTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type dialError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e dialError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
id uuid.UUID
|
||||
muxer *h2mux.Muxer
|
||||
}
|
||||
|
||||
func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) {
|
||||
func newConnection(muxer *h2mux.Muxer) (*Connection, error) {
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -50,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Connect is used to establish connections with cloudflare's edge network
|
||||
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) {
|
||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
rpcConn, err := c.newRPConn(openStreamCtx, logger)
|
||||
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
|
||||
tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
||||
}
|
||||
defer rpcConn.Close()
|
||||
|
||||
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
|
||||
|
||||
defer tsClient.Close()
|
||||
return tsClient.Connect(ctx, parameters)
|
||||
}
|
||||
|
||||
func (c *Connection) Shutdown() {
|
||||
c.muxer.Shutdown()
|
||||
}
|
||||
|
||||
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
|
||||
stream, err := c.muxer.OpenRPCStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
|
||||
), nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DialEdge makes a TLS connection to a Cloudflare edge node
|
||||
func DialEdge(
|
||||
ctx context.Context,
|
||||
timeout time.Duration,
|
||||
tlsConfig *tls.Config,
|
||||
edgeTCPAddr *net.TCPAddr,
|
||||
) (net.Conn, error) {
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
|
||||
defer dialCancel()
|
||||
|
||||
dialer := net.Dialer{}
|
||||
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String())
|
||||
if err != nil {
|
||||
return nil, newDialError(err, "DialContext error")
|
||||
}
|
||||
tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
|
||||
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
if err = tlsEdgeConn.Handshake(); err != nil {
|
||||
return nil, newDialError(err, "Handshake with edge error")
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
tlsEdgeConn.SetDeadline(time.Time{})
|
||||
return tlsEdgeConn, nil
|
||||
}
|
||||
|
||||
// DialError is an error returned from DialEdge
|
||||
type DialError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func newDialError(err error, message string) error {
|
||||
return DialError{cause: errors.Wrap(err, message)}
|
||||
}
|
||||
|
||||
func (e DialError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (e DialError) Cause() error {
|
||||
return e.cause
|
||||
}
|
|
@ -4,19 +4,18 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -59,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics {
|
|||
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
||||
type EdgeManagerConfigurable struct {
|
||||
TunnelHostnames []h2mux.TunnelHostname
|
||||
*pogs.EdgeConnectionConfig
|
||||
*tunnelpogs.EdgeConnectionConfig
|
||||
}
|
||||
|
||||
type CloudflaredConfig struct {
|
||||
CloudflaredID uuid.UUID
|
||||
Tags []pogs.Tag
|
||||
Tags []tunnelpogs.Tag
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
IntentLabel string
|
||||
}
|
||||
|
@ -127,13 +126,13 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
|
|||
em.state.updateConfigurable(newConfigurable)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
|
||||
edgeIP := em.serviceDiscoverer.Addr()
|
||||
edgeConn, err := em.dialEdge(ctx, edgeIP)
|
||||
func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
|
||||
edgeTCPAddr := em.serviceDiscoverer.Addr()
|
||||
configurable := em.state.getConfigurable()
|
||||
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
|
||||
if err != nil {
|
||||
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
|
||||
}
|
||||
configurable := em.state.getConfigurable()
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
|
@ -148,14 +147,14 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
|
|||
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
|
||||
}
|
||||
|
||||
h2muxConn, err := newConnection(muxer, edgeIP)
|
||||
h2muxConn, err := newConnection(muxer)
|
||||
if err != nil {
|
||||
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
|
||||
}
|
||||
|
||||
go em.serveConn(ctx, h2muxConn)
|
||||
|
||||
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
|
||||
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
|
||||
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
||||
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
||||
NumPreviousAttempts: 0,
|
||||
|
@ -196,28 +195,6 @@ func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
|
|||
em.state.closeConnection(conn)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) {
|
||||
timeout := em.state.getConfigurable().Timeout
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
|
||||
defer dialCancel()
|
||||
|
||||
dialer := net.Dialer{DualStack: true}
|
||||
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
|
||||
if err != nil {
|
||||
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
|
||||
}
|
||||
tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig)
|
||||
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
if err = tlsEdgeConn.Handshake(); err != nil {
|
||||
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
tlsEdgeConn.SetDeadline(time.Time{})
|
||||
return tlsEdgeConn, nil
|
||||
}
|
||||
|
||||
func (em *EdgeManager) noRetryMessage() string {
|
||||
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
|
||||
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
|
||||
|
@ -308,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte {
|
|||
return ems.userCredential
|
||||
}
|
||||
|
||||
func retryConnection(cause string) *pogs.ConnectError {
|
||||
return &pogs.ConnectError{
|
||||
func retryConnection(cause string) *tunnelpogs.ConnectError {
|
||||
return &tunnelpogs.ConnectError{
|
||||
Cause: cause,
|
||||
RetryAfter: defaultRetryAfter,
|
||||
ShouldRetry: true,
|
||||
|
|
|
@ -4,15 +4,15 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// NewRPCClient creates and returns a new RPC client, which will communicate
|
||||
// using a stream on the given muxer
|
||||
func NewRPCClient(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
logger *logrus.Entry,
|
||||
openStreamTimeout time.Duration,
|
||||
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
|
||||
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||
defer openStreamCancel()
|
||||
stream, err := muxer.OpenRPCStream(openStreamCtx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !isRPCStreamResponse(stream.Headers) {
|
||||
stream.Close()
|
||||
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
|
||||
return
|
||||
}
|
||||
|
||||
conn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(logger),
|
||||
)
|
||||
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func isRPCStreamResponse(headers []h2mux.Header) bool {
|
||||
return len(headers) == 1 &&
|
||||
headers[0].Name == ":status" &&
|
||||
headers[0].Value == "200"
|
||||
}
|
|
@ -13,26 +13,28 @@ type activeStreamMap struct {
|
|||
sync.RWMutex
|
||||
// streams tracks open streams.
|
||||
streams map[uint32]*MuxedStream
|
||||
// streamsEmpty is a chan that should be closed when no more streams are open.
|
||||
streamsEmpty chan struct{}
|
||||
// nextStreamID is the next ID to use on our side of the connection.
|
||||
// This is odd for clients, even for servers.
|
||||
nextStreamID uint32
|
||||
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||
maxPeerStreamID uint32
|
||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||
activeStreams prometheus.Gauge
|
||||
|
||||
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||
// cannot be registered.
|
||||
ignoreNewStreams bool
|
||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||
activeStreams prometheus.Gauge
|
||||
// streamsEmpty is a chan that will be closed when no more streams are open.
|
||||
streamsEmptyChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
||||
m := &activeStreamMap{
|
||||
streams: make(map[uint32]*MuxedStream),
|
||||
streamsEmpty: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
activeStreams: activeStreams,
|
||||
streams: make(map[uint32]*MuxedStream),
|
||||
streamsEmptyChan: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
activeStreams: activeStreams,
|
||||
}
|
||||
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
||||
if !useClientStreamNumbers {
|
||||
|
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
|
|||
return m
|
||||
}
|
||||
|
||||
func (m *activeStreamMap) notifyStreamsEmpty() {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.streamsEmptyChan)
|
||||
})
|
||||
}
|
||||
|
||||
// Len returns the number of active streams.
|
||||
func (m *activeStreamMap) Len() int {
|
||||
m.RLock()
|
||||
|
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
|
|||
delete(m.streams, streamID)
|
||||
m.activeStreams.Dec()
|
||||
}
|
||||
if len(m.streams) == 0 && m.streamsEmpty != nil {
|
||||
close(m.streamsEmpty)
|
||||
m.streamsEmpty = nil
|
||||
if len(m.streams) == 0 {
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown blocks new streams from being created. It returns a channel that receives an event
|
||||
// once the last stream has closed, or nil if a shutdown is in progress.
|
||||
func (m *activeStreamMap) Shutdown() <-chan struct{} {
|
||||
// Shutdown blocks new streams from being created.
|
||||
// It returns `done`, a channel that is closed once the last stream has closed
|
||||
// and `progress`, whether a shutdown was already in progress
|
||||
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if m.ignoreNewStreams {
|
||||
// already shutting down
|
||||
return nil
|
||||
return m.streamsEmptyChan, true
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
done := make(chan struct{})
|
||||
if len(m.streams) == 0 {
|
||||
// nothing to shut down
|
||||
close(done)
|
||||
return done
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
m.streamsEmpty = done
|
||||
return done
|
||||
return m.streamsEmptyChan, false
|
||||
}
|
||||
|
||||
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||
|
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
|
|||
stream.Close()
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{streamID: uint32(streamID)}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
shutdownChan2, alreadyInProgress2 := m.Shutdown()
|
||||
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
|
||||
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
|
||||
|
||||
// Delete all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
m.Delete(uint32(streamID))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
|
||||
}
|
||||
}
|
||||
|
||||
type noopBuffer struct {
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (t *noopBuffer) Reset() {}
|
||||
func (t *noopBuffer) Len() int { return 0 }
|
||||
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
|
||||
func (t *noopBuffer) Closed() bool { return t.isClosed }
|
||||
|
||||
type noopReadyList struct{}
|
||||
|
||||
func (_ *noopReadyList) Signal(streamID uint32) {}
|
||||
|
||||
func TestAbort(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||
|
||||
var openedStreams sync.Map
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{
|
||||
streamID: uint32(streamID),
|
||||
readBuffer: &noopBuffer{},
|
||||
writeBuffer: &noopBuffer{},
|
||||
readyList: &noopReadyList{},
|
||||
}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
|
||||
openedStreams.Store(stream.streamID, stream)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
m.Abort()
|
||||
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
|
||||
openedStreams.Range(func(key interface{}, value interface{}) bool {
|
||||
stream := value.(*MuxedStream)
|
||||
readBuffer := stream.readBuffer.(*noopBuffer)
|
||||
writeBuffer := stream.writeBuffer.(*noopBuffer)
|
||||
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
|
||||
})
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
|
||||
}
|
||||
|
||||
// multiple aborts shouldn't cause any issues
|
||||
m.Abort()
|
||||
m.Abort()
|
||||
m.Abort()
|
||||
}
|
|
@ -542,7 +542,10 @@ func (w *h2DictWriter) Write(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (w *h2DictWriter) Close() error {
|
||||
return w.comp.Close()
|
||||
if w.comp != nil {
|
||||
return w.comp.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// From http2/hpack
|
||||
|
|
|
@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Shutdown is called to initiate the "happy path" of muxer termination.
|
||||
func (m *Muxer) Shutdown() {
|
||||
// It blocks new streams from being created.
|
||||
// It returns a channel that is closed when the last stream has been closed.
|
||||
func (m *Muxer) Shutdown() <-chan struct{} {
|
||||
m.explicitShutdown.Fuse(true)
|
||||
m.muxReader.Shutdown()
|
||||
return m.muxReader.Shutdown()
|
||||
}
|
||||
|
||||
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||
|
@ -390,7 +392,7 @@ func isConnectionClosedError(err error) bool {
|
|||
// Called by proxy server and tunnel
|
||||
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
|
||||
stream := m.NewStream(headers)
|
||||
if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream, body}); err != nil {
|
||||
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
||||
|
@ -401,7 +403,7 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader
|
|||
|
||||
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
|
||||
stream := m.NewStream(RPCHeaders())
|
||||
if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream: stream, body: nil}); err != nil {
|
||||
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
||||
|
|
|
@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
|||
DefaultWindowSize: (1 << 8) - 1,
|
||||
MaxWindowSize: (1 << 15) - 1,
|
||||
StreamWriteBufferMaxLen: 1024,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
|
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
|||
DefaultWindowSize: (1 << 8) - 1,
|
||||
MaxWindowSize: (1 << 15) - 1,
|
||||
StreamWriteBufferMaxLen: 1024,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
|
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
|||
Name: "origin",
|
||||
CompressionQuality: quality,
|
||||
Logger: log.NewEntry(log.New()),
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
|
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
|||
Name: "edge",
|
||||
CompressionQuality: quality,
|
||||
Logger: log.NewEntry(log.New()),
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
|
|
|
@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
|
|||
Closed() bool
|
||||
}
|
||||
|
||||
// MuxedStreamDataSignaller is a write-only *ReadyList
|
||||
type MuxedStreamDataSignaller interface {
|
||||
// Non-blocking: call this when data is ready to be sent for the given stream ID.
|
||||
Signal(ID uint32)
|
||||
}
|
||||
|
||||
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
||||
type MuxedStream struct {
|
||||
streamID uint32
|
||||
|
@ -55,8 +61,8 @@ type MuxedStream struct {
|
|||
// This is the amount of bytes that are in the peer's receive window
|
||||
// (how much data we can send from this stream).
|
||||
sendWindow uint32
|
||||
// Reference to the muxer's readyList; signal this for stream data to be sent.
|
||||
readyList *ReadyList
|
||||
// The muxer's readyList
|
||||
readyList MuxedStreamDataSignaller
|
||||
// The headers that should be sent, and a flag so we only send them once.
|
||||
headersSent bool
|
||||
writeHeaders []Header
|
||||
|
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
|
|||
return th != ""
|
||||
}
|
||||
|
||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream {
|
||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
|
||||
return &MuxedStream{
|
||||
responseHeadersReceived: make(chan struct{}),
|
||||
readBuffer: NewSharedBuffer(),
|
||||
|
|
|
@ -51,10 +51,12 @@ type MuxReader struct {
|
|||
dictionaries h2Dictionaries
|
||||
}
|
||||
|
||||
func (r *MuxReader) Shutdown() {
|
||||
done := r.streams.Shutdown()
|
||||
if done == nil {
|
||||
return
|
||||
// Shutdown blocks new streams from being created.
|
||||
// It returns a channel that is closed once the last stream has closed.
|
||||
func (r *MuxReader) Shutdown() <-chan struct{} {
|
||||
done, alreadyInProgress := r.streams.Shutdown()
|
||||
if alreadyInProgress {
|
||||
return done
|
||||
}
|
||||
r.sendGoAway(http2.ErrCodeNo)
|
||||
go func() {
|
||||
|
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
|
|||
<-done
|
||||
r.r.Close()
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||
|
|
|
@ -54,6 +54,13 @@ type MuxedStreamRequest struct {
|
|||
body io.Reader
|
||||
}
|
||||
|
||||
func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest {
|
||||
return MuxedStreamRequest{
|
||||
stream: stream,
|
||||
body: body,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MuxedStreamRequest) flushBody() {
|
||||
io.Copy(r.stream, r.body)
|
||||
r.stream.CloseWrite()
|
||||
|
|
|
@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
|
|||
}
|
||||
return b.BaseTime
|
||||
}
|
||||
|
||||
// Retries returns the number of retries consumed so far.
|
||||
func (b *BackoffHandler) Retries() int {
|
||||
return int(b.retries)
|
||||
}
|
||||
|
|
|
@ -2,16 +2,20 @@ package origin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/signal"
|
||||
|
||||
"github.com/google/uuid"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -21,11 +25,23 @@ const (
|
|||
resolveTTL = time.Hour
|
||||
// Interval between registering new tunnels
|
||||
registrationInterval = time.Second
|
||||
|
||||
subsystemRefreshAuth = "refresh_auth"
|
||||
// Maximum exponent for 'Authenticate' exponential backoff
|
||||
refreshAuthMaxBackoff = 10
|
||||
// Waiting time before retrying a failed 'Authenticate' connection
|
||||
refreshAuthRetryDuration = time.Second * 10
|
||||
)
|
||||
|
||||
var (
|
||||
errJWTUnset = errors.New("JWT unset")
|
||||
errEventDigestUnset = errors.New("event digest unset")
|
||||
)
|
||||
|
||||
type Supervisor struct {
|
||||
config *TunnelConfig
|
||||
edgeIPs []*net.TCPAddr
|
||||
cloudflaredUUID uuid.UUID
|
||||
config *TunnelConfig
|
||||
edgeIPs []*net.TCPAddr
|
||||
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
||||
nextUnusedEdgeIP int
|
||||
lastResolve time.Time
|
||||
|
@ -38,6 +54,12 @@ type Supervisor struct {
|
|||
nextConnectedSignal chan struct{}
|
||||
|
||||
logger *logrus.Entry
|
||||
|
||||
jwtLock *sync.RWMutex
|
||||
jwt []byte
|
||||
|
||||
eventDigestLock *sync.RWMutex
|
||||
eventDigest []byte
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
|
@ -50,18 +72,21 @@ type tunnelError struct {
|
|||
err error
|
||||
}
|
||||
|
||||
func NewSupervisor(config *TunnelConfig) *Supervisor {
|
||||
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
|
||||
return &Supervisor{
|
||||
cloudflaredUUID: u,
|
||||
config: config,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
logger: config.Logger.WithField("subsystem", "supervisor"),
|
||||
jwtLock: &sync.RWMutex{},
|
||||
eventDigestLock: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
logger := s.config.Logger
|
||||
if err := s.initialize(ctx, connectedSignal, u); err != nil {
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
var tunnelsWaiting []int
|
||||
|
@ -69,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
|||
var backoffTimer <-chan time.Time
|
||||
tunnelsActive := s.config.HAConnections
|
||||
|
||||
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
||||
var refreshAuthBackoffTimer <-chan time.Time
|
||||
if s.config.UseReconnectToken {
|
||||
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
// Context cancelled
|
||||
|
@ -104,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
|||
case <-backoffTimer:
|
||||
backoffTimer = nil
|
||||
for _, index := range tunnelsWaiting {
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||
}
|
||||
tunnelsActive += len(tunnelsWaiting)
|
||||
tunnelsWaiting = nil
|
||||
// Time to call Authenticate
|
||||
case <-refreshAuthBackoffTimer:
|
||||
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Authentication failed")
|
||||
// Permanent failure. Leave the `select` without setting the
|
||||
// channel to be non-null, so we'll never hit this case of the `select` again.
|
||||
continue
|
||||
}
|
||||
refreshAuthBackoffTimer = newTimer
|
||||
// Tunnel successfully connected
|
||||
case <-s.nextConnectedSignal:
|
||||
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
||||
|
@ -128,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
logger := s.logger
|
||||
|
||||
edgeIPs, err := s.resolveEdgeIPs()
|
||||
|
@ -145,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
|||
s.lastResolve = time.Now()
|
||||
// check entitlement and version too old error before attempting to register more tunnels
|
||||
s.nextUnusedEdgeIP = s.config.HAConnections
|
||||
go s.startFirstTunnel(ctx, connectedSignal, u)
|
||||
go s.startFirstTunnel(ctx, connectedSignal)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
<-s.tunnelErrors
|
||||
// Error can't be nil. A nil error signals that initialization succeed
|
||||
return fmt.Errorf("context was canceled")
|
||||
return ctx.Err()
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
return tunnelError.err
|
||||
case <-connectedSignal.Wait():
|
||||
|
@ -158,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
|||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
ch := signal.New(make(chan struct{}))
|
||||
go s.startTunnel(ctx, i, ch, u)
|
||||
go s.startTunnel(ctx, i, ch)
|
||||
time.Sleep(registrationInterval)
|
||||
}
|
||||
return nil
|
||||
|
@ -166,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
|||
|
||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||
defer func() {
|
||||
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
||||
}()
|
||||
|
@ -183,19 +224,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
|||
return
|
||||
// try the next address if it was a dialError(network problem) or
|
||||
// dupConnRegisterTunnelError
|
||||
case dialError, dupConnRegisterTunnelError:
|
||||
case connection.DialError, dupConnRegisterTunnelError:
|
||||
s.replaceEdgeIP(0)
|
||||
default:
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
|
||||
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||
}
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
|
||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||
}
|
||||
|
||||
|
@ -253,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
|
|||
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
||||
s.nextUnusedEdgeIP++
|
||||
}
|
||||
|
||||
func (s *Supervisor) ReconnectToken() ([]byte, error) {
|
||||
s.jwtLock.RLock()
|
||||
defer s.jwtLock.RUnlock()
|
||||
if s.jwt == nil {
|
||||
return nil, errJWTUnset
|
||||
}
|
||||
return s.jwt, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) SetReconnectToken(jwt []byte) {
|
||||
s.jwtLock.Lock()
|
||||
defer s.jwtLock.Unlock()
|
||||
s.jwt = jwt
|
||||
}
|
||||
|
||||
func (s *Supervisor) EventDigest() ([]byte, error) {
|
||||
s.eventDigestLock.RLock()
|
||||
defer s.eventDigestLock.RUnlock()
|
||||
if s.eventDigest == nil {
|
||||
return nil, errEventDigestUnset
|
||||
}
|
||||
return s.eventDigest, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
|
||||
s.eventDigestLock.Lock()
|
||||
defer s.eventDigestLock.Unlock()
|
||||
s.eventDigest = eventDigest
|
||||
}
|
||||
|
||||
func (s *Supervisor) refreshAuth(
|
||||
ctx context.Context,
|
||||
backoff *BackoffHandler,
|
||||
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||
) (retryTimer <-chan time.Time, err error) {
|
||||
logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth)
|
||||
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||
if err != nil {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
logger.WithError(err).Warnf("Retrying in %v", duration)
|
||||
return backoff.BackoffTimer(), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// clear backoff timer
|
||||
backoff.SetGracePeriod()
|
||||
|
||||
switch outcome := authOutcome.(type) {
|
||||
case tunnelpogs.AuthSuccess:
|
||||
s.SetReconnectToken(outcome.JWT())
|
||||
return timeAfter(outcome.RefreshAfter()), nil
|
||||
case tunnelpogs.AuthUnknown:
|
||||
return timeAfter(outcome.RefreshAfter()), nil
|
||||
case tunnelpogs.AuthFail:
|
||||
return nil, outcome
|
||||
default:
|
||||
return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
||||
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
|
||||
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer edgeConn.Close()
|
||||
|
||||
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
|
||||
// This callback is invoked by h2mux when the edge initiates a stream.
|
||||
return nil // noop
|
||||
})
|
||||
muxerConfig := s.config.muxerConfig(handler)
|
||||
muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth)
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go muxer.Serve(ctx)
|
||||
defer func() {
|
||||
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
|
||||
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
|
||||
<-muxer.Shutdown()
|
||||
}()
|
||||
|
||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tunnelServer.Close()
|
||||
|
||||
const arbitraryConnectionID = uint8(0)
|
||||
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
||||
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||
authResponse, err := tunnelServer.Authenticate(
|
||||
ctx,
|
||||
s.config.OriginCert,
|
||||
s.config.Hostname,
|
||||
registrationOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authResponse.Outcome(), nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
func TestRefreshAuthBackoff(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
logger.Level = logrus.ErrorLevel
|
||||
|
||||
var wait time.Duration
|
||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||
wait = d
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return nil, fmt.Errorf("authentication failure")
|
||||
}
|
||||
|
||||
// authentication failures should consume the backoff
|
||||
for i := uint(0); i < backoff.MaxRetries; i++ {
|
||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, retryChan)
|
||||
assert.Equal(t, (1<<i)*time.Second, wait)
|
||||
}
|
||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, retryChan)
|
||||
|
||||
// now we actually make contact with the remote server
|
||||
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||
})
|
||||
|
||||
// The backoff timer should have been reset. To confirm this, make timeNow
|
||||
// return a value after the backoff timer's grace period
|
||||
timeNow = func() time.Time {
|
||||
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
|
||||
return time.Now().Add(expectedGracePeriod * 2)
|
||||
}
|
||||
_, ok := backoff.GetBackoffDuration(context.Background())
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestRefreshAuthSuccess(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
logger.Level = logrus.ErrorLevel
|
||||
|
||||
var wait time.Duration
|
||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||
wait = d
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||
}
|
||||
|
||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, retryChan)
|
||||
assert.Equal(t, 19*time.Hour, wait)
|
||||
|
||||
token, err := s.ReconnectToken()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("jwt"), token)
|
||||
}
|
||||
|
||||
func TestRefreshAuthUnknown(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
logger.Level = logrus.ErrorLevel
|
||||
|
||||
var wait time.Duration
|
||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||
wait = d
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||
}
|
||||
|
||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, retryChan)
|
||||
assert.Equal(t, 19*time.Hour, wait)
|
||||
|
||||
token, err := s.ReconnectToken()
|
||||
assert.Equal(t, errJWTUnset, err)
|
||||
assert.Nil(t, token)
|
||||
}
|
||||
|
||||
func TestRefreshAuthFail(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
logger.Level = logrus.ErrorLevel
|
||||
|
||||
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||
backoff := &BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||
}
|
||||
|
||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, retryChan)
|
||||
|
||||
token, err := s.ReconnectToken()
|
||||
assert.Equal(t, errJWTUnset, err)
|
||||
assert.Nil(t, token)
|
||||
}
|
159
origin/tunnel.go
159
origin/tunnel.go
|
@ -14,7 +14,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/signal"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
|
@ -22,19 +29,12 @@ import (
|
|||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
_ "github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 15 * time.Second
|
||||
openStreamTimeout = 30 * time.Second
|
||||
muxerTimeout = 5 * time.Second
|
||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||
DuplicateConnectionError = "EDUPCONN"
|
||||
|
@ -73,14 +73,9 @@ type TunnelConfig struct {
|
|||
WSGI bool
|
||||
// OriginUrl may not be used if a user specifies a unix socket.
|
||||
OriginUrl string
|
||||
}
|
||||
|
||||
type dialError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e dialError) Error() string {
|
||||
return e.cause.Error()
|
||||
// feature-flag to use new edge reconnect tokens
|
||||
UseReconnectToken bool
|
||||
}
|
||||
|
||||
type dupConnRegisterTunnelError struct{}
|
||||
|
@ -119,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string {
|
|||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
|
||||
return h2mux.MuxerConfig{
|
||||
Timeout: muxerTimeout,
|
||||
Handler: handler,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: c.HeartbeatInterval,
|
||||
MaxHeartbeats: c.MaxHeartbeats,
|
||||
Logger: c.TransportLogger.WithFields(log.Fields{}),
|
||||
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
||||
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
||||
if c.HAConnections <= 1 && c.LBPool == "" {
|
||||
|
@ -141,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
|||
}
|
||||
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
||||
return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID)
|
||||
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
|
||||
}
|
||||
|
||||
func ServeTunnelLoop(ctx context.Context,
|
||||
|
@ -213,11 +220,11 @@ func ServeTunnel(
|
|||
tags["ha"] = connectionTag
|
||||
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID)
|
||||
if err != nil {
|
||||
errLog := logger.WithError(err)
|
||||
switch err.(type) {
|
||||
case dialError:
|
||||
case connection.DialError:
|
||||
errLog.Error("Unable to dial edge")
|
||||
case h2mux.MuxerHandshakeError:
|
||||
errLog.Error("Handshake failed with edge server")
|
||||
|
@ -295,16 +302,6 @@ func ServeTunnel(
|
|||
return nil, true
|
||||
}
|
||||
|
||||
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||
if len(headers) != 1 {
|
||||
return false
|
||||
}
|
||||
if headers[0].Name != ":status" || headers[0].Value != "200" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func RegisterTunnel(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
|
@ -315,43 +312,31 @@ func RegisterTunnel(
|
|||
uuid uuid.UUID,
|
||||
) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream to register")
|
||||
stream, err := openStream(ctx, muxer)
|
||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
||||
}
|
||||
if !IsRPCStreamResponse(stream.Headers) {
|
||||
// stream response error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
||||
}
|
||||
conn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
|
||||
)
|
||||
defer conn.Close()
|
||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||
defer tunnelServer.Close()
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
|
||||
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
registration, err := ts.RegisterTunnel(
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
||||
registration := tunnelServer.RegisterTunnel(
|
||||
ctx,
|
||||
config.OriginCert,
|
||||
config.Hostname,
|
||||
config.RegistrationOptions(connectionID, originLocalIP, uuid),
|
||||
)
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
||||
if err != nil {
|
||||
|
||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||
// RegisterTunnel RPC failure
|
||||
return newClientRegisterTunnelError(err, config.Metrics.regFail)
|
||||
}
|
||||
for _, logLine := range registration.LogLines {
|
||||
logger.Info(logLine)
|
||||
return processRegisterTunnelError(registrationErr, config.Metrics)
|
||||
}
|
||||
|
||||
if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil {
|
||||
return regErr
|
||||
for _, logLine := range registration.LogLines {
|
||||
logger.Info(logLine)
|
||||
}
|
||||
|
||||
if registration.TunnelID != "" {
|
||||
|
@ -374,57 +359,34 @@ func RegisterTunnel(
|
|||
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
||||
|
||||
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
||||
config.Metrics.regSuccess.Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error {
|
||||
if err == "" {
|
||||
metrics.regSuccess.Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
metrics.regFail.WithLabelValues(err).Inc()
|
||||
if err == DuplicateConnectionError {
|
||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
|
||||
if err.Error() == DuplicateConnectionError {
|
||||
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
|
||||
return dupConnRegisterTunnelError{}
|
||||
}
|
||||
metrics.regFail.WithLabelValues("server_error").Inc()
|
||||
return serverRegisterTunnelError{
|
||||
cause: fmt.Errorf("Server error: %s", err),
|
||||
permanent: permanentFailure,
|
||||
cause: fmt.Errorf("Server error: %s", err.Error()),
|
||||
permanent: err.IsPermanent(),
|
||||
}
|
||||
}
|
||||
|
||||
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
|
||||
logger.Debug("initiating RPC stream to unregister")
|
||||
ctx := context.Background()
|
||||
stream, err := openStream(ctx, muxer)
|
||||
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return err
|
||||
}
|
||||
if !IsRPCStreamResponse(stream.Headers) {
|
||||
// stream response error
|
||||
return err
|
||||
}
|
||||
conn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
||||
)
|
||||
defer conn.Close()
|
||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||
// gracePeriod is encoded in int64 using capnproto
|
||||
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
||||
}
|
||||
|
||||
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
|
||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||
defer cancel()
|
||||
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func LogServerInfo(
|
||||
promise tunnelrpc.ServerInfo_Promise,
|
||||
connectionID uint8,
|
||||
|
@ -469,12 +431,12 @@ type TunnelHandler struct {
|
|||
noChunkedEncoding bool
|
||||
}
|
||||
|
||||
var dialer = net.Dialer{DualStack: true}
|
||||
var dialer = net.Dialer{}
|
||||
|
||||
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
||||
func NewTunnelHandler(ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
addr string,
|
||||
addr *net.TCPAddr,
|
||||
connectionID uint8,
|
||||
) (*TunnelHandler, string, error) {
|
||||
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
||||
|
@ -495,37 +457,18 @@ func NewTunnelHandler(ctx context.Context,
|
|||
if h.httpClient == nil {
|
||||
h.httpClient = http.DefaultTransport
|
||||
}
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
||||
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
||||
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
|
||||
dialCancel()
|
||||
|
||||
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
|
||||
if err != nil {
|
||||
return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
|
||||
return nil, "", err
|
||||
}
|
||||
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
||||
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
||||
err = edgeConn.Handshake()
|
||||
if err != nil {
|
||||
return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
edgeConn.SetDeadline(time.Time{})
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
Handler: h,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: config.HeartbeatInterval,
|
||||
MaxHeartbeats: config.MaxHeartbeats,
|
||||
Logger: config.TransportLogger.WithFields(log.Fields{}),
|
||||
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
|
||||
}, h.metrics.activeStreams)
|
||||
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
|
||||
if err != nil {
|
||||
return h, "", errors.New("TLS handshake error")
|
||||
return nil, "", errors.Wrap(err, "Handshake with edge error")
|
||||
}
|
||||
return h, edgeConn.LocalAddr().String(), err
|
||||
return h, edgeConn.LocalAddr().String(), nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
||||
|
|
|
@ -35,26 +35,49 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro
|
|||
return req, nil
|
||||
}
|
||||
|
||||
// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request
|
||||
// object. This includes conversion of the pseudo-headers into their closest
|
||||
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
|
||||
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||
for _, header := range h2 {
|
||||
switch header.Name {
|
||||
case ":method":
|
||||
h1.Method = header.Value
|
||||
case ":scheme":
|
||||
// noop - use the preexisting scheme from h1.URL
|
||||
case ":authority":
|
||||
// Otherwise the host header will be based on the origin URL
|
||||
h1.Host = header.Value
|
||||
case ":path":
|
||||
u, err := url.Parse(header.Value)
|
||||
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
|
||||
// However, this HTTP/1 Request object belongs to the Go standard library,
|
||||
// whose URL package makes some opinionated decisions about the encoding of
|
||||
// URL characters: see the docs of https://godoc.org/net/url#URL,
|
||||
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
|
||||
// which is always used when computing url.URL.String(), whether we'd like it or not.
|
||||
//
|
||||
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
|
||||
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
|
||||
// is treated differently when HTTP_PROXY is set!
|
||||
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
|
||||
//
|
||||
// This means we are subject to the behavior of net/url's function `shouldEscape`
|
||||
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
|
||||
|
||||
if header.Value == "*" {
|
||||
h1.URL.Path = "*"
|
||||
continue
|
||||
}
|
||||
// Due to the behavior of validation.ValidateUrl, h1.URL may
|
||||
// already have a partial value, with or without a trailing slash.
|
||||
base := h1.URL.String()
|
||||
base = strings.TrimRight(base, "/")
|
||||
// But we know :path begins with '/', because we handled '*' above - see RFC7540
|
||||
url, err := url.Parse(base + header.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unparseable path")
|
||||
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
|
||||
}
|
||||
resolved := h1.URL.ResolveReference(u)
|
||||
// prevent escaping base URL
|
||||
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
|
||||
return fmt.Errorf("invalid path %s", header.Value)
|
||||
}
|
||||
h1.URL = resolved
|
||||
h1.URL = url
|
||||
case "content-length":
|
||||
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,441 @@
|
|||
package streamhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{
|
||||
h2mux.Header{
|
||||
Name: "Mock header 1",
|
||||
Value: "Mock value 1",
|
||||
},
|
||||
h2mux.Header{
|
||||
Name: "Mock header 2",
|
||||
Value: "Mock value 2",
|
||||
},
|
||||
},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header 1": []string{"Mock value 1"},
|
||||
"Mock header 2": []string{"Mock value 2"},
|
||||
}, request.Header)
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.Header{}, request.Header)
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{
|
||||
h2mux.Header{
|
||||
Name: ":path",
|
||||
Value: "//bad_path/",
|
||||
},
|
||||
h2mux.Header{
|
||||
Name: "Mock header",
|
||||
Value: "Mock value",
|
||||
},
|
||||
},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{
|
||||
h2mux.Header{
|
||||
Name: ":path",
|
||||
Value: "/?query=mock%20value",
|
||||
},
|
||||
h2mux.Header{
|
||||
Name: "Mock header",
|
||||
Value: "Mock value",
|
||||
},
|
||||
},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{
|
||||
h2mux.Header{
|
||||
Name: ":path",
|
||||
Value: "/mock%20path",
|
||||
},
|
||||
h2mux.Header{
|
||||
Name: "Mock header",
|
||||
Value: "Mock value",
|
||||
},
|
||||
},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
|
||||
type testCase struct {
|
||||
path string
|
||||
want string
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
path: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
path: "/",
|
||||
want: "/",
|
||||
},
|
||||
{
|
||||
path: "//",
|
||||
want: "//",
|
||||
},
|
||||
{
|
||||
path: "/test",
|
||||
want: "/test",
|
||||
},
|
||||
{
|
||||
path: "//test",
|
||||
want: "//test",
|
||||
},
|
||||
{
|
||||
// https://github.com/cloudflare/cloudflared/issues/81
|
||||
path: "//test/",
|
||||
want: "//test/",
|
||||
},
|
||||
{
|
||||
path: "/%2Ftest",
|
||||
want: "/%2Ftest",
|
||||
},
|
||||
{
|
||||
path: "//%20test",
|
||||
want: "//%20test",
|
||||
},
|
||||
{
|
||||
// https://github.com/cloudflare/cloudflared/issues/124
|
||||
path: "/test?get=somthing%20a",
|
||||
want: "/test?get=somthing%20a",
|
||||
},
|
||||
{
|
||||
path: "/%20",
|
||||
want: "/%20",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode ' '
|
||||
path: "/ ",
|
||||
want: "/%20",
|
||||
},
|
||||
{
|
||||
path: "/ a ",
|
||||
want: "/%20a%20",
|
||||
},
|
||||
{
|
||||
path: "/a%20b",
|
||||
want: "/a%20b",
|
||||
},
|
||||
{
|
||||
path: "/foo/bar;param?query#frag",
|
||||
want: "/foo/bar;param?query#frag",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
|
||||
path: "/a␠b",
|
||||
want: "/a%E2%90%A0b",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-ä",
|
||||
want: "/a-umlaut-%C3%A4",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-%C3%A4",
|
||||
want: "/a-umlaut-%C3%A4",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-%c3%a4",
|
||||
want: "/a-umlaut-%c3%a4",
|
||||
},
|
||||
{
|
||||
// here the second '#' is treated as part of the fragment
|
||||
path: "/a#b#c",
|
||||
want: "/a#b%23c",
|
||||
},
|
||||
{
|
||||
path: "/a#b␠c",
|
||||
want: "/a#b%E2%90%A0c",
|
||||
},
|
||||
{
|
||||
path: "/a#b%20c",
|
||||
want: "/a#b%20c",
|
||||
},
|
||||
{
|
||||
path: "/a#b c",
|
||||
want: "/a#b%20c",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode '\'
|
||||
path: "/\\",
|
||||
want: "/%5C",
|
||||
},
|
||||
{
|
||||
path: "/a\\",
|
||||
want: "/a%5C",
|
||||
},
|
||||
{
|
||||
path: "/a,b.c.",
|
||||
want: "/a,b.c.",
|
||||
},
|
||||
{
|
||||
path: "/.",
|
||||
want: "/.",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode '`'
|
||||
path: "/a`",
|
||||
want: "/a%60",
|
||||
},
|
||||
{
|
||||
path: "/a[0]",
|
||||
want: "/a[0]",
|
||||
},
|
||||
{
|
||||
path: "/?a[0]=5 &b[]=",
|
||||
want: "/?a[0]=5 &b[]=",
|
||||
},
|
||||
{
|
||||
path: "/?a=%22b%20%22",
|
||||
want: "/?a=%22b%20%22",
|
||||
},
|
||||
}
|
||||
|
||||
for index, testCase := range testCases {
|
||||
requestURL := "https://example.com"
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
assert.NoError(t, err)
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{
|
||||
h2mux.Header{
|
||||
Name: ":path",
|
||||
Value: testCase.path,
|
||||
},
|
||||
h2mux.Header{
|
||||
Name: "Mock header",
|
||||
Value: "Mock value",
|
||||
},
|
||||
},
|
||||
request,
|
||||
)
|
||||
assert.NoError(t, headersConversionErr)
|
||||
|
||||
assert.Equal(t,
|
||||
http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
},
|
||||
request.Header)
|
||||
|
||||
assert.Equal(t,
|
||||
"https://example.com"+testCase.want,
|
||||
request.URL.String(),
|
||||
"Failed URL index: %v %#v", index, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
|
||||
config := &quick.Config{
|
||||
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
|
||||
},
|
||||
}
|
||||
|
||||
type testOrigin struct {
|
||||
url string
|
||||
|
||||
expectedScheme string
|
||||
expectedBasePath string
|
||||
}
|
||||
testOrigins := []testOrigin{
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/api",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/api/",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
{
|
||||
url: "https://origin.hostname.example.com:8080/api",
|
||||
expectedScheme: "https",
|
||||
expectedBasePath: "https://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
}
|
||||
|
||||
// use multiple schemes to demonstrate that the URL is based on the
|
||||
// origin's scheme, not the :scheme header
|
||||
for _, testScheme := range []string{"http", "https"} {
|
||||
for _, testOrigin := range testOrigins {
|
||||
assertion := func(testPath string) bool {
|
||||
const expectedMethod = "POST"
|
||||
const expectedHostname = "request.hostname.example.com"
|
||||
|
||||
h2 := []h2mux.Header{
|
||||
h2mux.Header{Name: ":method", Value: expectedMethod},
|
||||
h2mux.Header{Name: ":scheme", Value: testScheme},
|
||||
h2mux.Header{Name: ":authority", Value: expectedHostname},
|
||||
h2mux.Header{Name: ":path", Value: testPath},
|
||||
}
|
||||
h1, err := http.NewRequest("GET", testOrigin.url, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = H2RequestHeadersToH1Request(h2, h1)
|
||||
return assert.NoError(t, err) &&
|
||||
assert.Equal(t, expectedMethod, h1.Method) &&
|
||||
assert.Equal(t, expectedHostname, h1.Host) &&
|
||||
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
|
||||
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
|
||||
}
|
||||
err := quick.Check(assertion, config)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomASCIIPrintableChar(rand *rand.Rand) int {
|
||||
// smallest printable ASCII char is 32, largest is 126
|
||||
const startPrintable = 32
|
||||
const endPrintable = 127
|
||||
return startPrintable + rand.Intn(endPrintable-startPrintable)
|
||||
}
|
||||
|
||||
// randomASCIIText generates an ASCII string, some of whose characters may be
|
||||
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
|
||||
// between 1 and `maxLength`.
|
||||
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
|
||||
length := minLength + rand.Intn(maxLength)
|
||||
result := ""
|
||||
for i := 0; i < length; i++ {
|
||||
c := randomASCIIPrintableChar(rand)
|
||||
|
||||
// 1/4 chance of using percent encoding when not necessary
|
||||
if c == '%' || rand.Intn(4) == 0 {
|
||||
result += fmt.Sprintf("%%%02X", c)
|
||||
} else {
|
||||
result += string(c)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL path,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
regexp, err := regexp.Compile("[^/;,]*")
|
||||
require.NoError(t, err)
|
||||
return "/" + regexp.ReplaceAllStringFunc(text, url.PathEscape)
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL query,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
return "?" + strings.ReplaceAll(text, "#", "%23")
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
url, err := url.Parse("#" + text)
|
||||
require.NoError(t, err)
|
||||
return url.String()
|
||||
}
|
||||
|
||||
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
|
||||
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
|
||||
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
|
||||
result := randomHTTP1Path(t, rand, 1, 64)
|
||||
if rand.Intn(2) == 1 {
|
||||
result += randomHTTP1Query(t, rand, 1, 32)
|
||||
}
|
||||
if rand.Intn(2) == 1 {
|
||||
result += randomHTTP1Fragment(t, rand, 1, 16)
|
||||
}
|
||||
return result
|
||||
}
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -28,6 +29,27 @@ type Supervisor struct {
|
|||
useConfigResultChan chan<- *pogs.UseConfigurationResult
|
||||
state *state
|
||||
logger *logrus.Entry
|
||||
metrics metrics
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
configVersion prometheus.Gauge
|
||||
}
|
||||
|
||||
func newMetrics() metrics {
|
||||
configVersion := prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "supervisor",
|
||||
Subsystem: "supervisor",
|
||||
Name: "config_version",
|
||||
Help: "Latest configuration version received from Cloudflare",
|
||||
},
|
||||
)
|
||||
prometheus.MustRegister(
|
||||
configVersion,
|
||||
)
|
||||
return metrics{
|
||||
configVersion: configVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSupervisor(
|
||||
|
@ -70,6 +92,7 @@ func NewSupervisor(
|
|||
useConfigResultChan: useConfigResultChan,
|
||||
state: newState(defaultClientConfig),
|
||||
logger: logger.WithField("subsystem", "supervisor"),
|
||||
metrics: newMetrics(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -131,6 +154,7 @@ func (s *Supervisor) notifySubsystemsNewConfig(newConfig *pogs.ClientConfig) *po
|
|||
Success: true,
|
||||
}
|
||||
}
|
||||
s.metrics.configVersion.Set(float64(newConfig.Version))
|
||||
|
||||
s.state.updateConfig(newConfig)
|
||||
var tunnelHostnames []h2mux.TunnelHostname
|
||||
|
|
|
@ -26,28 +26,28 @@ mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
|
|||
-----END CERTIFICATE-----
|
||||
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
|
||||
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
|
||||
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
|
||||
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
|
||||
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
|
||||
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
|
||||
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
|
||||
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
|
||||
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
|
||||
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
|
||||
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
|
||||
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
|
||||
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
|
||||
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
|
||||
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
|
||||
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
|
||||
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
|
||||
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
|
||||
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
|
||||
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
|
||||
Fu6q54beR89jDc+oABmOgg==
|
||||
MIIEADCCAuigAwIBAgIID+rOSdTGfGcwDQYJKoZIhvcNAQELBQAwgYsxCzAJBgNV
|
||||
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91
|
||||
ZEZsYXJlIE9yaWdpbiBTU0wgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQH
|
||||
Ew1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMB4XDTE5MDgyMzIx
|
||||
MDgwMFoXDTI5MDgxNTE3MDAwMFowgYsxCzAJBgNVBAYTAlVTMRkwFwYDVQQKExBD
|
||||
bG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91ZEZsYXJlIE9yaWdpbiBTU0wg
|
||||
Q2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRMw
|
||||
EQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
|
||||
AQEAwEiVZ/UoQpHmFsHvk5isBxRehukP8DG9JhFev3WZtG76WoTthvLJFRKFCHXm
|
||||
V6Z5/66Z4S09mgsUuFwvJzMnE6Ej6yIsYNCb9r9QORa8BdhrkNn6kdTly3mdnykb
|
||||
OomnwbUfLlExVgNdlP0XoRoeMwbQ4598foiHblO2B/LKuNfJzAMfS7oZe34b+vLB
|
||||
yrP/1bgCSLdc1AxQc1AC0EsQQhgcyTJNgnG4va1c7ogPlwKyhbDyZ4e59N5lbYPJ
|
||||
SmXI/cAe3jXj1FBLJZkwnoDKe0v13xeF+nF32smSH0qB7aJX2tBMW4TWtFPmzs5I
|
||||
lwrFSySWAdwYdgxw180yKU0dvwIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAQYwEgYD
|
||||
VR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUJOhTV118NECHqeuU27rhFnj8KaQw
|
||||
HwYDVR0jBBgwFoAUJOhTV118NECHqeuU27rhFnj8KaQwDQYJKoZIhvcNAQELBQAD
|
||||
ggEBAHwOf9Ur1l0Ar5vFE6PNrZWrDfQIMyEfdgSKofCdTckbqXNTiXdgbHs+TWoQ
|
||||
wAB0pfJDAHJDXOTCWRyTeXOseeOi5Btj5CnEuw3P0oXqdqevM1/+uWp0CM35zgZ8
|
||||
VD4aITxity0djzE6Qnx3Syzz+ZkoBgTnNum7d9A66/V636x4vTeqbZFBr9erJzgz
|
||||
hhurjcoacvRNhnjtDRM0dPeiCJ50CP3wEYuvUzDHUaowOsnLCjQIkWbR7Ni6KEIk
|
||||
MOz2U0OBSif3FTkhCgZWQKOOLo1P42jHC3ssUZAtVNXrCk3fw9/E15k8NPkBazZ6
|
||||
0iykLhH1trywrKRMVw67F44IE8Y=
|
||||
-----END CERTIFICATE-----
|
||||
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
|
||||
-----BEGIN CERTIFICATE-----
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuthenticateResponse is the serialized response from the Authenticate RPC.
|
||||
// It's a 1:1 representation of the capnp message, so it's not very useful for programmers.
|
||||
// Instead, you should call the `Outcome()` method to get a programmer-friendly sum type, with one
|
||||
// case for each possible outcome.
|
||||
type AuthenticateResponse struct {
|
||||
PermanentErr string
|
||||
RetryableErr string
|
||||
Jwt []byte
|
||||
HoursUntilRefresh uint8
|
||||
}
|
||||
|
||||
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
|
||||
func (ar AuthenticateResponse) Outcome() AuthOutcome {
|
||||
// If the user's authentication was unsuccessful, the server will return an error explaining why.
|
||||
// cloudflared should fatal with this error.
|
||||
if ar.PermanentErr != "" {
|
||||
return NewAuthFail(errors.New(ar.PermanentErr))
|
||||
}
|
||||
|
||||
// If there was a network error, then cloudflared should retry later,
|
||||
// because origintunneld couldn't prove whether auth was correct or not.
|
||||
if ar.RetryableErr != "" {
|
||||
return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh)
|
||||
}
|
||||
|
||||
// If auth succeeded, return the token and refresh it when instructed.
|
||||
if len(ar.Jwt) > 0 {
|
||||
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
|
||||
}
|
||||
|
||||
// Otherwise the state got messed up.
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthOutcome is a programmer-friendly sum type denoting the possible outcomes of Authenticate.
|
||||
//go-sumtype:decl AuthOutcome
|
||||
type AuthOutcome interface {
|
||||
isAuthOutcome()
|
||||
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||
Serialize() AuthenticateResponse
|
||||
}
|
||||
|
||||
// AuthSuccess means the backend successfully authenticated this cloudflared.
|
||||
type AuthSuccess struct {
|
||||
jwt []byte
|
||||
hoursUntilRefresh uint8
|
||||
}
|
||||
|
||||
func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess {
|
||||
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
|
||||
}
|
||||
|
||||
func (ao AuthSuccess) JWT() []byte {
|
||||
return ao.jwt
|
||||
}
|
||||
|
||||
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||
func (ao AuthSuccess) RefreshAfter() time.Duration {
|
||||
return hoursToTime(ao.hoursUntilRefresh)
|
||||
}
|
||||
|
||||
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||
func (ao AuthSuccess) Serialize() AuthenticateResponse {
|
||||
return AuthenticateResponse{
|
||||
Jwt: ao.jwt,
|
||||
HoursUntilRefresh: ao.hoursUntilRefresh,
|
||||
}
|
||||
}
|
||||
|
||||
func (ao AuthSuccess) isAuthOutcome() {}
|
||||
|
||||
// AuthFail means this cloudflared has the wrong auth and should exit.
|
||||
type AuthFail struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewAuthFail(err error) AuthFail {
|
||||
return AuthFail{err: err}
|
||||
}
|
||||
|
||||
func (ao AuthFail) Error() string {
|
||||
return ao.err.Error()
|
||||
}
|
||||
|
||||
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||
func (ao AuthFail) Serialize() AuthenticateResponse {
|
||||
return AuthenticateResponse{
|
||||
PermanentErr: ao.err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ao AuthFail) isAuthOutcome() {}
|
||||
|
||||
// AuthUnknown means the backend couldn't finish checking authentication. Try again later.
|
||||
type AuthUnknown struct {
|
||||
err error
|
||||
hoursUntilRefresh uint8
|
||||
}
|
||||
|
||||
func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown {
|
||||
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
|
||||
}
|
||||
|
||||
func (ao AuthUnknown) Error() string {
|
||||
return ao.err.Error()
|
||||
}
|
||||
|
||||
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||
func (ao AuthUnknown) RefreshAfter() time.Duration {
|
||||
return hoursToTime(ao.hoursUntilRefresh)
|
||||
}
|
||||
|
||||
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||
func (ao AuthUnknown) Serialize() AuthenticateResponse {
|
||||
return AuthenticateResponse{
|
||||
RetryableErr: ao.err.Error(),
|
||||
HoursUntilRefresh: ao.hoursUntilRefresh,
|
||||
}
|
||||
}
|
||||
|
||||
func (ao AuthUnknown) isAuthOutcome() {}
|
||||
|
||||
func hoursToTime(hours uint8) time.Duration {
|
||||
return time.Duration(hours) * time.Hour
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func (i TunnelServer_PogsImpl) Authenticate(p tunnelrpc.TunnelServer_authenticate) error {
|
||||
originCert, err := p.Params.OriginCert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hostname, err := p.Params.Hostname()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
server.Ack(p.Options)
|
||||
resp, err := i.impl.Authenticate(p.Ctx, originCert, hostname, pogsOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalAuthenticateResponse(result, resp)
|
||||
}
|
||||
|
||||
func MarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse, p *AuthenticateResponse) error {
|
||||
return pogs.Insert(tunnelrpc.AuthenticateResponse_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.Authenticate(ctx, func(p tunnelrpc.TunnelServer_authenticate_Params) error {
|
||||
err := p.SetOriginCert(originCert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetHostname(hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registrationOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
retval, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalAuthenticateResponse(retval)
|
||||
}
|
||||
|
||||
func UnmarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse) (*AuthenticateResponse, error) {
|
||||
p := new(AuthenticateResponse)
|
||||
err := pogs.Extract(p, tunnelrpc.AuthenticateResponse_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
// Ensure the AuthOutcome sum is correct
|
||||
var _ AuthOutcome = &AuthSuccess{}
|
||||
var _ AuthOutcome = &AuthFail{}
|
||||
var _ AuthOutcome = &AuthUnknown{}
|
||||
|
||||
// Unit tests for AuthenticateResponse.Outcome()
|
||||
func TestAuthenticateResponseOutcome(t *testing.T) {
|
||||
type fields struct {
|
||||
PermanentErr string
|
||||
RetryableErr string
|
||||
Jwt []byte
|
||||
HoursUntilRefresh uint8
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want AuthOutcome
|
||||
}{
|
||||
{"success",
|
||||
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
|
||||
AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6},
|
||||
},
|
||||
{"fail",
|
||||
fields{PermanentErr: "bad creds"},
|
||||
AuthFail{err: fmt.Errorf("bad creds")},
|
||||
},
|
||||
{"error",
|
||||
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
|
||||
AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6},
|
||||
},
|
||||
{"nil (no fields are set)",
|
||||
fields{},
|
||||
nil,
|
||||
},
|
||||
{"nil (too few fields are set)",
|
||||
fields{HoursUntilRefresh: 6},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ar := AuthenticateResponse{
|
||||
PermanentErr: tt.fields.PermanentErr,
|
||||
RetryableErr: tt.fields.RetryableErr,
|
||||
Jwt: tt.fields.Jwt,
|
||||
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
|
||||
}
|
||||
got := ar.Outcome()
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
|
||||
}
|
||||
if got != nil && !reflect.DeepEqual(got.Serialize(), ar) {
|
||||
t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSuccess(t *testing.T) {
|
||||
input := NewAuthSuccess([]byte("asdf"), 6)
|
||||
output, ok := input.Serialize().Outcome().(AuthSuccess)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, input, output)
|
||||
}
|
||||
|
||||
func TestAuthUnknown(t *testing.T) {
|
||||
input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6)
|
||||
output, ok := input.Serialize().Outcome().(AuthUnknown)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, input, output)
|
||||
}
|
||||
|
||||
func TestAuthFail(t *testing.T) {
|
||||
input := NewAuthFail(fmt.Errorf("wrong creds"))
|
||||
output, ok := input.Serialize().Outcome().(AuthFail)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, input, output)
|
||||
}
|
||||
|
||||
func TestWhenToRefresh(t *testing.T) {
|
||||
expected := 4 * time.Hour
|
||||
actual := hoursToTime(4)
|
||||
if expected != actual {
|
||||
t.Fatalf("expected %v hours, got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that serializing and deserializing AuthenticationResponse undo each other.
|
||||
func TestSerializeAuthenticationResponse(t *testing.T) {
|
||||
|
||||
tests := []*AuthenticateResponse{
|
||||
&AuthenticateResponse{
|
||||
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
|
||||
HoursUntilRefresh: 24,
|
||||
},
|
||||
&AuthenticateResponse{
|
||||
PermanentErr: "bad auth",
|
||||
},
|
||||
&AuthenticateResponse{
|
||||
RetryableErr: "bad connection",
|
||||
HoursUntilRefresh: 24,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range tests {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalAuthenticateResponse(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalAuthenticateResponse(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
}
|
||||
}
|
|
@ -197,11 +197,14 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
dialContext := (&net.Dialer{
|
||||
dialer := &net.Dialer{
|
||||
Timeout: hc.ProxyConnectionTimeout,
|
||||
KeepAlive: hc.TCPKeepAlive,
|
||||
DualStack: hc.DialDualStack,
|
||||
}).DialContext
|
||||
}
|
||||
if !hc.DialDualStack {
|
||||
dialer.FallbackDelay = -1
|
||||
}
|
||||
dialContext := dialer.DialContext
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialContext,
|
||||
|
@ -270,7 +273,6 @@ func (*HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
|
|||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func (i TunnelServer_PogsImpl) ReconnectTunnel(p tunnelrpc.TunnelServer_reconnectTunnel) error {
|
||||
jwt, err := p.Params.Jwt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventDigest, err := p.Params.EventDigest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hostname, err := p.Params.Hostname()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
server.Ack(p.Options)
|
||||
registration, err := i.impl.ReconnectTunnel(p.Ctx, jwt, eventDigest, hostname, pogsOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalTunnelRegistration(result, registration)
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) ReconnectTunnel(
|
||||
ctx context.Context,
|
||||
jwt,
|
||||
eventDigest []byte,
|
||||
hostname string,
|
||||
options *RegistrationOptions,
|
||||
) (*TunnelRegistration, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.ReconnectTunnel(ctx, func(p tunnelrpc.TunnelServer_reconnectTunnel_Params) error {
|
||||
err := p.SetJwt(jwt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetEventDigest(eventDigest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetHostname(hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registrationOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
retval, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalTunnelRegistration(retval)
|
||||
}
|
|
@ -9,13 +9,16 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRetryAfterSeconds = 15
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
Key string
|
||||
Email string
|
||||
|
@ -33,11 +36,112 @@ func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error
|
|||
}
|
||||
|
||||
type TunnelRegistration struct {
|
||||
Err string
|
||||
Url string
|
||||
LogLines []string
|
||||
PermanentFailure bool
|
||||
TunnelID string `capnp:"tunnelID"`
|
||||
SuccessfulTunnelRegistration
|
||||
Err string
|
||||
PermanentFailure bool
|
||||
RetryAfterSeconds uint16
|
||||
}
|
||||
|
||||
type SuccessfulTunnelRegistration struct {
|
||||
Url string
|
||||
LogLines []string
|
||||
TunnelID string `capnp:"tunnelID"`
|
||||
EventDigest []byte
|
||||
}
|
||||
|
||||
func NewSuccessfulTunnelRegistration(
|
||||
url string,
|
||||
logLines []string,
|
||||
tunnelID string,
|
||||
eventDigest []byte,
|
||||
) *TunnelRegistration {
|
||||
// Marshal nil will result in an error
|
||||
if logLines == nil {
|
||||
logLines = []string{}
|
||||
}
|
||||
return &TunnelRegistration{
|
||||
SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{
|
||||
Url: url,
|
||||
LogLines: logLines,
|
||||
TunnelID: tunnelID,
|
||||
EventDigest: eventDigest,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Not calling this function Error() to avoid confusion with implementing error interface
|
||||
func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError {
|
||||
if tr.Err != "" {
|
||||
err := fmt.Errorf(tr.Err)
|
||||
if tr.PermanentFailure {
|
||||
return NewPermanentRegistrationError(err)
|
||||
}
|
||||
retryAfterSeconds := tr.RetryAfterSeconds
|
||||
if retryAfterSeconds < defaultRetryAfterSeconds {
|
||||
retryAfterSeconds = defaultRetryAfterSeconds
|
||||
}
|
||||
return NewRetryableRegistrationError(err, retryAfterSeconds)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TunnelRegistrationError interface {
|
||||
error
|
||||
Serialize() *TunnelRegistration
|
||||
IsPermanent() bool
|
||||
}
|
||||
|
||||
type PermanentRegistrationError struct {
|
||||
err string
|
||||
}
|
||||
|
||||
func NewPermanentRegistrationError(err error) TunnelRegistrationError {
|
||||
return &PermanentRegistrationError{
|
||||
err: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
func (pre *PermanentRegistrationError) Error() string {
|
||||
return pre.err
|
||||
}
|
||||
|
||||
func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration {
|
||||
return &TunnelRegistration{
|
||||
Err: pre.err,
|
||||
PermanentFailure: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (*PermanentRegistrationError) IsPermanent() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type RetryableRegistrationError struct {
|
||||
err string
|
||||
retryAfterSeconds uint16
|
||||
}
|
||||
|
||||
func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError {
|
||||
return &RetryableRegistrationError{
|
||||
err: err.Error(),
|
||||
retryAfterSeconds: retryAfterSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
func (rre *RetryableRegistrationError) Error() string {
|
||||
return rre.err
|
||||
}
|
||||
|
||||
func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration {
|
||||
return &TunnelRegistration{
|
||||
Err: rre.err,
|
||||
PermanentFailure: false,
|
||||
RetryAfterSeconds: rre.retryAfterSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
func (*RetryableRegistrationError) IsPermanent() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
|
||||
|
@ -63,6 +167,7 @@ type RegistrationOptions struct {
|
|||
RunFromTerminal bool `capnp:"runFromTerminal"`
|
||||
CompressionQuality uint64 `capnp:"compressionQuality"`
|
||||
UUID string `capnp:"uuid"`
|
||||
NumPreviousAttempts uint8
|
||||
}
|
||||
|
||||
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||
|
@ -323,10 +428,12 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar
|
|||
}
|
||||
|
||||
type TunnelServer interface {
|
||||
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
||||
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
|
||||
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
|
||||
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
|
||||
Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error)
|
||||
ReconnectTunnel(ctx context.Context, jwt, eventDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
||||
}
|
||||
|
||||
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||
|
@ -355,15 +462,12 @@ func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerT
|
|||
return err
|
||||
}
|
||||
server.Ack(p.Options)
|
||||
registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
|
||||
|
||||
result, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info(registration.TunnelID)
|
||||
return MarshalTunnelRegistration(result, registration)
|
||||
}
|
||||
|
||||
|
@ -416,7 +520,7 @@ func (c TunnelServer_PogsClient) Close() error {
|
|||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
|
||||
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
|
||||
err := p.SetOriginCert(originCert)
|
||||
|
@ -439,9 +543,13 @@ func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert
|
|||
})
|
||||
retval, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
|
||||
}
|
||||
return UnmarshalTunnelRegistration(retval)
|
||||
registration, err := UnmarshalTunnelRegistration(retval)
|
||||
if err != nil {
|
||||
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
|
||||
}
|
||||
return registration
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -11,6 +12,50 @@ import (
|
|||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
const (
|
||||
testURL = "tunnel.example.com"
|
||||
testTunnelID = "asdfghjkl;"
|
||||
testRetryAfterSeconds = 19
|
||||
)
|
||||
|
||||
var (
|
||||
testErr = fmt.Errorf("Invalid credential")
|
||||
testLogLines = []string{"all", "working"}
|
||||
testEventDigest = []byte("asdf")
|
||||
)
|
||||
|
||||
// *PermanentRegistrationError implements TunnelRegistrationError
|
||||
var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil)
|
||||
|
||||
// *RetryableRegistrationError implements TunnelRegistrationError
|
||||
var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil)
|
||||
|
||||
func TestTunnelRegistration(t *testing.T) {
|
||||
testCases := []*TunnelRegistration{
|
||||
NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID, testEventDigest),
|
||||
NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID, testEventDigest),
|
||||
NewPermanentRegistrationError(testErr).Serialize(),
|
||||
NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(),
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewTunnelRegistration(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalTunnelRegistration(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase #%v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalTunnelRegistration(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase #%v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestConnectResult(t *testing.T) {
|
||||
testCases := []ConnectResult{
|
||||
&ConnectError{
|
||||
|
|
|
@ -19,6 +19,10 @@ struct TunnelRegistration {
|
|||
permanentFailure @3 :Bool;
|
||||
# Displayed to user
|
||||
tunnelID @4 :Text;
|
||||
# How long should this connection wait to retry in seconds, if the error wasn't permanent
|
||||
retryAfterSeconds @5 :UInt16;
|
||||
# A unique ID used to reconnect this tunnel.
|
||||
eventDigest @6 :Data;
|
||||
}
|
||||
|
||||
struct RegistrationOptions {
|
||||
|
@ -44,6 +48,8 @@ struct RegistrationOptions {
|
|||
# cross stream compression setting, 0 - off, 3 - high
|
||||
compressionQuality @10 :UInt64;
|
||||
uuid @11 :Text;
|
||||
# number of previous attempts to send RegisterTunnel/ReconnectTunnel
|
||||
numPreviousAttempts @12 :UInt8;
|
||||
}
|
||||
|
||||
struct CapnpConnectParameters {
|
||||
|
@ -274,11 +280,20 @@ struct FailedConfig {
|
|||
reason @4 :Text;
|
||||
}
|
||||
|
||||
struct AuthenticateResponse {
|
||||
permanentErr @0 :Text;
|
||||
retryableErr @1 :Text;
|
||||
jwt @2 :Data;
|
||||
hoursUntilRefresh @3 :UInt8;
|
||||
}
|
||||
|
||||
interface TunnelServer {
|
||||
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||
getServerInfo @1 () -> (result :ServerInfo);
|
||||
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
|
||||
connect @3 (parameters :CapnpConnectParameters) -> (result :ConnectResult);
|
||||
authenticate @4 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :AuthenticateResponse);
|
||||
reconnectTunnel @5 (jwt :Data, eventDigest :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||
}
|
||||
|
||||
interface ClientService {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -60,6 +60,12 @@ func ValidateHostname(hostname string) (string, error) {
|
|||
|
||||
}
|
||||
|
||||
// ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
|
||||
// Note: when originUrl contains a scheme, the path is removed:
|
||||
// ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
|
||||
// but when it does not, the path is preserved:
|
||||
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
|
||||
// This is arguably a bug, but changing it might break some cloudflared users.
|
||||
func ValidateUrl(originUrl string) (string, error) {
|
||||
if originUrl == "" {
|
||||
return "", fmt.Errorf("URL should not be empty")
|
||||
|
@ -121,6 +127,8 @@ func ValidateUrl(originUrl string) (string, error) {
|
|||
if err != nil {
|
||||
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||
}
|
||||
// This is why the path is preserved when `originUrl` doesn't have a schema.
|
||||
// Using `parsedUrl.Port()` here, instead of `port`, would remove the path
|
||||
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
|
||||
}
|
||||
}
|
||||
|
@ -182,10 +190,11 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
|||
_, secondErr := client.Do(secondRequest)
|
||||
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
||||
return errors.Errorf(
|
||||
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
||||
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
|
||||
parsedURL.Host,
|
||||
oldScheme,
|
||||
parsedURL.Scheme,
|
||||
initialErr,
|
||||
parsedURL,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -53,98 +53,65 @@ func TestValidateHostname(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestValidateUrl(t *testing.T) {
|
||||
type testCase struct {
|
||||
input string
|
||||
expectedOutput string
|
||||
}
|
||||
testCases := []testCase{
|
||||
{"http://localhost", "http://localhost"},
|
||||
{"http://localhost/", "http://localhost"},
|
||||
{"http://localhost/api", "http://localhost"},
|
||||
{"http://localhost/api/", "http://localhost"},
|
||||
{"https://localhost", "https://localhost"},
|
||||
{"https://localhost/", "https://localhost"},
|
||||
{"https://localhost/api", "https://localhost"},
|
||||
{"https://localhost/api/", "https://localhost"},
|
||||
{"https://localhost:8080", "https://localhost:8080"},
|
||||
{"https://localhost:8080/", "https://localhost:8080"},
|
||||
{"https://localhost:8080/api", "https://localhost:8080"},
|
||||
{"https://localhost:8080/api/", "https://localhost:8080"},
|
||||
{"localhost", "http://localhost"},
|
||||
{"localhost/", "http://localhost/"},
|
||||
{"localhost/api", "http://localhost/api"},
|
||||
{"localhost/api/", "http://localhost/api/"},
|
||||
{"localhost:8080", "http://localhost:8080"},
|
||||
{"localhost:8080/", "http://localhost:8080/"},
|
||||
{"localhost:8080/api", "http://localhost:8080/api"},
|
||||
{"localhost:8080/api/", "http://localhost:8080/api/"},
|
||||
{"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
|
||||
{"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
|
||||
{"127.0.0.1:8080", "http://127.0.0.1:8080"},
|
||||
{"127.0.0.1", "http://127.0.0.1"},
|
||||
{"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
|
||||
{"[::1]:8080", "http://[::1]:8080"},
|
||||
{"http://[::1]", "http://[::1]"},
|
||||
{"http://[::1]:8080", "http://[::1]:8080"},
|
||||
{"[::1]", "http://[::1]"},
|
||||
{"https://example.com", "https://example.com"},
|
||||
{"example.com", "http://example.com"},
|
||||
{"http://hello.example.com", "http://hello.example.com"},
|
||||
{"hello.example.com", "http://hello.example.com"},
|
||||
{"hello.example.com:8080", "http://hello.example.com:8080"},
|
||||
{"https://hello.example.com:8080", "https://hello.example.com:8080"},
|
||||
{"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
|
||||
{"bücher.example.com", "http://xn--bcher-kva.example.com"},
|
||||
{"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
|
||||
{"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
validUrl, err := ValidateUrl(testCase.input)
|
||||
assert.NoError(t, err, "test case %v", i)
|
||||
assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i)
|
||||
}
|
||||
|
||||
validUrl, err := ValidateUrl("")
|
||||
assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
|
||||
assert.Empty(t, validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://localhost:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://localhost:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("localhost:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://localhost:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://localhost")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://localhost", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://127.0.0.1:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("127.0.0.1:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("127.0.0.1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://127.0.0.1:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://127.0.0.1:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("[::1]:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://[::1]")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://[::1]:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("[::1]")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://hello.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://hello.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("hello.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://hello.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("hello.example.com:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://hello.example.com:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://hello.example.com:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://bücher.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("bücher.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://hello.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
||||
assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
|
||||
assert.Empty(t, validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||
|
||||
}
|
||||
|
||||
func TestToggleProtocol(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue