Merge branch 'master' into master

This commit is contained in:
Kyle Harding 2019-12-12 08:50:18 -05:00 committed by GitHub
commit 9042025902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 2805 additions and 601 deletions

View File

@ -1,3 +1,19 @@
2019.11.3
- 2019-11-20 TUN-2562: Update Cloudflare Origin CA RSA root
2019.11.2
- 2019-11-18 TUN-2567: AuthOutcome can be turned back into AuthResponse
- 2019-11-18 TUN-2563: Exposes config_version metrics
2019.11.1
- 2019-11-12 Add db-connect, a SQL over HTTPS server
- 2019-11-12 TUN-2053: Add a /healthcheck endpoint to the metrics server
- 2019-11-13 TUN-2178: public API to create new h2mux.MuxedStreamRequest
- 2019-11-13 TUN-2490: respect original representation of HTTP request path
- 2019-11-18 TUN-2547: TunnelRPC definitions for Authenticate flow
- 2019-11-18 TUN-2551: TunnelRPC definitions for ReconnectTunnel flow
- 2019-11-05 TUN-2506: Expose active streams metrics
2019.11.0
- 2019-11-04 TUN-2502: Switch to go modules
- 2019-11-04 TUN-2500: Don't send client registration errors to Sentry

View File

@ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_INTENT"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-reconnect-token",
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "dial-edge-timeout",
Usage: "Maximum wait time to set up a connection with the edge",
@ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Usage: "Absolute path of directory to save SSH host keys in",
EnvVars: []string{"HOST_KEY_PATH"},
Hidden: true,
}),
}
}
}

View File

@ -203,11 +203,14 @@ func prepareTunnelConfig(
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
}
dialContext := (&net.Dialer{
dialer := &net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
}).DialContext
}
if c.Bool("proxy-no-happy-eyeballs") {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
dialContext := dialer.DialContext
if c.IsSet("unix-socket") {
unixSocket, err := config.ValidateUnixSocket(c)
@ -272,6 +275,7 @@ func prepareTunnelConfig(
TlsConfig: toEdgeTLSConfig,
TransportLogger: transportLogger,
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
UseReconnectToken: c.Bool("use-reconnect-token"),
}, nil
}

View File

@ -2,38 +2,26 @@ package connection
import (
"context"
"net"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
openStreamTimeout = 30 * time.Second
)
type dialError struct {
cause error
}
func (e dialError) Error() string {
return e.cause.Error()
}
type Connection struct {
id uuid.UUID
muxer *h2mux.Muxer
}
func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) {
func newConnection(muxer *h2mux.Muxer) (*Connection, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, err
@ -50,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
}
// Connect is used to establish connections with cloudflare's edge network
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
rpcConn, err := c.newRPConn(openStreamCtx, logger)
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
if err != nil {
return nil, errors.Wrap(err, "cannot create new RPC connection")
}
defer rpcConn.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
defer tsClient.Close()
return tsClient.Connect(ctx, parameters)
}
func (c *Connection) Shutdown() {
c.muxer.Shutdown()
}
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
stream, err := c.muxer.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
return rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
), nil
}

54
connection/dial.go Normal file
View File

@ -0,0 +1,54 @@
package connection
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/pkg/errors"
)
// DialEdge makes a TLS connection to a Cloudflare edge node
func DialEdge(
ctx context.Context,
timeout time.Duration,
tlsConfig *tls.Config,
edgeTCPAddr *net.TCPAddr,
) (net.Conn, error) {
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel()
dialer := net.Dialer{}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String())
if err != nil {
return nil, newDialError(err, "DialContext error")
}
tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
if err = tlsEdgeConn.Handshake(); err != nil {
return nil, newDialError(err, "Handshake with edge error")
}
// clear the deadline on the conn; h2mux has its own timeouts
tlsEdgeConn.SetDeadline(time.Time{})
return tlsEdgeConn, nil
}
// DialError is an error returned from DialEdge
type DialError struct {
cause error
}
func newDialError(err error, message string) error {
return DialError{cause: errors.Wrap(err, message)}
}
func (e DialError) Error() string {
return e.cause.Error()
}
func (e DialError) Cause() error {
return e.cause
}

View File

@ -4,19 +4,18 @@ import (
"context"
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -59,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics {
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type EdgeManagerConfigurable struct {
TunnelHostnames []h2mux.TunnelHostname
*pogs.EdgeConnectionConfig
*tunnelpogs.EdgeConnectionConfig
}
type CloudflaredConfig struct {
CloudflaredID uuid.UUID
Tags []pogs.Tag
Tags []tunnelpogs.Tag
BuildInfo *buildinfo.BuildInfo
IntentLabel string
}
@ -127,13 +126,13 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
em.state.updateConfigurable(newConfigurable)
}
func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
edgeIP := em.serviceDiscoverer.Addr()
edgeConn, err := em.dialEdge(ctx, edgeIP)
func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
edgeTCPAddr := em.serviceDiscoverer.Addr()
configurable := em.state.getConfigurable()
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
if err != nil {
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
}
configurable := em.state.getConfigurable()
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
@ -148,14 +147,14 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
}
h2muxConn, err := newConnection(muxer, edgeIP)
h2muxConn, err := newConnection(muxer)
if err != nil {
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
}
go em.serveConn(ctx, h2muxConn)
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
NumPreviousAttempts: 0,
@ -196,28 +195,6 @@ func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
em.state.closeConnection(conn)
}
func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) {
timeout := em.state.getConfigurable().Timeout
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel()
dialer := net.Dialer{DualStack: true}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
}
tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig)
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
if err = tlsEdgeConn.Handshake(); err != nil {
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
tlsEdgeConn.SetDeadline(time.Time{})
return tlsEdgeConn, nil
}
func (em *EdgeManager) noRetryMessage() string {
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
@ -308,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte {
return ems.userCredential
}
func retryConnection(cause string) *pogs.ConnectError {
return &pogs.ConnectError{
func retryConnection(cause string) *tunnelpogs.ConnectError {
return &tunnelpogs.ConnectError{
Cause: cause,
RetryAfter: defaultRetryAfter,
ShouldRetry: true,

View File

@ -4,15 +4,15 @@ import (
"testing"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
var (

49
connection/rpc.go Normal file
View File

@ -0,0 +1,49 @@
package connection
import (
"context"
"fmt"
"time"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// NewRPCClient creates and returns a new RPC client, which will communicate
// using a stream on the given muxer
func NewRPCClient(
ctx context.Context,
muxer *h2mux.Muxer,
logger *logrus.Entry,
openStreamTimeout time.Duration,
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return
}
if !isRPCStreamResponse(stream.Headers) {
stream.Close()
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
return
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger),
)
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
return client, nil
}
func isRPCStreamResponse(headers []h2mux.Header) bool {
return len(headers) == 1 &&
headers[0].Name == ":status" &&
headers[0].Value == "200"
}

View File

@ -13,26 +13,28 @@ type activeStreamMap struct {
sync.RWMutex
// streams tracks open streams.
streams map[uint32]*MuxedStream
// streamsEmpty is a chan that should be closed when no more streams are open.
streamsEmpty chan struct{}
// nextStreamID is the next ID to use on our side of the connection.
// This is odd for clients, even for servers.
nextStreamID uint32
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
maxPeerStreamID uint32
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// ignoreNewStreams is true when the connection is being shut down. New streams
// cannot be registered.
ignoreNewStreams bool
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// streamsEmpty is a chan that will be closed when no more streams are open.
streamsEmptyChan chan struct{}
closeOnce sync.Once
}
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream),
streamsEmpty: make(chan struct{}),
nextStreamID: 1,
activeStreams: activeStreams,
streams: make(map[uint32]*MuxedStream),
streamsEmptyChan: make(chan struct{}),
nextStreamID: 1,
activeStreams: activeStreams,
}
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
if !useClientStreamNumbers {
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
return m
}
func (m *activeStreamMap) notifyStreamsEmpty() {
m.closeOnce.Do(func() {
close(m.streamsEmptyChan)
})
}
// Len returns the number of active streams.
func (m *activeStreamMap) Len() int {
m.RLock()
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
delete(m.streams, streamID)
m.activeStreams.Dec()
}
if len(m.streams) == 0 && m.streamsEmpty != nil {
close(m.streamsEmpty)
m.streamsEmpty = nil
if len(m.streams) == 0 {
m.notifyStreamsEmpty()
}
}
// Shutdown blocks new streams from being created. It returns a channel that receives an event
// once the last stream has closed, or nil if a shutdown is in progress.
func (m *activeStreamMap) Shutdown() <-chan struct{} {
// Shutdown blocks new streams from being created.
// It returns `done`, a channel that is closed once the last stream has closed
// and `progress`, whether a shutdown was already in progress
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
m.Lock()
defer m.Unlock()
if m.ignoreNewStreams {
// already shutting down
return nil
return m.streamsEmptyChan, true
}
m.ignoreNewStreams = true
done := make(chan struct{})
if len(m.streams) == 0 {
// nothing to shut down
close(done)
return done
m.notifyStreamsEmpty()
}
m.streamsEmpty = done
return done
return m.streamsEmptyChan, false
}
// AcquireLocalID acquires a new stream ID for a stream you're opening.
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
stream.Close()
}
m.ignoreNewStreams = true
m.notifyStreamsEmpty()
}

View File

@ -0,0 +1,134 @@
package h2mux
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestShutdown(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{streamID: uint32(streamID)}
ok := m.Set(stream)
assert.True(t, ok)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
shutdownChan2, alreadyInProgress2 := m.Shutdown()
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
// Delete all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
m.Delete(uint32(streamID))
}(i)
}
wg.Wait()
}
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
select {
case <-shutdownChan:
default:
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
}
}
type noopBuffer struct {
isClosed bool
}
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Reset() {}
func (t *noopBuffer) Len() int { return 0 }
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
func (t *noopBuffer) Closed() bool { return t.isClosed }
type noopReadyList struct{}
func (_ *noopReadyList) Signal(streamID uint32) {}
func TestAbort(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
var openedStreams sync.Map
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{
streamID: uint32(streamID),
readBuffer: &noopBuffer{},
writeBuffer: &noopBuffer{},
readyList: &noopReadyList{},
}
ok := m.Set(stream)
assert.True(t, ok)
openedStreams.Store(stream.streamID, stream)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
m.Abort()
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
openedStreams.Range(func(key interface{}, value interface{}) bool {
stream := value.(*MuxedStream)
readBuffer := stream.readBuffer.(*noopBuffer)
writeBuffer := stream.writeBuffer.(*noopBuffer)
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
})
select {
case <-shutdownChan:
default:
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
}
// multiple aborts shouldn't cause any issues
m.Abort()
m.Abort()
m.Abort()
}

View File

@ -542,7 +542,10 @@ func (w *h2DictWriter) Write(p []byte) (n int, err error) {
}
func (w *h2DictWriter) Close() error {
return w.comp.Close()
if w.comp != nil {
return w.comp.Close()
}
return nil
}
// From http2/hpack

View File

@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
}
// Shutdown is called to initiate the "happy path" of muxer termination.
func (m *Muxer) Shutdown() {
// It blocks new streams from being created.
// It returns a channel that is closed when the last stream has been closed.
func (m *Muxer) Shutdown() <-chan struct{} {
m.explicitShutdown.Fuse(true)
m.muxReader.Shutdown()
return m.muxReader.Shutdown()
}
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
@ -390,7 +392,7 @@ func isConnectionClosedError(err error) bool {
// Called by proxy server and tunnel
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
stream := m.NewStream(headers)
if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream, body}); err != nil {
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
return nil, err
}
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
@ -401,7 +403,7 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
stream := m.NewStream(RPCHeaders())
if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream: stream, body: nil}); err != nil {
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
return nil, err
}
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {

View File

@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
EdgeConn: edge,
doneC: make(chan struct{}),
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
Name: "origin",
CompressionQuality: quality,
Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
Name: "edge",
CompressionQuality: quality,
Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
EdgeConn: edge,
doneC: make(chan struct{}),

View File

@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
Closed() bool
}
// MuxedStreamDataSignaller is a write-only *ReadyList
type MuxedStreamDataSignaller interface {
// Non-blocking: call this when data is ready to be sent for the given stream ID.
Signal(ID uint32)
}
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
type MuxedStream struct {
streamID uint32
@ -55,8 +61,8 @@ type MuxedStream struct {
// This is the amount of bytes that are in the peer's receive window
// (how much data we can send from this stream).
sendWindow uint32
// Reference to the muxer's readyList; signal this for stream data to be sent.
readyList *ReadyList
// The muxer's readyList
readyList MuxedStreamDataSignaller
// The headers that should be sent, and a flag so we only send them once.
headersSent bool
writeHeaders []Header
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
return th != ""
}
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream {
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
return &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),

View File

@ -51,10 +51,12 @@ type MuxReader struct {
dictionaries h2Dictionaries
}
func (r *MuxReader) Shutdown() {
done := r.streams.Shutdown()
if done == nil {
return
// Shutdown blocks new streams from being created.
// It returns a channel that is closed once the last stream has closed.
func (r *MuxReader) Shutdown() <-chan struct{} {
done, alreadyInProgress := r.streams.Shutdown()
if alreadyInProgress {
return done
}
r.sendGoAway(http2.ErrCodeNo)
go func() {
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
<-done
r.r.Close()
}()
return done
}
func (r *MuxReader) run(parentLogger *log.Entry) error {

View File

@ -54,6 +54,13 @@ type MuxedStreamRequest struct {
body io.Reader
}
func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest {
return MuxedStreamRequest{
stream: stream,
body: body,
}
}
func (r *MuxedStreamRequest) flushBody() {
io.Copy(r.stream, r.body)
r.stream.CloseWrite()

View File

@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
}
return b.BaseTime
}
// Retries returns the number of retries consumed so far.
func (b *BackoffHandler) Retries() int {
return int(b.retries)
}

View File

@ -2,16 +2,20 @@ package origin
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
"github.com/google/uuid"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -21,11 +25,23 @@ const (
resolveTTL = time.Hour
// Interval between registering new tunnels
registrationInterval = time.Second
subsystemRefreshAuth = "refresh_auth"
// Maximum exponent for 'Authenticate' exponential backoff
refreshAuthMaxBackoff = 10
// Waiting time before retrying a failed 'Authenticate' connection
refreshAuthRetryDuration = time.Second * 10
)
var (
errJWTUnset = errors.New("JWT unset")
errEventDigestUnset = errors.New("event digest unset")
)
type Supervisor struct {
config *TunnelConfig
edgeIPs []*net.TCPAddr
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs []*net.TCPAddr
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
nextUnusedEdgeIP int
lastResolve time.Time
@ -38,6 +54,12 @@ type Supervisor struct {
nextConnectedSignal chan struct{}
logger *logrus.Entry
jwtLock *sync.RWMutex
jwt []byte
eventDigestLock *sync.RWMutex
eventDigest []byte
}
type resolveResult struct {
@ -50,18 +72,21 @@ type tunnelError struct {
err error
}
func NewSupervisor(config *TunnelConfig) *Supervisor {
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
return &Supervisor{
cloudflaredUUID: u,
config: config,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger.WithField("subsystem", "supervisor"),
jwtLock: &sync.RWMutex{},
eventDigestLock: &sync.RWMutex{},
}
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal, u); err != nil {
if err := s.initialize(ctx, connectedSignal); err != nil {
return err
}
var tunnelsWaiting []int
@ -69,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
var backoffTimer <-chan time.Time
tunnelsActive := s.config.HAConnections
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken {
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
for {
select {
// Context cancelled
@ -104,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
// Time to call Authenticate
case <-refreshAuthBackoffTimer:
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
logger.WithError(err).Error("Authentication failed")
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
}
refreshAuthBackoffTimer = newTimer
// Tunnel successfully connected
case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
@ -128,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
}
}
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.logger
edgeIPs, err := s.resolveEdgeIPs()
@ -145,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
s.lastResolve = time.Now()
// check entitlement and version too old error before attempting to register more tunnels
s.nextUnusedEdgeIP = s.config.HAConnections
go s.startFirstTunnel(ctx, connectedSignal, u)
go s.startFirstTunnel(ctx, connectedSignal)
select {
case <-ctx.Done():
<-s.tunnelErrors
// Error can't be nil. A nil error signals that initialization succeed
return fmt.Errorf("context was canceled")
return ctx.Err()
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-connectedSignal.Wait():
@ -158,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch, u)
go s.startTunnel(ctx, i, ch)
time.Sleep(registrationInterval)
}
return nil
@ -166,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
defer func() {
s.tunnelErrors <- tunnelError{index: 0, err: err}
}()
@ -183,19 +224,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
// try the next address if it was a dialError(network problem) or
// dupConnRegisterTunnelError
case dialError, dupConnRegisterTunnelError:
case connection.DialError, dupConnRegisterTunnelError:
s.replaceEdgeIP(0)
default:
return
}
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
@ -253,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
s.nextUnusedEdgeIP++
}
func (s *Supervisor) ReconnectToken() ([]byte, error) {
s.jwtLock.RLock()
defer s.jwtLock.RUnlock()
if s.jwt == nil {
return nil, errJWTUnset
}
return s.jwt, nil
}
func (s *Supervisor) SetReconnectToken(jwt []byte) {
s.jwtLock.Lock()
defer s.jwtLock.Unlock()
s.jwt = jwt
}
func (s *Supervisor) EventDigest() ([]byte, error) {
s.eventDigestLock.RLock()
defer s.eventDigestLock.RUnlock()
if s.eventDigest == nil {
return nil, errEventDigestUnset
}
return s.eventDigest, nil
}
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
s.eventDigestLock.Lock()
defer s.eventDigestLock.Unlock()
s.eventDigest = eventDigest
}
func (s *Supervisor) refreshAuth(
ctx context.Context,
backoff *BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth)
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
logger.WithError(err).Warnf("Retrying in %v", duration)
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
s.SetReconnectToken(outcome.JWT())
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthFail:
return nil, outcome
default:
return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome)
}
}
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
if err != nil {
return nil, err
}
defer edgeConn.Close()
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
// This callback is invoked by h2mux when the edge initiates a stream.
return nil // noop
})
muxerConfig := s.config.muxerConfig(handler)
muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth)
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
if err != nil {
return nil, err
}
go muxer.Serve(ctx)
defer func() {
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
<-muxer.Shutdown()
}()
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout)
if err != nil {
return nil, err
}
defer tunnelServer.Close()
const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
authResponse, err := tunnelServer.Authenticate(
ctx,
s.config.OriginCert,
s.config.Hostname,
registrationOptions,
)
if err != nil {
return nil, err
}
return authResponse.Outcome(), nil
}

128
origin/supervisor_test.go Normal file
View File

@ -0,0 +1,128 @@
package origin
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
func TestRefreshAuthBackoff(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure")
}
// authentication failures should consume the backoff
for i := uint(0); i < backoff.MaxRetries; i++ {
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, (1<<i)*time.Second, wait)
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
// now we actually make contact with the remote server
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
})
// The backoff timer should have been reset. To confirm this, make timeNow
// return a value after the backoff timer's grace period
timeNow = func() time.Time {
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
return time.Now().Add(expectedGracePeriod * 2)
}
_, ok := backoff.GetBackoffDuration(context.Background())
assert.True(t, ok)
}
func TestRefreshAuthSuccess(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken()
assert.NoError(t, err)
assert.Equal(t, []byte("jwt"), token)
}
func TestRefreshAuthUnknown(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}
func TestRefreshAuthFail(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
token, err := s.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}

View File

@ -14,7 +14,14 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/streamhandler"
@ -22,19 +29,12 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
_ "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
rpc "zombiezen.com/go/capnproto2/rpc"
)
const (
dialTimeout = 15 * time.Second
openStreamTimeout = 30 * time.Second
muxerTimeout = 5 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
TagHeaderNamePrefix = "Cf-Warp-Tag-"
DuplicateConnectionError = "EDUPCONN"
@ -73,14 +73,9 @@ type TunnelConfig struct {
WSGI bool
// OriginUrl may not be used if a user specifies a unix socket.
OriginUrl string
}
type dialError struct {
cause error
}
func (e dialError) Error() string {
return e.cause.Error()
// feature-flag to use new edge reconnect tokens
UseReconnectToken bool
}
type dupConnRegisterTunnelError struct{}
@ -119,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string {
return e.cause.Error()
}
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
return h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: handler,
IsClient: true,
HeartbeatInterval: c.HeartbeatInterval,
MaxHeartbeats: c.MaxHeartbeats,
Logger: c.TransportLogger.WithFields(log.Fields{}),
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
}
}
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_balance
if c.HAConnections <= 1 && c.LBPool == "" {
@ -141,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID)
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
}
func ServeTunnelLoop(ctx context.Context,
@ -213,11 +220,11 @@ func ServeTunnel(
tags["ha"] = connectionTag
// Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID)
if err != nil {
errLog := logger.WithError(err)
switch err.(type) {
case dialError:
case connection.DialError:
errLog.Error("Unable to dial edge")
case h2mux.MuxerHandshakeError:
errLog.Error("Handshake failed with edge server")
@ -295,16 +302,6 @@ func ServeTunnel(
return nil, true
}
func IsRPCStreamResponse(headers []h2mux.Header) bool {
if len(headers) != 1 {
return false
}
if headers[0].Name != ":status" || headers[0].Value != "200" {
return false
}
return true
}
func RegisterTunnel(
ctx context.Context,
muxer *h2mux.Muxer,
@ -315,43 +312,31 @@ func RegisterTunnel(
uuid uuid.UUID,
) error {
config.TransportLogger.Debug("initiating RPC stream to register")
stream, err := openStream(ctx, muxer)
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
defer tunnelServer.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
registration, err := ts.RegisterTunnel(
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := tunnelServer.RegisterTunnel(
ctx,
config.OriginCert,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid),
)
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
if err != nil {
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// RegisterTunnel RPC failure
return newClientRegisterTunnelError(err, config.Metrics.regFail)
}
for _, logLine := range registration.LogLines {
logger.Info(logLine)
return processRegisterTunnelError(registrationErr, config.Metrics)
}
if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil {
return regErr
for _, logLine := range registration.LogLines {
logger.Info(logLine)
}
if registration.TunnelID != "" {
@ -374,57 +359,34 @@ func RegisterTunnel(
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
config.Metrics.regSuccess.Inc()
return nil
}
func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error {
if err == "" {
metrics.regSuccess.Inc()
return nil
}
metrics.regFail.WithLabelValues(err).Inc()
if err == DuplicateConnectionError {
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
return dupConnRegisterTunnelError{}
}
metrics.regFail.WithLabelValues("server_error").Inc()
return serverRegisterTunnelError{
cause: fmt.Errorf("Server error: %s", err),
permanent: permanentFailure,
cause: fmt.Errorf("Server error: %s", err.Error()),
permanent: err.IsPermanent(),
}
}
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
logger.Debug("initiating RPC stream to unregister")
ctx := context.Background()
stream, err := openStream(ctx, muxer)
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
if err != nil {
// RPC stream open error
return err
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return err
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
// gracePeriod is encoded in int64 using capnproto
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
}
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
}
func LogServerInfo(
promise tunnelrpc.ServerInfo_Promise,
connectionID uint8,
@ -469,12 +431,12 @@ type TunnelHandler struct {
noChunkedEncoding bool
}
var dialer = net.Dialer{DualStack: true}
var dialer = net.Dialer{}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewTunnelHandler(ctx context.Context,
config *TunnelConfig,
addr string,
addr *net.TCPAddr,
connectionID uint8,
) (*TunnelHandler, string, error) {
originURL, err := validation.ValidateUrl(config.OriginUrl)
@ -495,37 +457,18 @@ func NewTunnelHandler(ctx context.Context,
if h.httpClient == nil {
h.httpClient = http.DefaultTransport
}
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
dialCancel()
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
if err != nil {
return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
return nil, "", err
}
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
err = edgeConn.Handshake()
if err != nil {
return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
edgeConn.SetDeadline(time.Time{})
// Establish a muxed connection with the edge
// Client mux handshake with agent server
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: 5 * time.Second,
Handler: h,
IsClient: true,
HeartbeatInterval: config.HeartbeatInterval,
MaxHeartbeats: config.MaxHeartbeats,
Logger: config.TransportLogger.WithFields(log.Fields{}),
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
}, h.metrics.activeStreams)
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
if err != nil {
return h, "", errors.New("TLS handshake error")
return nil, "", errors.Wrap(err, "Handshake with edge error")
}
return h, edgeConn.LocalAddr().String(), err
return h, edgeConn.LocalAddr().String(), nil
}
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {

View File

@ -35,26 +35,49 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro
return req, nil
}
// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request
// object. This includes conversion of the pseudo-headers into their closest
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
for _, header := range h2 {
switch header.Name {
case ":method":
h1.Method = header.Value
case ":scheme":
// noop - use the preexisting scheme from h1.URL
case ":authority":
// Otherwise the host header will be based on the origin URL
h1.Host = header.Value
case ":path":
u, err := url.Parse(header.Value)
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
// However, this HTTP/1 Request object belongs to the Go standard library,
// whose URL package makes some opinionated decisions about the encoding of
// URL characters: see the docs of https://godoc.org/net/url#URL,
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
// which is always used when computing url.URL.String(), whether we'd like it or not.
//
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
// is treated differently when HTTP_PROXY is set!
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
//
// This means we are subject to the behavior of net/url's function `shouldEscape`
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
if header.Value == "*" {
h1.URL.Path = "*"
continue
}
// Due to the behavior of validation.ValidateUrl, h1.URL may
// already have a partial value, with or without a trailing slash.
base := h1.URL.String()
base = strings.TrimRight(base, "/")
// But we know :path begins with '/', because we handled '*' above - see RFC7540
url, err := url.Parse(base + header.Value)
if err != nil {
return fmt.Errorf("unparseable path")
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
}
resolved := h1.URL.ResolveReference(u)
// prevent escaping base URL
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
return fmt.Errorf("invalid path %s", header.Value)
}
h1.URL = resolved
h1.URL = url
case "content-length":
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
if err != nil {

View File

@ -0,0 +1,441 @@
package streamhandler
import (
"fmt"
"math/rand"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
"testing"
"testing/quick"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: "Mock header 1",
Value: "Mock value 1",
},
h2mux.Header{
Name: "Mock header 2",
Value: "Mock value 2",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header 1": []string{"Mock value 1"},
"Mock header 2": []string{"Mock value 2"},
}, request.Header)
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{},
request,
)
assert.Equal(t, http.Header{}, request.Header)
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: "//bad_path/",
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: "/?query=mock%20value",
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: "/mock%20path",
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
type testCase struct {
path string
want string
}
testCases := []testCase{
{
path: "",
want: "",
},
{
path: "/",
want: "/",
},
{
path: "//",
want: "//",
},
{
path: "/test",
want: "/test",
},
{
path: "//test",
want: "//test",
},
{
// https://github.com/cloudflare/cloudflared/issues/81
path: "//test/",
want: "//test/",
},
{
path: "/%2Ftest",
want: "/%2Ftest",
},
{
path: "//%20test",
want: "//%20test",
},
{
// https://github.com/cloudflare/cloudflared/issues/124
path: "/test?get=somthing%20a",
want: "/test?get=somthing%20a",
},
{
path: "/%20",
want: "/%20",
},
{
// stdlib's EscapedPath() will always percent-encode ' '
path: "/ ",
want: "/%20",
},
{
path: "/ a ",
want: "/%20a%20",
},
{
path: "/a%20b",
want: "/a%20b",
},
{
path: "/foo/bar;param?query#frag",
want: "/foo/bar;param?query#frag",
},
{
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
path: "/a␠b",
want: "/a%E2%90%A0b",
},
{
path: "/a-umlaut-ä",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%C3%A4",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%c3%a4",
want: "/a-umlaut-%c3%a4",
},
{
// here the second '#' is treated as part of the fragment
path: "/a#b#c",
want: "/a#b%23c",
},
{
path: "/a#b␠c",
want: "/a#b%E2%90%A0c",
},
{
path: "/a#b%20c",
want: "/a#b%20c",
},
{
path: "/a#b c",
want: "/a#b%20c",
},
{
// stdlib's EscapedPath() will always percent-encode '\'
path: "/\\",
want: "/%5C",
},
{
path: "/a\\",
want: "/a%5C",
},
{
path: "/a,b.c.",
want: "/a,b.c.",
},
{
path: "/.",
want: "/.",
},
{
// stdlib's EscapedPath() will always percent-encode '`'
path: "/a`",
want: "/a%60",
},
{
path: "/a[0]",
want: "/a[0]",
},
{
path: "/?a[0]=5 &b[]=",
want: "/?a[0]=5 &b[]=",
},
{
path: "/?a=%22b%20%22",
want: "/?a=%22b%20%22",
},
}
for index, testCase := range testCases {
requestURL := "https://example.com"
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{
h2mux.Header{
Name: ":path",
Value: testCase.path,
},
h2mux.Header{
Name: "Mock header",
Value: "Mock value",
},
},
request,
)
assert.NoError(t, headersConversionErr)
assert.Equal(t,
http.Header{
"Mock header": []string{"Mock value"},
},
request.Header)
assert.Equal(t,
"https://example.com"+testCase.want,
request.URL.String(),
"Failed URL index: %v %#v", index, testCase)
}
}
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
config := &quick.Config{
Values: func(args []reflect.Value, rand *rand.Rand) {
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
},
}
type testOrigin struct {
url string
expectedScheme string
expectedBasePath string
}
testOrigins := []testOrigin{
{
url: "http://origin.hostname.example.com:8080",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/api",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "http://origin.hostname.example.com:8080/api/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "https://origin.hostname.example.com:8080/api",
expectedScheme: "https",
expectedBasePath: "https://origin.hostname.example.com:8080/api",
},
}
// use multiple schemes to demonstrate that the URL is based on the
// origin's scheme, not the :scheme header
for _, testScheme := range []string{"http", "https"} {
for _, testOrigin := range testOrigins {
assertion := func(testPath string) bool {
const expectedMethod = "POST"
const expectedHostname = "request.hostname.example.com"
h2 := []h2mux.Header{
h2mux.Header{Name: ":method", Value: expectedMethod},
h2mux.Header{Name: ":scheme", Value: testScheme},
h2mux.Header{Name: ":authority", Value: expectedHostname},
h2mux.Header{Name: ":path", Value: testPath},
}
h1, err := http.NewRequest("GET", testOrigin.url, nil)
require.NoError(t, err)
err = H2RequestHeadersToH1Request(h2, h1)
return assert.NoError(t, err) &&
assert.Equal(t, expectedMethod, h1.Method) &&
assert.Equal(t, expectedHostname, h1.Host) &&
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
}
err := quick.Check(assertion, config)
assert.NoError(t, err)
}
}
}
func randomASCIIPrintableChar(rand *rand.Rand) int {
// smallest printable ASCII char is 32, largest is 126
const startPrintable = 32
const endPrintable = 127
return startPrintable + rand.Intn(endPrintable-startPrintable)
}
// randomASCIIText generates an ASCII string, some of whose characters may be
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
// between 1 and `maxLength`.
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
length := minLength + rand.Intn(maxLength)
result := ""
for i := 0; i < length; i++ {
c := randomASCIIPrintableChar(rand)
// 1/4 chance of using percent encoding when not necessary
if c == '%' || rand.Intn(4) == 0 {
result += fmt.Sprintf("%%%02X", c)
} else {
result += string(c)
}
}
return result
}
// Calls `randomASCIIText` and ensures the result is a valid URL path,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
regexp, err := regexp.Compile("[^/;,]*")
require.NoError(t, err)
return "/" + regexp.ReplaceAllStringFunc(text, url.PathEscape)
}
// Calls `randomASCIIText` and ensures the result is a valid URL query,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
return "?" + strings.ReplaceAll(text, "#", "%23")
}
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
url, err := url.Parse("#" + text)
require.NoError(t, err)
return url.String()
}
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
result := randomHTTP1Path(t, rand, 1, 64)
if rand.Intn(2) == 1 {
result += randomHTTP1Query(t, rand, 1, 32)
}
if rand.Intn(2) == 1 {
result += randomHTTP1Fragment(t, rand, 1, 16)
}
return result
}

View File

@ -16,6 +16,7 @@ import (
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
@ -28,6 +29,27 @@ type Supervisor struct {
useConfigResultChan chan<- *pogs.UseConfigurationResult
state *state
logger *logrus.Entry
metrics metrics
}
type metrics struct {
configVersion prometheus.Gauge
}
func newMetrics() metrics {
configVersion := prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "supervisor",
Subsystem: "supervisor",
Name: "config_version",
Help: "Latest configuration version received from Cloudflare",
},
)
prometheus.MustRegister(
configVersion,
)
return metrics{
configVersion: configVersion,
}
}
func NewSupervisor(
@ -70,6 +92,7 @@ func NewSupervisor(
useConfigResultChan: useConfigResultChan,
state: newState(defaultClientConfig),
logger: logger.WithField("subsystem", "supervisor"),
metrics: newMetrics(),
}, nil
}
@ -131,6 +154,7 @@ func (s *Supervisor) notifySubsystemsNewConfig(newConfig *pogs.ClientConfig) *po
Success: true,
}
}
s.metrics.configVersion.Set(float64(newConfig.Version))
s.state.updateConfig(newConfig)
var tunnelHostnames []h2mux.TunnelHostname

View File

@ -26,28 +26,28 @@ mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
-----END CERTIFICATE-----
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
-----BEGIN CERTIFICATE-----
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
Fu6q54beR89jDc+oABmOgg==
MIIEADCCAuigAwIBAgIID+rOSdTGfGcwDQYJKoZIhvcNAQELBQAwgYsxCzAJBgNV
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91
ZEZsYXJlIE9yaWdpbiBTU0wgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQH
Ew1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMB4XDTE5MDgyMzIx
MDgwMFoXDTI5MDgxNTE3MDAwMFowgYsxCzAJBgNVBAYTAlVTMRkwFwYDVQQKExBD
bG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91ZEZsYXJlIE9yaWdpbiBTU0wg
Q2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRMw
EQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
AQEAwEiVZ/UoQpHmFsHvk5isBxRehukP8DG9JhFev3WZtG76WoTthvLJFRKFCHXm
V6Z5/66Z4S09mgsUuFwvJzMnE6Ej6yIsYNCb9r9QORa8BdhrkNn6kdTly3mdnykb
OomnwbUfLlExVgNdlP0XoRoeMwbQ4598foiHblO2B/LKuNfJzAMfS7oZe34b+vLB
yrP/1bgCSLdc1AxQc1AC0EsQQhgcyTJNgnG4va1c7ogPlwKyhbDyZ4e59N5lbYPJ
SmXI/cAe3jXj1FBLJZkwnoDKe0v13xeF+nF32smSH0qB7aJX2tBMW4TWtFPmzs5I
lwrFSySWAdwYdgxw180yKU0dvwIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAQYwEgYD
VR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUJOhTV118NECHqeuU27rhFnj8KaQw
HwYDVR0jBBgwFoAUJOhTV118NECHqeuU27rhFnj8KaQwDQYJKoZIhvcNAQELBQAD
ggEBAHwOf9Ur1l0Ar5vFE6PNrZWrDfQIMyEfdgSKofCdTckbqXNTiXdgbHs+TWoQ
wAB0pfJDAHJDXOTCWRyTeXOseeOi5Btj5CnEuw3P0oXqdqevM1/+uWp0CM35zgZ8
VD4aITxity0djzE6Qnx3Syzz+ZkoBgTnNum7d9A66/V636x4vTeqbZFBr9erJzgz
hhurjcoacvRNhnjtDRM0dPeiCJ50CP3wEYuvUzDHUaowOsnLCjQIkWbR7Ni6KEIk
MOz2U0OBSif3FTkhCgZWQKOOLo1P42jHC3ssUZAtVNXrCk3fw9/E15k8NPkBazZ6
0iykLhH1trywrKRMVw67F44IE8Y=
-----END CERTIFICATE-----
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
-----BEGIN CERTIFICATE-----

View File

@ -0,0 +1,132 @@
package pogs
import (
"errors"
"time"
)
// AuthenticateResponse is the serialized response from the Authenticate RPC.
// It's a 1:1 representation of the capnp message, so it's not very useful for programmers.
// Instead, you should call the `Outcome()` method to get a programmer-friendly sum type, with one
// case for each possible outcome.
type AuthenticateResponse struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
func (ar AuthenticateResponse) Outcome() AuthOutcome {
// If the user's authentication was unsuccessful, the server will return an error explaining why.
// cloudflared should fatal with this error.
if ar.PermanentErr != "" {
return NewAuthFail(errors.New(ar.PermanentErr))
}
// If there was a network error, then cloudflared should retry later,
// because origintunneld couldn't prove whether auth was correct or not.
if ar.RetryableErr != "" {
return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh)
}
// If auth succeeded, return the token and refresh it when instructed.
if len(ar.Jwt) > 0 {
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
}
// Otherwise the state got messed up.
return nil
}
// AuthOutcome is a programmer-friendly sum type denoting the possible outcomes of Authenticate.
//go-sumtype:decl AuthOutcome
type AuthOutcome interface {
isAuthOutcome()
// Serialize into an AuthenticateResponse which can be sent via Capnp
Serialize() AuthenticateResponse
}
// AuthSuccess means the backend successfully authenticated this cloudflared.
type AuthSuccess struct {
jwt []byte
hoursUntilRefresh uint8
}
func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess {
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
}
func (ao AuthSuccess) JWT() []byte {
return ao.jwt
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao AuthSuccess) RefreshAfter() time.Duration {
return hoursToTime(ao.hoursUntilRefresh)
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthSuccess) Serialize() AuthenticateResponse {
return AuthenticateResponse{
Jwt: ao.jwt,
HoursUntilRefresh: ao.hoursUntilRefresh,
}
}
func (ao AuthSuccess) isAuthOutcome() {}
// AuthFail means this cloudflared has the wrong auth and should exit.
type AuthFail struct {
err error
}
func NewAuthFail(err error) AuthFail {
return AuthFail{err: err}
}
func (ao AuthFail) Error() string {
return ao.err.Error()
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthFail) Serialize() AuthenticateResponse {
return AuthenticateResponse{
PermanentErr: ao.err.Error(),
}
}
func (ao AuthFail) isAuthOutcome() {}
// AuthUnknown means the backend couldn't finish checking authentication. Try again later.
type AuthUnknown struct {
err error
hoursUntilRefresh uint8
}
func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown {
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
}
func (ao AuthUnknown) Error() string {
return ao.err.Error()
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao AuthUnknown) RefreshAfter() time.Duration {
return hoursToTime(ao.hoursUntilRefresh)
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthUnknown) Serialize() AuthenticateResponse {
return AuthenticateResponse{
RetryableErr: ao.err.Error(),
HoursUntilRefresh: ao.hoursUntilRefresh,
}
}
func (ao AuthUnknown) isAuthOutcome() {}
func hoursToTime(hours uint8) time.Duration {
return time.Duration(hours) * time.Hour
}

View File

@ -0,0 +1,78 @@
package pogs
import (
"context"
"github.com/cloudflare/cloudflared/tunnelrpc"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/server"
)
func (i TunnelServer_PogsImpl) Authenticate(p tunnelrpc.TunnelServer_authenticate) error {
originCert, err := p.Params.OriginCert()
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
resp, err := i.impl.Authenticate(p.Ctx, originCert, hostname, pogsOptions)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalAuthenticateResponse(result, resp)
}
func MarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse, p *AuthenticateResponse) error {
return pogs.Insert(tunnelrpc.AuthenticateResponse_TypeID, s.Struct, p)
}
func (c TunnelServer_PogsClient) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.Authenticate(ctx, func(p tunnelrpc.TunnelServer_authenticate_Params) error {
err := p.SetOriginCert(originCert)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalAuthenticateResponse(retval)
}
func UnmarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse) (*AuthenticateResponse, error) {
p := new(AuthenticateResponse)
err := pogs.Extract(p, tunnelrpc.AuthenticateResponse_TypeID, s.Struct)
return p, err
}

134
tunnelrpc/pogs/auth_test.go Normal file
View File

@ -0,0 +1,134 @@
package pogs
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/stretchr/testify/assert"
capnp "zombiezen.com/go/capnproto2"
)
// Ensure the AuthOutcome sum is correct
var _ AuthOutcome = &AuthSuccess{}
var _ AuthOutcome = &AuthFail{}
var _ AuthOutcome = &AuthUnknown{}
// Unit tests for AuthenticateResponse.Outcome()
func TestAuthenticateResponseOutcome(t *testing.T) {
type fields struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
tests := []struct {
name string
fields fields
want AuthOutcome
}{
{"success",
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6},
},
{"fail",
fields{PermanentErr: "bad creds"},
AuthFail{err: fmt.Errorf("bad creds")},
},
{"error",
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6},
},
{"nil (no fields are set)",
fields{},
nil,
},
{"nil (too few fields are set)",
fields{HoursUntilRefresh: 6},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ar := AuthenticateResponse{
PermanentErr: tt.fields.PermanentErr,
RetryableErr: tt.fields.RetryableErr,
Jwt: tt.fields.Jwt,
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
}
got := ar.Outcome()
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
}
if got != nil && !reflect.DeepEqual(got.Serialize(), ar) {
t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize())
}
})
}
}
func TestAuthSuccess(t *testing.T) {
input := NewAuthSuccess([]byte("asdf"), 6)
output, ok := input.Serialize().Outcome().(AuthSuccess)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestAuthUnknown(t *testing.T) {
input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6)
output, ok := input.Serialize().Outcome().(AuthUnknown)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestAuthFail(t *testing.T) {
input := NewAuthFail(fmt.Errorf("wrong creds"))
output, ok := input.Serialize().Outcome().(AuthFail)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestWhenToRefresh(t *testing.T) {
expected := 4 * time.Hour
actual := hoursToTime(4)
if expected != actual {
t.Fatalf("expected %v hours, got %v", expected, actual)
}
}
// Test that serializing and deserializing AuthenticationResponse undo each other.
func TestSerializeAuthenticationResponse(t *testing.T) {
tests := []*AuthenticateResponse{
&AuthenticateResponse{
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
HoursUntilRefresh: 24,
},
&AuthenticateResponse{
PermanentErr: "bad auth",
},
&AuthenticateResponse{
RetryableErr: "bad connection",
HoursUntilRefresh: 24,
},
}
for i, testCase := range tests {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalAuthenticateResponse(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalAuthenticateResponse(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}

View File

@ -197,11 +197,14 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
return nil, err
}
dialContext := (&net.Dialer{
dialer := &net.Dialer{
Timeout: hc.ProxyConnectionTimeout,
KeepAlive: hc.TCPKeepAlive,
DualStack: hc.DialDualStack,
}).DialContext
}
if !hc.DialDualStack {
dialer.FallbackDelay = -1
}
dialContext := dialer.DialContext
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
@ -270,7 +273,6 @@ func (*HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,

View File

@ -0,0 +1,79 @@
package pogs
import (
"context"
"github.com/cloudflare/cloudflared/tunnelrpc"
"zombiezen.com/go/capnproto2/server"
)
func (i TunnelServer_PogsImpl) ReconnectTunnel(p tunnelrpc.TunnelServer_reconnectTunnel) error {
jwt, err := p.Params.Jwt()
if err != nil {
return err
}
eventDigest, err := p.Params.EventDigest()
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
registration, err := i.impl.ReconnectTunnel(p.Ctx, jwt, eventDigest, hostname, pogsOptions)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalTunnelRegistration(result, registration)
}
func (c TunnelServer_PogsClient) ReconnectTunnel(
ctx context.Context,
jwt,
eventDigest []byte,
hostname string,
options *RegistrationOptions,
) (*TunnelRegistration, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.ReconnectTunnel(ctx, func(p tunnelrpc.TunnelServer_reconnectTunnel_Params) error {
err := p.SetJwt(jwt)
if err != nil {
return err
}
err = p.SetEventDigest(eventDigest)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalTunnelRegistration(retval)
}

View File

@ -9,13 +9,16 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
)
const (
defaultRetryAfterSeconds = 15
)
type Authentication struct {
Key string
Email string
@ -33,11 +36,112 @@ func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error
}
type TunnelRegistration struct {
Err string
Url string
LogLines []string
PermanentFailure bool
TunnelID string `capnp:"tunnelID"`
SuccessfulTunnelRegistration
Err string
PermanentFailure bool
RetryAfterSeconds uint16
}
type SuccessfulTunnelRegistration struct {
Url string
LogLines []string
TunnelID string `capnp:"tunnelID"`
EventDigest []byte
}
func NewSuccessfulTunnelRegistration(
url string,
logLines []string,
tunnelID string,
eventDigest []byte,
) *TunnelRegistration {
// Marshal nil will result in an error
if logLines == nil {
logLines = []string{}
}
return &TunnelRegistration{
SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{
Url: url,
LogLines: logLines,
TunnelID: tunnelID,
EventDigest: eventDigest,
},
}
}
// Not calling this function Error() to avoid confusion with implementing error interface
func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError {
if tr.Err != "" {
err := fmt.Errorf(tr.Err)
if tr.PermanentFailure {
return NewPermanentRegistrationError(err)
}
retryAfterSeconds := tr.RetryAfterSeconds
if retryAfterSeconds < defaultRetryAfterSeconds {
retryAfterSeconds = defaultRetryAfterSeconds
}
return NewRetryableRegistrationError(err, retryAfterSeconds)
}
return nil
}
type TunnelRegistrationError interface {
error
Serialize() *TunnelRegistration
IsPermanent() bool
}
type PermanentRegistrationError struct {
err string
}
func NewPermanentRegistrationError(err error) TunnelRegistrationError {
return &PermanentRegistrationError{
err: err.Error(),
}
}
func (pre *PermanentRegistrationError) Error() string {
return pre.err
}
func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration {
return &TunnelRegistration{
Err: pre.err,
PermanentFailure: true,
}
}
func (*PermanentRegistrationError) IsPermanent() bool {
return true
}
type RetryableRegistrationError struct {
err string
retryAfterSeconds uint16
}
func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError {
return &RetryableRegistrationError{
err: err.Error(),
retryAfterSeconds: retryAfterSeconds,
}
}
func (rre *RetryableRegistrationError) Error() string {
return rre.err
}
func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration {
return &TunnelRegistration{
Err: rre.err,
PermanentFailure: false,
RetryAfterSeconds: rre.retryAfterSeconds,
}
}
func (*RetryableRegistrationError) IsPermanent() bool {
return false
}
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
@ -63,6 +167,7 @@ type RegistrationOptions struct {
RunFromTerminal bool `capnp:"runFromTerminal"`
CompressionQuality uint64 `capnp:"compressionQuality"`
UUID string `capnp:"uuid"`
NumPreviousAttempts uint8
}
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
@ -323,10 +428,12 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar
}
type TunnelServer interface {
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
GetServerInfo(ctx context.Context) (*ServerInfo, error)
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error)
ReconnectTunnel(ctx context.Context, jwt, eventDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
}
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
@ -355,15 +462,12 @@ func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerT
return err
}
server.Ack(p.Options)
registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
if err != nil {
return err
}
registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
result, err := p.Results.NewResult()
if err != nil {
return err
}
log.Info(registration.TunnelID)
return MarshalTunnelRegistration(result, registration)
}
@ -416,7 +520,7 @@ func (c TunnelServer_PogsClient) Close() error {
return c.Conn.Close()
}
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
err := p.SetOriginCert(originCert)
@ -439,9 +543,13 @@ func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
return UnmarshalTunnelRegistration(retval)
registration, err := UnmarshalTunnelRegistration(retval)
if err != nil {
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
return registration
}
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {

View File

@ -1,6 +1,7 @@
package pogs
import (
"fmt"
"reflect"
"testing"
"time"
@ -11,6 +12,50 @@ import (
capnp "zombiezen.com/go/capnproto2"
)
const (
testURL = "tunnel.example.com"
testTunnelID = "asdfghjkl;"
testRetryAfterSeconds = 19
)
var (
testErr = fmt.Errorf("Invalid credential")
testLogLines = []string{"all", "working"}
testEventDigest = []byte("asdf")
)
// *PermanentRegistrationError implements TunnelRegistrationError
var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil)
// *RetryableRegistrationError implements TunnelRegistrationError
var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil)
func TestTunnelRegistration(t *testing.T) {
testCases := []*TunnelRegistration{
NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID, testEventDigest),
NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID, testEventDigest),
NewPermanentRegistrationError(testErr).Serialize(),
NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewTunnelRegistration(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalTunnelRegistration(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase #%v failed to marshal", i) {
continue
}
result, err := UnmarshalTunnelRegistration(capnpEntity)
if !assert.NoError(t, err, "testCase #%v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestConnectResult(t *testing.T) {
testCases := []ConnectResult{
&ConnectError{

View File

@ -19,6 +19,10 @@ struct TunnelRegistration {
permanentFailure @3 :Bool;
# Displayed to user
tunnelID @4 :Text;
# How long should this connection wait to retry in seconds, if the error wasn't permanent
retryAfterSeconds @5 :UInt16;
# A unique ID used to reconnect this tunnel.
eventDigest @6 :Data;
}
struct RegistrationOptions {
@ -44,6 +48,8 @@ struct RegistrationOptions {
# cross stream compression setting, 0 - off, 3 - high
compressionQuality @10 :UInt64;
uuid @11 :Text;
# number of previous attempts to send RegisterTunnel/ReconnectTunnel
numPreviousAttempts @12 :UInt8;
}
struct CapnpConnectParameters {
@ -274,11 +280,20 @@ struct FailedConfig {
reason @4 :Text;
}
struct AuthenticateResponse {
permanentErr @0 :Text;
retryableErr @1 :Text;
jwt @2 :Data;
hoursUntilRefresh @3 :UInt8;
}
interface TunnelServer {
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
getServerInfo @1 () -> (result :ServerInfo);
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
connect @3 (parameters :CapnpConnectParameters) -> (result :ConnectResult);
authenticate @4 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :AuthenticateResponse);
reconnectTunnel @5 (jwt :Data, eventDigest :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
}
interface ClientService {

File diff suppressed because it is too large Load Diff

View File

@ -60,6 +60,12 @@ func ValidateHostname(hostname string) (string, error) {
}
// ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
// Note: when originUrl contains a scheme, the path is removed:
// ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
// but when it does not, the path is preserved:
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
// This is arguably a bug, but changing it might break some cloudflared users.
func ValidateUrl(originUrl string) (string, error) {
if originUrl == "" {
return "", fmt.Errorf("URL should not be empty")
@ -121,6 +127,8 @@ func ValidateUrl(originUrl string) (string, error) {
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
// This is why the path is preserved when `originUrl` doesn't have a schema.
// Using `parsedUrl.Port()` here, instead of `port`, would remove the path
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
}
}
@ -182,10 +190,11 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
_, secondErr := client.Do(secondRequest)
if secondErr == nil { // Worked this time--advise the user to switch protocols
return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
parsedURL.Host,
oldScheme,
parsedURL.Scheme,
initialErr,
parsedURL,
)
}

View File

@ -53,98 +53,65 @@ func TestValidateHostname(t *testing.T) {
}
func TestValidateUrl(t *testing.T) {
type testCase struct {
input string
expectedOutput string
}
testCases := []testCase{
{"http://localhost", "http://localhost"},
{"http://localhost/", "http://localhost"},
{"http://localhost/api", "http://localhost"},
{"http://localhost/api/", "http://localhost"},
{"https://localhost", "https://localhost"},
{"https://localhost/", "https://localhost"},
{"https://localhost/api", "https://localhost"},
{"https://localhost/api/", "https://localhost"},
{"https://localhost:8080", "https://localhost:8080"},
{"https://localhost:8080/", "https://localhost:8080"},
{"https://localhost:8080/api", "https://localhost:8080"},
{"https://localhost:8080/api/", "https://localhost:8080"},
{"localhost", "http://localhost"},
{"localhost/", "http://localhost/"},
{"localhost/api", "http://localhost/api"},
{"localhost/api/", "http://localhost/api/"},
{"localhost:8080", "http://localhost:8080"},
{"localhost:8080/", "http://localhost:8080/"},
{"localhost:8080/api", "http://localhost:8080/api"},
{"localhost:8080/api/", "http://localhost:8080/api/"},
{"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
{"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
{"127.0.0.1:8080", "http://127.0.0.1:8080"},
{"127.0.0.1", "http://127.0.0.1"},
{"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
{"[::1]:8080", "http://[::1]:8080"},
{"http://[::1]", "http://[::1]"},
{"http://[::1]:8080", "http://[::1]:8080"},
{"[::1]", "http://[::1]"},
{"https://example.com", "https://example.com"},
{"example.com", "http://example.com"},
{"http://hello.example.com", "http://hello.example.com"},
{"hello.example.com", "http://hello.example.com"},
{"hello.example.com:8080", "http://hello.example.com:8080"},
{"https://hello.example.com:8080", "https://hello.example.com:8080"},
{"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
{"bücher.example.com", "http://xn--bcher-kva.example.com"},
{"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
{"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
}
for i, testCase := range testCases {
validUrl, err := ValidateUrl(testCase.input)
assert.NoError(t, err, "test case %v", i)
assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i)
}
validUrl, err := ValidateUrl("")
assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("https://localhost:8080")
assert.Nil(t, err)
assert.Equal(t, "https://localhost:8080", validUrl)
validUrl, err = ValidateUrl("localhost:8080")
assert.Nil(t, err)
assert.Equal(t, "http://localhost:8080", validUrl)
validUrl, err = ValidateUrl("http://localhost")
assert.Nil(t, err)
assert.Equal(t, "http://localhost", validUrl)
validUrl, err = ValidateUrl("http://127.0.0.1:8080")
assert.Nil(t, err)
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
validUrl, err = ValidateUrl("127.0.0.1:8080")
assert.Nil(t, err)
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
validUrl, err = ValidateUrl("127.0.0.1")
assert.Nil(t, err)
assert.Equal(t, "http://127.0.0.1", validUrl)
validUrl, err = ValidateUrl("https://127.0.0.1:8080")
assert.Nil(t, err)
assert.Equal(t, "https://127.0.0.1:8080", validUrl)
validUrl, err = ValidateUrl("[::1]:8080")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]:8080", validUrl)
validUrl, err = ValidateUrl("http://[::1]")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]", validUrl)
validUrl, err = ValidateUrl("http://[::1]:8080")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]:8080", validUrl)
validUrl, err = ValidateUrl("[::1]")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]", validUrl)
validUrl, err = ValidateUrl("https://example.com")
assert.Nil(t, err)
assert.Equal(t, "https://example.com", validUrl)
validUrl, err = ValidateUrl("example.com")
assert.Nil(t, err)
assert.Equal(t, "http://example.com", validUrl)
validUrl, err = ValidateUrl("http://hello.example.com")
assert.Nil(t, err)
assert.Equal(t, "http://hello.example.com", validUrl)
validUrl, err = ValidateUrl("hello.example.com")
assert.Nil(t, err)
assert.Equal(t, "http://hello.example.com", validUrl)
validUrl, err = ValidateUrl("hello.example.com:8080")
assert.Nil(t, err)
assert.Equal(t, "http://hello.example.com:8080", validUrl)
validUrl, err = ValidateUrl("https://hello.example.com:8080")
assert.Nil(t, err)
assert.Equal(t, "https://hello.example.com:8080", validUrl)
validUrl, err = ValidateUrl("https://bücher.example.com")
assert.Nil(t, err)
assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl)
validUrl, err = ValidateUrl("bücher.example.com")
assert.Nil(t, err)
assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl)
validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com")
assert.Nil(t, err)
assert.Equal(t, "https://hello.example.com", validUrl)
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
assert.Nil(t, err)
assert.Equal(t, "https://hello.example.com:8080", validUrl)
}
func TestToggleProtocol(t *testing.T) {