Merge branch 'master' into master
This commit is contained in:
commit
9042025902
|
@ -1,3 +1,19 @@
|
||||||
|
2019.11.3
|
||||||
|
- 2019-11-20 TUN-2562: Update Cloudflare Origin CA RSA root
|
||||||
|
|
||||||
|
2019.11.2
|
||||||
|
- 2019-11-18 TUN-2567: AuthOutcome can be turned back into AuthResponse
|
||||||
|
- 2019-11-18 TUN-2563: Exposes config_version metrics
|
||||||
|
|
||||||
|
2019.11.1
|
||||||
|
- 2019-11-12 Add db-connect, a SQL over HTTPS server
|
||||||
|
- 2019-11-12 TUN-2053: Add a /healthcheck endpoint to the metrics server
|
||||||
|
- 2019-11-13 TUN-2178: public API to create new h2mux.MuxedStreamRequest
|
||||||
|
- 2019-11-13 TUN-2490: respect original representation of HTTP request path
|
||||||
|
- 2019-11-18 TUN-2547: TunnelRPC definitions for Authenticate flow
|
||||||
|
- 2019-11-18 TUN-2551: TunnelRPC definitions for ReconnectTunnel flow
|
||||||
|
- 2019-11-05 TUN-2506: Expose active streams metrics
|
||||||
|
|
||||||
2019.11.0
|
2019.11.0
|
||||||
- 2019-11-04 TUN-2502: Switch to go modules
|
- 2019-11-04 TUN-2502: Switch to go modules
|
||||||
- 2019-11-04 TUN-2500: Don't send client registration errors to Sentry
|
- 2019-11-04 TUN-2500: Don't send client registration errors to Sentry
|
||||||
|
|
|
@ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
EnvVars: []string{"TUNNEL_INTENT"},
|
EnvVars: []string{"TUNNEL_INTENT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "use-reconnect-token",
|
||||||
|
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",
|
||||||
|
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
Name: "dial-edge-timeout",
|
Name: "dial-edge-timeout",
|
||||||
Usage: "Maximum wait time to set up a connection with the edge",
|
Usage: "Maximum wait time to set up a connection with the edge",
|
||||||
|
@ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Usage: "Absolute path of directory to save SSH host keys in",
|
Usage: "Absolute path of directory to save SSH host keys in",
|
||||||
EnvVars: []string{"HOST_KEY_PATH"},
|
EnvVars: []string{"HOST_KEY_PATH"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
|
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -203,11 +203,14 @@ func prepareTunnelConfig(
|
||||||
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
|
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
|
||||||
}
|
}
|
||||||
|
|
||||||
dialContext := (&net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: c.Duration("proxy-connect-timeout"),
|
Timeout: c.Duration("proxy-connect-timeout"),
|
||||||
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
||||||
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
|
}
|
||||||
}).DialContext
|
if c.Bool("proxy-no-happy-eyeballs") {
|
||||||
|
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
|
||||||
|
}
|
||||||
|
dialContext := dialer.DialContext
|
||||||
|
|
||||||
if c.IsSet("unix-socket") {
|
if c.IsSet("unix-socket") {
|
||||||
unixSocket, err := config.ValidateUnixSocket(c)
|
unixSocket, err := config.ValidateUnixSocket(c)
|
||||||
|
@ -272,6 +275,7 @@ func prepareTunnelConfig(
|
||||||
TlsConfig: toEdgeTLSConfig,
|
TlsConfig: toEdgeTLSConfig,
|
||||||
TransportLogger: transportLogger,
|
TransportLogger: transportLogger,
|
||||||
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
|
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
|
||||||
|
UseReconnectToken: c.Bool("use-reconnect-token"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,38 +2,26 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
openStreamTimeout = 30 * time.Second
|
openStreamTimeout = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialError struct {
|
|
||||||
cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e dialError) Error() string {
|
|
||||||
return e.cause.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Connection struct {
|
type Connection struct {
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
muxer *h2mux.Muxer
|
muxer *h2mux.Muxer
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) {
|
func newConnection(muxer *h2mux.Muxer) (*Connection, error) {
|
||||||
id, err := uuid.NewRandom()
|
id, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -50,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect is used to establish connections with cloudflare's edge network
|
// Connect is used to establish connections with cloudflare's edge network
|
||||||
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) {
|
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
|
||||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
rpcConn, err := c.newRPConn(openStreamCtx, logger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
||||||
}
|
}
|
||||||
defer rpcConn.Close()
|
defer tsClient.Close()
|
||||||
|
|
||||||
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
|
|
||||||
|
|
||||||
return tsClient.Connect(ctx, parameters)
|
return tsClient.Connect(ctx, parameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) Shutdown() {
|
func (c *Connection) Shutdown() {
|
||||||
c.muxer.Shutdown()
|
c.muxer.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
|
|
||||||
stream, err := c.muxer.OpenRPCStream(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rpc.NewConn(
|
|
||||||
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
|
|
||||||
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
|
|
||||||
), nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialEdge makes a TLS connection to a Cloudflare edge node
|
||||||
|
func DialEdge(
|
||||||
|
ctx context.Context,
|
||||||
|
timeout time.Duration,
|
||||||
|
tlsConfig *tls.Config,
|
||||||
|
edgeTCPAddr *net.TCPAddr,
|
||||||
|
) (net.Conn, error) {
|
||||||
|
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||||
|
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer dialCancel()
|
||||||
|
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, newDialError(err, "DialContext error")
|
||||||
|
}
|
||||||
|
tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
|
||||||
|
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
|
||||||
|
|
||||||
|
if err = tlsEdgeConn.Handshake(); err != nil {
|
||||||
|
return nil, newDialError(err, "Handshake with edge error")
|
||||||
|
}
|
||||||
|
// clear the deadline on the conn; h2mux has its own timeouts
|
||||||
|
tlsEdgeConn.SetDeadline(time.Time{})
|
||||||
|
return tlsEdgeConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialError is an error returned from DialEdge
|
||||||
|
type DialError struct {
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDialError(err error, message string) error {
|
||||||
|
return DialError{cause: errors.Wrap(err, message)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e DialError) Error() string {
|
||||||
|
return e.cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e DialError) Cause() error {
|
||||||
|
return e.cause
|
||||||
|
}
|
|
@ -4,19 +4,18 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -59,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics {
|
||||||
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
||||||
type EdgeManagerConfigurable struct {
|
type EdgeManagerConfigurable struct {
|
||||||
TunnelHostnames []h2mux.TunnelHostname
|
TunnelHostnames []h2mux.TunnelHostname
|
||||||
*pogs.EdgeConnectionConfig
|
*tunnelpogs.EdgeConnectionConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloudflaredConfig struct {
|
type CloudflaredConfig struct {
|
||||||
CloudflaredID uuid.UUID
|
CloudflaredID uuid.UUID
|
||||||
Tags []pogs.Tag
|
Tags []tunnelpogs.Tag
|
||||||
BuildInfo *buildinfo.BuildInfo
|
BuildInfo *buildinfo.BuildInfo
|
||||||
IntentLabel string
|
IntentLabel string
|
||||||
}
|
}
|
||||||
|
@ -127,13 +126,13 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
|
||||||
em.state.updateConfigurable(newConfigurable)
|
em.state.updateConfigurable(newConfigurable)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
|
func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
|
||||||
edgeIP := em.serviceDiscoverer.Addr()
|
edgeTCPAddr := em.serviceDiscoverer.Addr()
|
||||||
edgeConn, err := em.dialEdge(ctx, edgeIP)
|
configurable := em.state.getConfigurable()
|
||||||
|
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
|
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
|
||||||
}
|
}
|
||||||
configurable := em.state.getConfigurable()
|
|
||||||
// Establish a muxed connection with the edge
|
// Establish a muxed connection with the edge
|
||||||
// Client mux handshake with agent server
|
// Client mux handshake with agent server
|
||||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||||
|
@ -148,14 +147,14 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
|
||||||
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
|
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
h2muxConn, err := newConnection(muxer, edgeIP)
|
h2muxConn, err := newConnection(muxer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
|
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
go em.serveConn(ctx, h2muxConn)
|
go em.serveConn(ctx, h2muxConn)
|
||||||
|
|
||||||
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
|
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
|
||||||
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
||||||
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
||||||
NumPreviousAttempts: 0,
|
NumPreviousAttempts: 0,
|
||||||
|
@ -196,28 +195,6 @@ func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
|
||||||
em.state.closeConnection(conn)
|
em.state.closeConnection(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) {
|
|
||||||
timeout := em.state.getConfigurable().Timeout
|
|
||||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
|
||||||
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
|
|
||||||
defer dialCancel()
|
|
||||||
|
|
||||||
dialer := net.Dialer{DualStack: true}
|
|
||||||
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
|
|
||||||
}
|
|
||||||
tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig)
|
|
||||||
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
|
|
||||||
|
|
||||||
if err = tlsEdgeConn.Handshake(); err != nil {
|
|
||||||
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
|
||||||
}
|
|
||||||
// clear the deadline on the conn; h2mux has its own timeouts
|
|
||||||
tlsEdgeConn.SetDeadline(time.Time{})
|
|
||||||
return tlsEdgeConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (em *EdgeManager) noRetryMessage() string {
|
func (em *EdgeManager) noRetryMessage() string {
|
||||||
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
|
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
|
||||||
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
|
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
|
||||||
|
@ -308,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte {
|
||||||
return ems.userCredential
|
return ems.userCredential
|
||||||
}
|
}
|
||||||
|
|
||||||
func retryConnection(cause string) *pogs.ConnectError {
|
func retryConnection(cause string) *tunnelpogs.ConnectError {
|
||||||
return &pogs.ConnectError{
|
return &tunnelpogs.ConnectError{
|
||||||
Cause: cause,
|
Cause: cause,
|
||||||
RetryAfter: defaultRetryAfter,
|
RetryAfter: defaultRetryAfter,
|
||||||
ShouldRetry: true,
|
ShouldRetry: true,
|
||||||
|
|
|
@ -4,15 +4,15 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/google/uuid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRPCClient creates and returns a new RPC client, which will communicate
|
||||||
|
// using a stream on the given muxer
|
||||||
|
func NewRPCClient(
|
||||||
|
ctx context.Context,
|
||||||
|
muxer *h2mux.Muxer,
|
||||||
|
logger *logrus.Entry,
|
||||||
|
openStreamTimeout time.Duration,
|
||||||
|
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
|
||||||
|
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||||
|
defer openStreamCancel()
|
||||||
|
stream, err := muxer.OpenRPCStream(openStreamCtx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isRPCStreamResponse(stream.Headers) {
|
||||||
|
stream.Close()
|
||||||
|
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := rpc.NewConn(
|
||||||
|
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
||||||
|
tunnelrpc.ConnLog(logger),
|
||||||
|
)
|
||||||
|
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRPCStreamResponse(headers []h2mux.Header) bool {
|
||||||
|
return len(headers) == 1 &&
|
||||||
|
headers[0].Name == ":status" &&
|
||||||
|
headers[0].Value == "200"
|
||||||
|
}
|
|
@ -13,26 +13,28 @@ type activeStreamMap struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
// streams tracks open streams.
|
// streams tracks open streams.
|
||||||
streams map[uint32]*MuxedStream
|
streams map[uint32]*MuxedStream
|
||||||
// streamsEmpty is a chan that should be closed when no more streams are open.
|
|
||||||
streamsEmpty chan struct{}
|
|
||||||
// nextStreamID is the next ID to use on our side of the connection.
|
// nextStreamID is the next ID to use on our side of the connection.
|
||||||
// This is odd for clients, even for servers.
|
// This is odd for clients, even for servers.
|
||||||
nextStreamID uint32
|
nextStreamID uint32
|
||||||
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||||
maxPeerStreamID uint32
|
maxPeerStreamID uint32
|
||||||
|
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||||
|
activeStreams prometheus.Gauge
|
||||||
|
|
||||||
// ignoreNewStreams is true when the connection is being shut down. New streams
|
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||||
// cannot be registered.
|
// cannot be registered.
|
||||||
ignoreNewStreams bool
|
ignoreNewStreams bool
|
||||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
// streamsEmpty is a chan that will be closed when no more streams are open.
|
||||||
activeStreams prometheus.Gauge
|
streamsEmptyChan chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
||||||
m := &activeStreamMap{
|
m := &activeStreamMap{
|
||||||
streams: make(map[uint32]*MuxedStream),
|
streams: make(map[uint32]*MuxedStream),
|
||||||
streamsEmpty: make(chan struct{}),
|
streamsEmptyChan: make(chan struct{}),
|
||||||
nextStreamID: 1,
|
nextStreamID: 1,
|
||||||
activeStreams: activeStreams,
|
activeStreams: activeStreams,
|
||||||
}
|
}
|
||||||
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
||||||
if !useClientStreamNumbers {
|
if !useClientStreamNumbers {
|
||||||
|
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *activeStreamMap) notifyStreamsEmpty() {
|
||||||
|
m.closeOnce.Do(func() {
|
||||||
|
close(m.streamsEmptyChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Len returns the number of active streams.
|
// Len returns the number of active streams.
|
||||||
func (m *activeStreamMap) Len() int {
|
func (m *activeStreamMap) Len() int {
|
||||||
m.RLock()
|
m.RLock()
|
||||||
|
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
|
||||||
delete(m.streams, streamID)
|
delete(m.streams, streamID)
|
||||||
m.activeStreams.Dec()
|
m.activeStreams.Dec()
|
||||||
}
|
}
|
||||||
if len(m.streams) == 0 && m.streamsEmpty != nil {
|
if len(m.streams) == 0 {
|
||||||
close(m.streamsEmpty)
|
m.notifyStreamsEmpty()
|
||||||
m.streamsEmpty = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown blocks new streams from being created. It returns a channel that receives an event
|
// Shutdown blocks new streams from being created.
|
||||||
// once the last stream has closed, or nil if a shutdown is in progress.
|
// It returns `done`, a channel that is closed once the last stream has closed
|
||||||
func (m *activeStreamMap) Shutdown() <-chan struct{} {
|
// and `progress`, whether a shutdown was already in progress
|
||||||
|
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
if m.ignoreNewStreams {
|
if m.ignoreNewStreams {
|
||||||
// already shutting down
|
// already shutting down
|
||||||
return nil
|
return m.streamsEmptyChan, true
|
||||||
}
|
}
|
||||||
m.ignoreNewStreams = true
|
m.ignoreNewStreams = true
|
||||||
done := make(chan struct{})
|
|
||||||
if len(m.streams) == 0 {
|
if len(m.streams) == 0 {
|
||||||
// nothing to shut down
|
// nothing to shut down
|
||||||
close(done)
|
m.notifyStreamsEmpty()
|
||||||
return done
|
|
||||||
}
|
}
|
||||||
m.streamsEmpty = done
|
return m.streamsEmptyChan, false
|
||||||
return done
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||||
|
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
}
|
}
|
||||||
m.ignoreNewStreams = true
|
m.ignoreNewStreams = true
|
||||||
|
m.notifyStreamsEmpty()
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShutdown(t *testing.T) {
|
||||||
|
const numStreams = 1000
|
||||||
|
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||||
|
|
||||||
|
// Add all the streams
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numStreams)
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
go func(streamID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
stream := &MuxedStream{streamID: uint32(streamID)}
|
||||||
|
ok := m.Set(stream)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||||
|
|
||||||
|
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.False(t, alreadyInProgress)
|
||||||
|
|
||||||
|
shutdownChan2, alreadyInProgress2 := m.Shutdown()
|
||||||
|
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
|
||||||
|
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
|
||||||
|
|
||||||
|
// Delete all the streams
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numStreams)
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
go func(streamID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
m.Delete(uint32(streamID))
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
default:
|
||||||
|
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopBuffer struct {
|
||||||
|
isClosed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
|
||||||
|
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
|
||||||
|
func (t *noopBuffer) Reset() {}
|
||||||
|
func (t *noopBuffer) Len() int { return 0 }
|
||||||
|
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
|
||||||
|
func (t *noopBuffer) Closed() bool { return t.isClosed }
|
||||||
|
|
||||||
|
type noopReadyList struct{}
|
||||||
|
|
||||||
|
func (_ *noopReadyList) Signal(streamID uint32) {}
|
||||||
|
|
||||||
|
func TestAbort(t *testing.T) {
|
||||||
|
const numStreams = 1000
|
||||||
|
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||||
|
|
||||||
|
var openedStreams sync.Map
|
||||||
|
|
||||||
|
// Add all the streams
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numStreams)
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
go func(streamID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
stream := &MuxedStream{
|
||||||
|
streamID: uint32(streamID),
|
||||||
|
readBuffer: &noopBuffer{},
|
||||||
|
writeBuffer: &noopBuffer{},
|
||||||
|
readyList: &noopReadyList{},
|
||||||
|
}
|
||||||
|
ok := m.Set(stream)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
openedStreams.Store(stream.streamID, stream)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||||
|
|
||||||
|
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.False(t, alreadyInProgress)
|
||||||
|
|
||||||
|
m.Abort()
|
||||||
|
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
|
||||||
|
openedStreams.Range(func(key interface{}, value interface{}) bool {
|
||||||
|
stream := value.(*MuxedStream)
|
||||||
|
readBuffer := stream.readBuffer.(*noopBuffer)
|
||||||
|
writeBuffer := stream.writeBuffer.(*noopBuffer)
|
||||||
|
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
default:
|
||||||
|
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// multiple aborts shouldn't cause any issues
|
||||||
|
m.Abort()
|
||||||
|
m.Abort()
|
||||||
|
m.Abort()
|
||||||
|
}
|
|
@ -542,7 +542,10 @@ func (w *h2DictWriter) Write(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *h2DictWriter) Close() error {
|
func (w *h2DictWriter) Close() error {
|
||||||
return w.comp.Close()
|
if w.comp != nil {
|
||||||
|
return w.comp.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// From http2/hpack
|
// From http2/hpack
|
||||||
|
|
|
@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called to initiate the "happy path" of muxer termination.
|
// Shutdown is called to initiate the "happy path" of muxer termination.
|
||||||
func (m *Muxer) Shutdown() {
|
// It blocks new streams from being created.
|
||||||
|
// It returns a channel that is closed when the last stream has been closed.
|
||||||
|
func (m *Muxer) Shutdown() <-chan struct{} {
|
||||||
m.explicitShutdown.Fuse(true)
|
m.explicitShutdown.Fuse(true)
|
||||||
m.muxReader.Shutdown()
|
return m.muxReader.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||||
|
@ -390,7 +392,7 @@ func isConnectionClosedError(err error) bool {
|
||||||
// Called by proxy server and tunnel
|
// Called by proxy server and tunnel
|
||||||
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
|
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
|
||||||
stream := m.NewStream(headers)
|
stream := m.NewStream(headers)
|
||||||
if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream, body}); err != nil {
|
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
||||||
|
@ -401,7 +403,7 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader
|
||||||
|
|
||||||
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
|
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
|
||||||
stream := m.NewStream(RPCHeaders())
|
stream := m.NewStream(RPCHeaders())
|
||||||
if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream: stream, body: nil}); err != nil {
|
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
||||||
|
|
|
@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
||||||
DefaultWindowSize: (1 << 8) - 1,
|
DefaultWindowSize: (1 << 8) - 1,
|
||||||
MaxWindowSize: (1 << 15) - 1,
|
MaxWindowSize: (1 << 15) - 1,
|
||||||
StreamWriteBufferMaxLen: 1024,
|
StreamWriteBufferMaxLen: 1024,
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
OriginConn: origin,
|
OriginConn: origin,
|
||||||
EdgeMuxConfig: MuxerConfig{
|
EdgeMuxConfig: MuxerConfig{
|
||||||
|
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
||||||
DefaultWindowSize: (1 << 8) - 1,
|
DefaultWindowSize: (1 << 8) - 1,
|
||||||
MaxWindowSize: (1 << 15) - 1,
|
MaxWindowSize: (1 << 15) - 1,
|
||||||
StreamWriteBufferMaxLen: 1024,
|
StreamWriteBufferMaxLen: 1024,
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
EdgeConn: edge,
|
EdgeConn: edge,
|
||||||
doneC: make(chan struct{}),
|
doneC: make(chan struct{}),
|
||||||
|
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
||||||
Name: "origin",
|
Name: "origin",
|
||||||
CompressionQuality: quality,
|
CompressionQuality: quality,
|
||||||
Logger: log.NewEntry(log.New()),
|
Logger: log.NewEntry(log.New()),
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
OriginConn: origin,
|
OriginConn: origin,
|
||||||
EdgeMuxConfig: MuxerConfig{
|
EdgeMuxConfig: MuxerConfig{
|
||||||
|
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
||||||
Name: "edge",
|
Name: "edge",
|
||||||
CompressionQuality: quality,
|
CompressionQuality: quality,
|
||||||
Logger: log.NewEntry(log.New()),
|
Logger: log.NewEntry(log.New()),
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
EdgeConn: edge,
|
EdgeConn: edge,
|
||||||
doneC: make(chan struct{}),
|
doneC: make(chan struct{}),
|
||||||
|
|
|
@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
|
||||||
Closed() bool
|
Closed() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MuxedStreamDataSignaller is a write-only *ReadyList
|
||||||
|
type MuxedStreamDataSignaller interface {
|
||||||
|
// Non-blocking: call this when data is ready to be sent for the given stream ID.
|
||||||
|
Signal(ID uint32)
|
||||||
|
}
|
||||||
|
|
||||||
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
||||||
type MuxedStream struct {
|
type MuxedStream struct {
|
||||||
streamID uint32
|
streamID uint32
|
||||||
|
@ -55,8 +61,8 @@ type MuxedStream struct {
|
||||||
// This is the amount of bytes that are in the peer's receive window
|
// This is the amount of bytes that are in the peer's receive window
|
||||||
// (how much data we can send from this stream).
|
// (how much data we can send from this stream).
|
||||||
sendWindow uint32
|
sendWindow uint32
|
||||||
// Reference to the muxer's readyList; signal this for stream data to be sent.
|
// The muxer's readyList
|
||||||
readyList *ReadyList
|
readyList MuxedStreamDataSignaller
|
||||||
// The headers that should be sent, and a flag so we only send them once.
|
// The headers that should be sent, and a flag so we only send them once.
|
||||||
headersSent bool
|
headersSent bool
|
||||||
writeHeaders []Header
|
writeHeaders []Header
|
||||||
|
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
|
||||||
return th != ""
|
return th != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream {
|
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
|
||||||
return &MuxedStream{
|
return &MuxedStream{
|
||||||
responseHeadersReceived: make(chan struct{}),
|
responseHeadersReceived: make(chan struct{}),
|
||||||
readBuffer: NewSharedBuffer(),
|
readBuffer: NewSharedBuffer(),
|
||||||
|
|
|
@ -51,10 +51,12 @@ type MuxReader struct {
|
||||||
dictionaries h2Dictionaries
|
dictionaries h2Dictionaries
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MuxReader) Shutdown() {
|
// Shutdown blocks new streams from being created.
|
||||||
done := r.streams.Shutdown()
|
// It returns a channel that is closed once the last stream has closed.
|
||||||
if done == nil {
|
func (r *MuxReader) Shutdown() <-chan struct{} {
|
||||||
return
|
done, alreadyInProgress := r.streams.Shutdown()
|
||||||
|
if alreadyInProgress {
|
||||||
|
return done
|
||||||
}
|
}
|
||||||
r.sendGoAway(http2.ErrCodeNo)
|
r.sendGoAway(http2.ErrCodeNo)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
|
||||||
<-done
|
<-done
|
||||||
r.r.Close()
|
r.r.Close()
|
||||||
}()
|
}()
|
||||||
|
return done
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||||
|
|
|
@ -54,6 +54,13 @@ type MuxedStreamRequest struct {
|
||||||
body io.Reader
|
body io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest {
|
||||||
|
return MuxedStreamRequest{
|
||||||
|
stream: stream,
|
||||||
|
body: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *MuxedStreamRequest) flushBody() {
|
func (r *MuxedStreamRequest) flushBody() {
|
||||||
io.Copy(r.stream, r.body)
|
io.Copy(r.stream, r.body)
|
||||||
r.stream.CloseWrite()
|
r.stream.CloseWrite()
|
||||||
|
|
|
@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
|
||||||
}
|
}
|
||||||
return b.BaseTime
|
return b.BaseTime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Retries returns the number of retries consumed so far.
|
||||||
|
func (b *BackoffHandler) Retries() int {
|
||||||
|
return int(b.retries)
|
||||||
|
}
|
||||||
|
|
|
@ -2,16 +2,20 @@ package origin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -21,11 +25,23 @@ const (
|
||||||
resolveTTL = time.Hour
|
resolveTTL = time.Hour
|
||||||
// Interval between registering new tunnels
|
// Interval between registering new tunnels
|
||||||
registrationInterval = time.Second
|
registrationInterval = time.Second
|
||||||
|
|
||||||
|
subsystemRefreshAuth = "refresh_auth"
|
||||||
|
// Maximum exponent for 'Authenticate' exponential backoff
|
||||||
|
refreshAuthMaxBackoff = 10
|
||||||
|
// Waiting time before retrying a failed 'Authenticate' connection
|
||||||
|
refreshAuthRetryDuration = time.Second * 10
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errJWTUnset = errors.New("JWT unset")
|
||||||
|
errEventDigestUnset = errors.New("event digest unset")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Supervisor struct {
|
type Supervisor struct {
|
||||||
config *TunnelConfig
|
cloudflaredUUID uuid.UUID
|
||||||
edgeIPs []*net.TCPAddr
|
config *TunnelConfig
|
||||||
|
edgeIPs []*net.TCPAddr
|
||||||
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
||||||
nextUnusedEdgeIP int
|
nextUnusedEdgeIP int
|
||||||
lastResolve time.Time
|
lastResolve time.Time
|
||||||
|
@ -38,6 +54,12 @@ type Supervisor struct {
|
||||||
nextConnectedSignal chan struct{}
|
nextConnectedSignal chan struct{}
|
||||||
|
|
||||||
logger *logrus.Entry
|
logger *logrus.Entry
|
||||||
|
|
||||||
|
jwtLock *sync.RWMutex
|
||||||
|
jwt []byte
|
||||||
|
|
||||||
|
eventDigestLock *sync.RWMutex
|
||||||
|
eventDigest []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type resolveResult struct {
|
type resolveResult struct {
|
||||||
|
@ -50,18 +72,21 @@ type tunnelError struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSupervisor(config *TunnelConfig) *Supervisor {
|
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
|
||||||
return &Supervisor{
|
return &Supervisor{
|
||||||
|
cloudflaredUUID: u,
|
||||||
config: config,
|
config: config,
|
||||||
tunnelErrors: make(chan tunnelError),
|
tunnelErrors: make(chan tunnelError),
|
||||||
tunnelsConnecting: map[int]chan struct{}{},
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
logger: config.Logger.WithField("subsystem", "supervisor"),
|
logger: config.Logger.WithField("subsystem", "supervisor"),
|
||||||
|
jwtLock: &sync.RWMutex{},
|
||||||
|
eventDigestLock: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||||
logger := s.config.Logger
|
logger := s.config.Logger
|
||||||
if err := s.initialize(ctx, connectedSignal, u); err != nil {
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
|
@ -69,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
||||||
var backoffTimer <-chan time.Time
|
var backoffTimer <-chan time.Time
|
||||||
tunnelsActive := s.config.HAConnections
|
tunnelsActive := s.config.HAConnections
|
||||||
|
|
||||||
|
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
||||||
|
var refreshAuthBackoffTimer <-chan time.Time
|
||||||
|
if s.config.UseReconnectToken {
|
||||||
|
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// Context cancelled
|
// Context cancelled
|
||||||
|
@ -104,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
||||||
case <-backoffTimer:
|
case <-backoffTimer:
|
||||||
backoffTimer = nil
|
backoffTimer = nil
|
||||||
for _, index := range tunnelsWaiting {
|
for _, index := range tunnelsWaiting {
|
||||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||||
}
|
}
|
||||||
tunnelsActive += len(tunnelsWaiting)
|
tunnelsActive += len(tunnelsWaiting)
|
||||||
tunnelsWaiting = nil
|
tunnelsWaiting = nil
|
||||||
|
// Time to call Authenticate
|
||||||
|
case <-refreshAuthBackoffTimer:
|
||||||
|
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("Authentication failed")
|
||||||
|
// Permanent failure. Leave the `select` without setting the
|
||||||
|
// channel to be non-null, so we'll never hit this case of the `select` again.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
refreshAuthBackoffTimer = newTimer
|
||||||
// Tunnel successfully connected
|
// Tunnel successfully connected
|
||||||
case <-s.nextConnectedSignal:
|
case <-s.nextConnectedSignal:
|
||||||
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
||||||
|
@ -128,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||||
logger := s.logger
|
logger := s.logger
|
||||||
|
|
||||||
edgeIPs, err := s.resolveEdgeIPs()
|
edgeIPs, err := s.resolveEdgeIPs()
|
||||||
|
@ -145,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
// check entitlement and version too old error before attempting to register more tunnels
|
// check entitlement and version too old error before attempting to register more tunnels
|
||||||
s.nextUnusedEdgeIP = s.config.HAConnections
|
s.nextUnusedEdgeIP = s.config.HAConnections
|
||||||
go s.startFirstTunnel(ctx, connectedSignal, u)
|
go s.startFirstTunnel(ctx, connectedSignal)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
<-s.tunnelErrors
|
<-s.tunnelErrors
|
||||||
// Error can't be nil. A nil error signals that initialization succeed
|
// Error can't be nil. A nil error signals that initialization succeed
|
||||||
return fmt.Errorf("context was canceled")
|
return ctx.Err()
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
return tunnelError.err
|
return tunnelError.err
|
||||||
case <-connectedSignal.Wait():
|
case <-connectedSignal.Wait():
|
||||||
|
@ -158,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
// At least one successful connection, so start the rest
|
// At least one successful connection, so start the rest
|
||||||
for i := 1; i < s.config.HAConnections; i++ {
|
for i := 1; i < s.config.HAConnections; i++ {
|
||||||
ch := signal.New(make(chan struct{}))
|
ch := signal.New(make(chan struct{}))
|
||||||
go s.startTunnel(ctx, i, ch, u)
|
go s.startTunnel(ctx, i, ch)
|
||||||
time.Sleep(registrationInterval)
|
time.Sleep(registrationInterval)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -166,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
|
|
||||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
|
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
||||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
|
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||||
defer func() {
|
defer func() {
|
||||||
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
||||||
}()
|
}()
|
||||||
|
@ -183,19 +224,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
// try the next address if it was a dialError(network problem) or
|
// try the next address if it was a dialError(network problem) or
|
||||||
// dupConnRegisterTunnelError
|
// dupConnRegisterTunnelError
|
||||||
case dialError, dupConnRegisterTunnelError:
|
case connection.DialError, dupConnRegisterTunnelError:
|
||||||
s.replaceEdgeIP(0)
|
s.replaceEdgeIP(0)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
|
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors.
|
// s.tunnelErrors.
|
||||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
|
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
||||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
|
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
|
||||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
|
||||||
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
||||||
s.nextUnusedEdgeIP++
|
s.nextUnusedEdgeIP++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) ReconnectToken() ([]byte, error) {
|
||||||
|
s.jwtLock.RLock()
|
||||||
|
defer s.jwtLock.RUnlock()
|
||||||
|
if s.jwt == nil {
|
||||||
|
return nil, errJWTUnset
|
||||||
|
}
|
||||||
|
return s.jwt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) SetReconnectToken(jwt []byte) {
|
||||||
|
s.jwtLock.Lock()
|
||||||
|
defer s.jwtLock.Unlock()
|
||||||
|
s.jwt = jwt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) EventDigest() ([]byte, error) {
|
||||||
|
s.eventDigestLock.RLock()
|
||||||
|
defer s.eventDigestLock.RUnlock()
|
||||||
|
if s.eventDigest == nil {
|
||||||
|
return nil, errEventDigestUnset
|
||||||
|
}
|
||||||
|
return s.eventDigest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
|
||||||
|
s.eventDigestLock.Lock()
|
||||||
|
defer s.eventDigestLock.Unlock()
|
||||||
|
s.eventDigest = eventDigest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) refreshAuth(
|
||||||
|
ctx context.Context,
|
||||||
|
backoff *BackoffHandler,
|
||||||
|
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||||
|
) (retryTimer <-chan time.Time, err error) {
|
||||||
|
logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth)
|
||||||
|
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||||
|
if err != nil {
|
||||||
|
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
logger.WithError(err).Warnf("Retrying in %v", duration)
|
||||||
|
return backoff.BackoffTimer(), nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// clear backoff timer
|
||||||
|
backoff.SetGracePeriod()
|
||||||
|
|
||||||
|
switch outcome := authOutcome.(type) {
|
||||||
|
case tunnelpogs.AuthSuccess:
|
||||||
|
s.SetReconnectToken(outcome.JWT())
|
||||||
|
return timeAfter(outcome.RefreshAfter()), nil
|
||||||
|
case tunnelpogs.AuthUnknown:
|
||||||
|
return timeAfter(outcome.RefreshAfter()), nil
|
||||||
|
case tunnelpogs.AuthFail:
|
||||||
|
return nil, outcome
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
|
||||||
|
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer edgeConn.Close()
|
||||||
|
|
||||||
|
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
|
||||||
|
// This callback is invoked by h2mux when the edge initiates a stream.
|
||||||
|
return nil // noop
|
||||||
|
})
|
||||||
|
muxerConfig := s.config.muxerConfig(handler)
|
||||||
|
muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth)
|
||||||
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go muxer.Serve(ctx)
|
||||||
|
defer func() {
|
||||||
|
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
|
||||||
|
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
|
||||||
|
<-muxer.Shutdown()
|
||||||
|
}()
|
||||||
|
|
||||||
|
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tunnelServer.Close()
|
||||||
|
|
||||||
|
const arbitraryConnectionID = uint8(0)
|
||||||
|
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
||||||
|
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||||
|
authResponse, err := tunnelServer.Authenticate(
|
||||||
|
ctx,
|
||||||
|
s.config.OriginCert,
|
||||||
|
s.config.Hostname,
|
||||||
|
registrationOptions,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authResponse.Outcome(), nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
var wait time.Duration
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
wait = d
|
||||||
|
return time.After(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return nil, fmt.Errorf("authentication failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// authentication failures should consume the backoff
|
||||||
|
for i := uint(0); i < backoff.MaxRetries; i++ {
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, retryChan)
|
||||||
|
assert.Equal(t, (1<<i)*time.Second, wait)
|
||||||
|
}
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, retryChan)
|
||||||
|
|
||||||
|
// now we actually make contact with the remote server
|
||||||
|
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// The backoff timer should have been reset. To confirm this, make timeNow
|
||||||
|
// return a value after the backoff timer's grace period
|
||||||
|
timeNow = func() time.Time {
|
||||||
|
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
|
||||||
|
return time.Now().Add(expectedGracePeriod * 2)
|
||||||
|
}
|
||||||
|
_, ok := backoff.GetBackoffDuration(context.Background())
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthSuccess(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
var wait time.Duration
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
wait = d
|
||||||
|
return time.After(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, retryChan)
|
||||||
|
assert.Equal(t, 19*time.Hour, wait)
|
||||||
|
|
||||||
|
token, err := s.ReconnectToken()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("jwt"), token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthUnknown(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
var wait time.Duration
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
wait = d
|
||||||
|
return time.After(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, retryChan)
|
||||||
|
assert.Equal(t, 19*time.Hour, wait)
|
||||||
|
|
||||||
|
token, err := s.ReconnectToken()
|
||||||
|
assert.Equal(t, errJWTUnset, err)
|
||||||
|
assert.Nil(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthFail(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, retryChan)
|
||||||
|
|
||||||
|
token, err := s.ReconnectToken()
|
||||||
|
assert.Equal(t, errJWTUnset, err)
|
||||||
|
assert.Nil(t, token)
|
||||||
|
}
|
159
origin/tunnel.go
159
origin/tunnel.go
|
@ -14,7 +14,14 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
|
@ -22,19 +29,12 @@ import (
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/validation"
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
_ "github.com/prometheus/client_golang/prometheus"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
openStreamTimeout = 30 * time.Second
|
openStreamTimeout = 30 * time.Second
|
||||||
|
muxerTimeout = 5 * time.Second
|
||||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||||
DuplicateConnectionError = "EDUPCONN"
|
DuplicateConnectionError = "EDUPCONN"
|
||||||
|
@ -73,14 +73,9 @@ type TunnelConfig struct {
|
||||||
WSGI bool
|
WSGI bool
|
||||||
// OriginUrl may not be used if a user specifies a unix socket.
|
// OriginUrl may not be used if a user specifies a unix socket.
|
||||||
OriginUrl string
|
OriginUrl string
|
||||||
}
|
|
||||||
|
|
||||||
type dialError struct {
|
// feature-flag to use new edge reconnect tokens
|
||||||
cause error
|
UseReconnectToken bool
|
||||||
}
|
|
||||||
|
|
||||||
func (e dialError) Error() string {
|
|
||||||
return e.cause.Error()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type dupConnRegisterTunnelError struct{}
|
type dupConnRegisterTunnelError struct{}
|
||||||
|
@ -119,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string {
|
||||||
return e.cause.Error()
|
return e.cause.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
|
||||||
|
return h2mux.MuxerConfig{
|
||||||
|
Timeout: muxerTimeout,
|
||||||
|
Handler: handler,
|
||||||
|
IsClient: true,
|
||||||
|
HeartbeatInterval: c.HeartbeatInterval,
|
||||||
|
MaxHeartbeats: c.MaxHeartbeats,
|
||||||
|
Logger: c.TransportLogger.WithFields(log.Fields{}),
|
||||||
|
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
||||||
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
||||||
if c.HAConnections <= 1 && c.LBPool == "" {
|
if c.HAConnections <= 1 && c.LBPool == "" {
|
||||||
|
@ -141,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
||||||
return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID)
|
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnelLoop(ctx context.Context,
|
func ServeTunnelLoop(ctx context.Context,
|
||||||
|
@ -213,11 +220,11 @@ func ServeTunnel(
|
||||||
tags["ha"] = connectionTag
|
tags["ha"] = connectionTag
|
||||||
|
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errLog := logger.WithError(err)
|
errLog := logger.WithError(err)
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case dialError:
|
case connection.DialError:
|
||||||
errLog.Error("Unable to dial edge")
|
errLog.Error("Unable to dial edge")
|
||||||
case h2mux.MuxerHandshakeError:
|
case h2mux.MuxerHandshakeError:
|
||||||
errLog.Error("Handshake failed with edge server")
|
errLog.Error("Handshake failed with edge server")
|
||||||
|
@ -295,16 +302,6 @@ func ServeTunnel(
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
|
||||||
if len(headers) != 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if headers[0].Name != ":status" || headers[0].Value != "200" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterTunnel(
|
func RegisterTunnel(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
muxer *h2mux.Muxer,
|
muxer *h2mux.Muxer,
|
||||||
|
@ -315,43 +312,31 @@ func RegisterTunnel(
|
||||||
uuid uuid.UUID,
|
uuid uuid.UUID,
|
||||||
) error {
|
) error {
|
||||||
config.TransportLogger.Debug("initiating RPC stream to register")
|
config.TransportLogger.Debug("initiating RPC stream to register")
|
||||||
stream, err := openStream(ctx, muxer)
|
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RPC stream open error
|
// RPC stream open error
|
||||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
||||||
}
|
}
|
||||||
if !IsRPCStreamResponse(stream.Headers) {
|
defer tunnelServer.Close()
|
||||||
// stream response error
|
|
||||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
|
||||||
}
|
|
||||||
conn := rpc.NewConn(
|
|
||||||
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
|
|
||||||
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
|
|
||||||
)
|
|
||||||
defer conn.Close()
|
|
||||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
|
||||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||||
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
|
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||||
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
registration, err := ts.RegisterTunnel(
|
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
||||||
|
registration := tunnelServer.RegisterTunnel(
|
||||||
ctx,
|
ctx,
|
||||||
config.OriginCert,
|
config.OriginCert,
|
||||||
config.Hostname,
|
config.Hostname,
|
||||||
config.RegistrationOptions(connectionID, originLocalIP, uuid),
|
config.RegistrationOptions(connectionID, originLocalIP, uuid),
|
||||||
)
|
)
|
||||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
|
||||||
if err != nil {
|
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||||
// RegisterTunnel RPC failure
|
// RegisterTunnel RPC failure
|
||||||
return newClientRegisterTunnelError(err, config.Metrics.regFail)
|
return processRegisterTunnelError(registrationErr, config.Metrics)
|
||||||
}
|
|
||||||
for _, logLine := range registration.LogLines {
|
|
||||||
logger.Info(logLine)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if regErr := processRegisterTunnelError(registration.Err, registration.PermanentFailure, config.Metrics); regErr != nil {
|
for _, logLine := range registration.LogLines {
|
||||||
return regErr
|
logger.Info(logLine)
|
||||||
}
|
}
|
||||||
|
|
||||||
if registration.TunnelID != "" {
|
if registration.TunnelID != "" {
|
||||||
|
@ -374,57 +359,34 @@ func RegisterTunnel(
|
||||||
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
||||||
|
|
||||||
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
||||||
|
config.Metrics.regSuccess.Inc()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRegisterTunnelError(err string, permanentFailure bool, metrics *TunnelMetrics) error {
|
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
|
||||||
if err == "" {
|
if err.Error() == DuplicateConnectionError {
|
||||||
metrics.regSuccess.Inc()
|
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
metrics.regFail.WithLabelValues(err).Inc()
|
|
||||||
if err == DuplicateConnectionError {
|
|
||||||
return dupConnRegisterTunnelError{}
|
return dupConnRegisterTunnelError{}
|
||||||
}
|
}
|
||||||
|
metrics.regFail.WithLabelValues("server_error").Inc()
|
||||||
return serverRegisterTunnelError{
|
return serverRegisterTunnelError{
|
||||||
cause: fmt.Errorf("Server error: %s", err),
|
cause: fmt.Errorf("Server error: %s", err.Error()),
|
||||||
permanent: permanentFailure,
|
permanent: err.IsPermanent(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
|
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
|
||||||
logger.Debug("initiating RPC stream to unregister")
|
logger.Debug("initiating RPC stream to unregister")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
stream, err := openStream(ctx, muxer)
|
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RPC stream open error
|
// RPC stream open error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !IsRPCStreamResponse(stream.Headers) {
|
|
||||||
// stream response error
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn := rpc.NewConn(
|
|
||||||
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
|
|
||||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
|
||||||
)
|
|
||||||
defer conn.Close()
|
|
||||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
|
||||||
// gracePeriod is encoded in int64 using capnproto
|
// gracePeriod is encoded in int64 using capnproto
|
||||||
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
||||||
}
|
}
|
||||||
|
|
||||||
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
|
|
||||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
|
||||||
defer cancel()
|
|
||||||
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
|
|
||||||
{Name: ":method", Value: "RPC"},
|
|
||||||
{Name: ":scheme", Value: "capnp"},
|
|
||||||
{Name: ":path", Value: "*"},
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogServerInfo(
|
func LogServerInfo(
|
||||||
promise tunnelrpc.ServerInfo_Promise,
|
promise tunnelrpc.ServerInfo_Promise,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
|
@ -469,12 +431,12 @@ type TunnelHandler struct {
|
||||||
noChunkedEncoding bool
|
noChunkedEncoding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var dialer = net.Dialer{DualStack: true}
|
var dialer = net.Dialer{}
|
||||||
|
|
||||||
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
||||||
func NewTunnelHandler(ctx context.Context,
|
func NewTunnelHandler(ctx context.Context,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
addr string,
|
addr *net.TCPAddr,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
) (*TunnelHandler, string, error) {
|
) (*TunnelHandler, string, error) {
|
||||||
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
||||||
|
@ -495,37 +457,18 @@ func NewTunnelHandler(ctx context.Context,
|
||||||
if h.httpClient == nil {
|
if h.httpClient == nil {
|
||||||
h.httpClient = http.DefaultTransport
|
h.httpClient = http.DefaultTransport
|
||||||
}
|
}
|
||||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
|
||||||
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
|
||||||
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
|
||||||
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
|
|
||||||
dialCancel()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
|
return nil, "", err
|
||||||
}
|
}
|
||||||
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
|
||||||
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
|
||||||
err = edgeConn.Handshake()
|
|
||||||
if err != nil {
|
|
||||||
return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
|
||||||
}
|
|
||||||
// clear the deadline on the conn; h2mux has its own timeouts
|
|
||||||
edgeConn.SetDeadline(time.Time{})
|
|
||||||
// Establish a muxed connection with the edge
|
// Establish a muxed connection with the edge
|
||||||
// Client mux handshake with agent server
|
// Client mux handshake with agent server
|
||||||
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
Handler: h,
|
|
||||||
IsClient: true,
|
|
||||||
HeartbeatInterval: config.HeartbeatInterval,
|
|
||||||
MaxHeartbeats: config.MaxHeartbeats,
|
|
||||||
Logger: config.TransportLogger.WithFields(log.Fields{}),
|
|
||||||
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
|
|
||||||
}, h.metrics.activeStreams)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return h, "", errors.New("TLS handshake error")
|
return nil, "", errors.Wrap(err, "Handshake with edge error")
|
||||||
}
|
}
|
||||||
return h, edgeConn.LocalAddr().String(), err
|
return h, edgeConn.LocalAddr().String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
||||||
|
|
|
@ -35,26 +35,49 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request
|
||||||
|
// object. This includes conversion of the pseudo-headers into their closest
|
||||||
|
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
|
||||||
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||||
for _, header := range h2 {
|
for _, header := range h2 {
|
||||||
switch header.Name {
|
switch header.Name {
|
||||||
case ":method":
|
case ":method":
|
||||||
h1.Method = header.Value
|
h1.Method = header.Value
|
||||||
case ":scheme":
|
case ":scheme":
|
||||||
|
// noop - use the preexisting scheme from h1.URL
|
||||||
case ":authority":
|
case ":authority":
|
||||||
// Otherwise the host header will be based on the origin URL
|
// Otherwise the host header will be based on the origin URL
|
||||||
h1.Host = header.Value
|
h1.Host = header.Value
|
||||||
case ":path":
|
case ":path":
|
||||||
u, err := url.Parse(header.Value)
|
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
|
||||||
|
// However, this HTTP/1 Request object belongs to the Go standard library,
|
||||||
|
// whose URL package makes some opinionated decisions about the encoding of
|
||||||
|
// URL characters: see the docs of https://godoc.org/net/url#URL,
|
||||||
|
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
|
||||||
|
// which is always used when computing url.URL.String(), whether we'd like it or not.
|
||||||
|
//
|
||||||
|
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
|
||||||
|
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
|
||||||
|
// is treated differently when HTTP_PROXY is set!
|
||||||
|
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
|
||||||
|
//
|
||||||
|
// This means we are subject to the behavior of net/url's function `shouldEscape`
|
||||||
|
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
|
||||||
|
|
||||||
|
if header.Value == "*" {
|
||||||
|
h1.URL.Path = "*"
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Due to the behavior of validation.ValidateUrl, h1.URL may
|
||||||
|
// already have a partial value, with or without a trailing slash.
|
||||||
|
base := h1.URL.String()
|
||||||
|
base = strings.TrimRight(base, "/")
|
||||||
|
// But we know :path begins with '/', because we handled '*' above - see RFC7540
|
||||||
|
url, err := url.Parse(base + header.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unparseable path")
|
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
|
||||||
}
|
}
|
||||||
resolved := h1.URL.ResolveReference(u)
|
h1.URL = url
|
||||||
// prevent escaping base URL
|
|
||||||
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
|
|
||||||
return fmt.Errorf("invalid path %s", header.Value)
|
|
||||||
}
|
|
||||||
h1.URL = resolved
|
|
||||||
case "content-length":
|
case "content-length":
|
||||||
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,441 @@
|
||||||
|
package streamhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
|
||||||
|
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
|
[]h2mux.Header{
|
||||||
|
h2mux.Header{
|
||||||
|
Name: "Mock header 1",
|
||||||
|
Value: "Mock value 1",
|
||||||
|
},
|
||||||
|
h2mux.Header{
|
||||||
|
Name: "Mock header 2",
|
||||||
|
Value: "Mock value 2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, http.Header{
|
||||||
|
"Mock header 1": []string{"Mock value 1"},
|
||||||
|
"Mock header 2": []string{"Mock value 2"},
|
||||||
|
}, request.Header)
|
||||||
|
|
||||||
|
assert.NoError(t, headersConversionErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
|
||||||
|
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
|
[]h2mux.Header{},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, http.Header{}, request.Header)
|
||||||
|
|
||||||
|
assert.NoError(t, headersConversionErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
|
||||||
|
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
|
[]h2mux.Header{
|
||||||
|
h2mux.Header{
|
||||||
|
Name: ":path",
|
||||||
|
Value: "//bad_path/",
|
||||||
|
},
|
||||||
|
h2mux.Header{
|
||||||
|
Name: "Mock header",
|
||||||
|
Value: "Mock value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, http.Header{
|
||||||
|
"Mock header": []string{"Mock value"},
|
||||||
|
}, request.Header)
|
||||||
|
|
||||||
|
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
|
||||||
|
|
||||||
|
assert.NoError(t, headersConversionErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
|
||||||
|
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
|
[]h2mux.Header{
|
||||||
|
h2mux.Header{
|
||||||
|
Name: ":path",
|
||||||
|
Value: "/?query=mock%20value",
|
||||||
|
},
|
||||||
|
h2mux.Header{
|
||||||
|
Name: "Mock header",
|
||||||
|
Value: "Mock value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, http.Header{
|
||||||
|
"Mock header": []string{"Mock value"},
|
||||||
|
}, request.Header)
|
||||||
|
|
||||||
|
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
|
||||||
|
|
||||||
|
assert.NoError(t, headersConversionErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
|
||||||
|
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
|
[]h2mux.Header{
|
||||||
|
h2mux.Header{
|
||||||
|
Name: ":path",
|
||||||
|
Value: "/mock%20path",
|
||||||
|
},
|
||||||
|
h2mux.Header{
|
||||||
|
Name: "Mock header",
|
||||||
|
Value: "Mock value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Equal(t, http.Header{
|
||||||
|
"Mock header": []string{"Mock value"},
|
||||||
|
}, request.Header)
|
||||||
|
|
||||||
|
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
|
||||||
|
|
||||||
|
assert.NoError(t, headersConversionErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
path string
|
||||||
|
want string
|
||||||
|
}
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
path: "",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/",
|
||||||
|
want: "/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "//",
|
||||||
|
want: "//",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/test",
|
||||||
|
want: "/test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "//test",
|
||||||
|
want: "//test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// https://github.com/cloudflare/cloudflared/issues/81
|
||||||
|
path: "//test/",
|
||||||
|
want: "//test/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/%2Ftest",
|
||||||
|
want: "/%2Ftest",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "//%20test",
|
||||||
|
want: "//%20test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// https://github.com/cloudflare/cloudflared/issues/124
|
||||||
|
path: "/test?get=somthing%20a",
|
||||||
|
want: "/test?get=somthing%20a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/%20",
|
||||||
|
want: "/%20",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// stdlib's EscapedPath() will always percent-encode ' '
|
||||||
|
path: "/ ",
|
||||||
|
want: "/%20",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/ a ",
|
||||||
|
want: "/%20a%20",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a%20b",
|
||||||
|
want: "/a%20b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/foo/bar;param?query#frag",
|
||||||
|
want: "/foo/bar;param?query#frag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
|
||||||
|
path: "/a␠b",
|
||||||
|
want: "/a%E2%90%A0b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a-umlaut-ä",
|
||||||
|
want: "/a-umlaut-%C3%A4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a-umlaut-%C3%A4",
|
||||||
|
want: "/a-umlaut-%C3%A4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a-umlaut-%c3%a4",
|
||||||
|
want: "/a-umlaut-%c3%a4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// here the second '#' is treated as part of the fragment
|
||||||
|
path: "/a#b#c",
|
||||||
|
want: "/a#b%23c",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a#b␠c",
|
||||||
|
want: "/a#b%E2%90%A0c",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a#b%20c",
|
||||||
|
want: "/a#b%20c",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a#b c",
|
||||||
|
want: "/a#b%20c",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// stdlib's EscapedPath() will always percent-encode '\'
|
||||||
|
path: "/\\",
|
||||||
|
want: "/%5C",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a\\",
|
||||||
|
want: "/a%5C",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a,b.c.",
|
||||||
|
want: "/a,b.c.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/.",
|
||||||
|
want: "/.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// stdlib's EscapedPath() will always percent-encode '`'
|
||||||
|
path: "/a`",
|
||||||
|
want: "/a%60",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/a[0]",
|
||||||
|
want: "/a[0]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/?a[0]=5 &b[]=",
|
||||||
|
want: "/?a[0]=5 &b[]=",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: "/?a=%22b%20%22",
|
||||||
|
want: "/?a=%22b%20%22",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, testCase := range testCases {
|
||||||
|
requestURL := "https://example.com"
|
||||||
|
|
||||||
|
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
|
[]h2mux.Header{
|
||||||
|
h2mux.Header{
|
||||||
|
Name: ":path",
|
||||||
|
Value: testCase.path,
|
||||||
|
},
|
||||||
|
h2mux.Header{
|
||||||
|
Name: "Mock header",
|
||||||
|
Value: "Mock value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
assert.NoError(t, headersConversionErr)
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
http.Header{
|
||||||
|
"Mock header": []string{"Mock value"},
|
||||||
|
},
|
||||||
|
request.Header)
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
"https://example.com"+testCase.want,
|
||||||
|
request.URL.String(),
|
||||||
|
"Failed URL index: %v %#v", index, testCase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
|
||||||
|
config := &quick.Config{
|
||||||
|
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||||
|
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type testOrigin struct {
|
||||||
|
url string
|
||||||
|
|
||||||
|
expectedScheme string
|
||||||
|
expectedBasePath string
|
||||||
|
}
|
||||||
|
testOrigins := []testOrigin{
|
||||||
|
{
|
||||||
|
url: "http://origin.hostname.example.com:8080",
|
||||||
|
expectedScheme: "http",
|
||||||
|
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "http://origin.hostname.example.com:8080/",
|
||||||
|
expectedScheme: "http",
|
||||||
|
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "http://origin.hostname.example.com:8080/api",
|
||||||
|
expectedScheme: "http",
|
||||||
|
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "http://origin.hostname.example.com:8080/api/",
|
||||||
|
expectedScheme: "http",
|
||||||
|
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "https://origin.hostname.example.com:8080/api",
|
||||||
|
expectedScheme: "https",
|
||||||
|
expectedBasePath: "https://origin.hostname.example.com:8080/api",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// use multiple schemes to demonstrate that the URL is based on the
|
||||||
|
// origin's scheme, not the :scheme header
|
||||||
|
for _, testScheme := range []string{"http", "https"} {
|
||||||
|
for _, testOrigin := range testOrigins {
|
||||||
|
assertion := func(testPath string) bool {
|
||||||
|
const expectedMethod = "POST"
|
||||||
|
const expectedHostname = "request.hostname.example.com"
|
||||||
|
|
||||||
|
h2 := []h2mux.Header{
|
||||||
|
h2mux.Header{Name: ":method", Value: expectedMethod},
|
||||||
|
h2mux.Header{Name: ":scheme", Value: testScheme},
|
||||||
|
h2mux.Header{Name: ":authority", Value: expectedHostname},
|
||||||
|
h2mux.Header{Name: ":path", Value: testPath},
|
||||||
|
}
|
||||||
|
h1, err := http.NewRequest("GET", testOrigin.url, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = H2RequestHeadersToH1Request(h2, h1)
|
||||||
|
return assert.NoError(t, err) &&
|
||||||
|
assert.Equal(t, expectedMethod, h1.Method) &&
|
||||||
|
assert.Equal(t, expectedHostname, h1.Host) &&
|
||||||
|
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
|
||||||
|
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
|
||||||
|
}
|
||||||
|
err := quick.Check(assertion, config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomASCIIPrintableChar(rand *rand.Rand) int {
|
||||||
|
// smallest printable ASCII char is 32, largest is 126
|
||||||
|
const startPrintable = 32
|
||||||
|
const endPrintable = 127
|
||||||
|
return startPrintable + rand.Intn(endPrintable-startPrintable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomASCIIText generates an ASCII string, some of whose characters may be
|
||||||
|
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
|
||||||
|
// between 1 and `maxLength`.
|
||||||
|
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
|
||||||
|
length := minLength + rand.Intn(maxLength)
|
||||||
|
result := ""
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
c := randomASCIIPrintableChar(rand)
|
||||||
|
|
||||||
|
// 1/4 chance of using percent encoding when not necessary
|
||||||
|
if c == '%' || rand.Intn(4) == 0 {
|
||||||
|
result += fmt.Sprintf("%%%02X", c)
|
||||||
|
} else {
|
||||||
|
result += string(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calls `randomASCIIText` and ensures the result is a valid URL path,
|
||||||
|
// i.e. one that can pass unchanged through url.URL.String()
|
||||||
|
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||||
|
text := randomASCIIText(rand, minLength, maxLength)
|
||||||
|
regexp, err := regexp.Compile("[^/;,]*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return "/" + regexp.ReplaceAllStringFunc(text, url.PathEscape)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calls `randomASCIIText` and ensures the result is a valid URL query,
|
||||||
|
// i.e. one that can pass unchanged through url.URL.String()
|
||||||
|
func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||||
|
text := randomASCIIText(rand, minLength, maxLength)
|
||||||
|
return "?" + strings.ReplaceAll(text, "#", "%23")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
|
||||||
|
// i.e. one that can pass unchanged through url.URL.String()
|
||||||
|
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||||
|
text := randomASCIIText(rand, minLength, maxLength)
|
||||||
|
url, err := url.Parse("#" + text)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return url.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
|
||||||
|
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
|
||||||
|
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
|
||||||
|
result := randomHTTP1Path(t, rand, 1, 64)
|
||||||
|
if rand.Intn(2) == 1 {
|
||||||
|
result += randomHTTP1Query(t, rand, 1, 32)
|
||||||
|
}
|
||||||
|
if rand.Intn(2) == 1 {
|
||||||
|
result += randomHTTP1Fragment(t, rand, 1, 16)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,6 +29,27 @@ type Supervisor struct {
|
||||||
useConfigResultChan chan<- *pogs.UseConfigurationResult
|
useConfigResultChan chan<- *pogs.UseConfigurationResult
|
||||||
state *state
|
state *state
|
||||||
logger *logrus.Entry
|
logger *logrus.Entry
|
||||||
|
metrics metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
type metrics struct {
|
||||||
|
configVersion prometheus.Gauge
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMetrics() metrics {
|
||||||
|
configVersion := prometheus.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: "supervisor",
|
||||||
|
Subsystem: "supervisor",
|
||||||
|
Name: "config_version",
|
||||||
|
Help: "Latest configuration version received from Cloudflare",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(
|
||||||
|
configVersion,
|
||||||
|
)
|
||||||
|
return metrics{
|
||||||
|
configVersion: configVersion,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSupervisor(
|
func NewSupervisor(
|
||||||
|
@ -70,6 +92,7 @@ func NewSupervisor(
|
||||||
useConfigResultChan: useConfigResultChan,
|
useConfigResultChan: useConfigResultChan,
|
||||||
state: newState(defaultClientConfig),
|
state: newState(defaultClientConfig),
|
||||||
logger: logger.WithField("subsystem", "supervisor"),
|
logger: logger.WithField("subsystem", "supervisor"),
|
||||||
|
metrics: newMetrics(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,6 +154,7 @@ func (s *Supervisor) notifySubsystemsNewConfig(newConfig *pogs.ClientConfig) *po
|
||||||
Success: true,
|
Success: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
s.metrics.configVersion.Set(float64(newConfig.Version))
|
||||||
|
|
||||||
s.state.updateConfig(newConfig)
|
s.state.updateConfig(newConfig)
|
||||||
var tunnelHostnames []h2mux.TunnelHostname
|
var tunnelHostnames []h2mux.TunnelHostname
|
||||||
|
|
|
@ -26,28 +26,28 @@ mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
|
||||||
-----END CERTIFICATE-----
|
-----END CERTIFICATE-----
|
||||||
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
|
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
|
||||||
-----BEGIN CERTIFICATE-----
|
-----BEGIN CERTIFICATE-----
|
||||||
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
|
MIIEADCCAuigAwIBAgIID+rOSdTGfGcwDQYJKoZIhvcNAQELBQAwgYsxCzAJBgNV
|
||||||
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
|
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91
|
||||||
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
|
ZEZsYXJlIE9yaWdpbiBTU0wgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQH
|
||||||
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
|
Ew1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMB4XDTE5MDgyMzIx
|
||||||
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
|
MDgwMFoXDTI5MDgxNTE3MDAwMFowgYsxCzAJBgNVBAYTAlVTMRkwFwYDVQQKExBD
|
||||||
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
|
bG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91ZEZsYXJlIE9yaWdpbiBTU0wg
|
||||||
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
|
Q2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRMw
|
||||||
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
EQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
|
||||||
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
|
AQEAwEiVZ/UoQpHmFsHvk5isBxRehukP8DG9JhFev3WZtG76WoTthvLJFRKFCHXm
|
||||||
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
|
V6Z5/66Z4S09mgsUuFwvJzMnE6Ej6yIsYNCb9r9QORa8BdhrkNn6kdTly3mdnykb
|
||||||
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
|
OomnwbUfLlExVgNdlP0XoRoeMwbQ4598foiHblO2B/LKuNfJzAMfS7oZe34b+vLB
|
||||||
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
|
yrP/1bgCSLdc1AxQc1AC0EsQQhgcyTJNgnG4va1c7ogPlwKyhbDyZ4e59N5lbYPJ
|
||||||
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
|
SmXI/cAe3jXj1FBLJZkwnoDKe0v13xeF+nF32smSH0qB7aJX2tBMW4TWtFPmzs5I
|
||||||
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
|
lwrFSySWAdwYdgxw180yKU0dvwIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAQYwEgYD
|
||||||
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
|
VR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUJOhTV118NECHqeuU27rhFnj8KaQw
|
||||||
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
|
HwYDVR0jBBgwFoAUJOhTV118NECHqeuU27rhFnj8KaQwDQYJKoZIhvcNAQELBQAD
|
||||||
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
|
ggEBAHwOf9Ur1l0Ar5vFE6PNrZWrDfQIMyEfdgSKofCdTckbqXNTiXdgbHs+TWoQ
|
||||||
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
|
wAB0pfJDAHJDXOTCWRyTeXOseeOi5Btj5CnEuw3P0oXqdqevM1/+uWp0CM35zgZ8
|
||||||
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
|
VD4aITxity0djzE6Qnx3Syzz+ZkoBgTnNum7d9A66/V636x4vTeqbZFBr9erJzgz
|
||||||
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
|
hhurjcoacvRNhnjtDRM0dPeiCJ50CP3wEYuvUzDHUaowOsnLCjQIkWbR7Ni6KEIk
|
||||||
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
|
MOz2U0OBSif3FTkhCgZWQKOOLo1P42jHC3ssUZAtVNXrCk3fw9/E15k8NPkBazZ6
|
||||||
Fu6q54beR89jDc+oABmOgg==
|
0iykLhH1trywrKRMVw67F44IE8Y=
|
||||||
-----END CERTIFICATE-----
|
-----END CERTIFICATE-----
|
||||||
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
|
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
|
||||||
-----BEGIN CERTIFICATE-----
|
-----BEGIN CERTIFICATE-----
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
package pogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticateResponse is the serialized response from the Authenticate RPC.
|
||||||
|
// It's a 1:1 representation of the capnp message, so it's not very useful for programmers.
|
||||||
|
// Instead, you should call the `Outcome()` method to get a programmer-friendly sum type, with one
|
||||||
|
// case for each possible outcome.
|
||||||
|
type AuthenticateResponse struct {
|
||||||
|
PermanentErr string
|
||||||
|
RetryableErr string
|
||||||
|
Jwt []byte
|
||||||
|
HoursUntilRefresh uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
|
||||||
|
func (ar AuthenticateResponse) Outcome() AuthOutcome {
|
||||||
|
// If the user's authentication was unsuccessful, the server will return an error explaining why.
|
||||||
|
// cloudflared should fatal with this error.
|
||||||
|
if ar.PermanentErr != "" {
|
||||||
|
return NewAuthFail(errors.New(ar.PermanentErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there was a network error, then cloudflared should retry later,
|
||||||
|
// because origintunneld couldn't prove whether auth was correct or not.
|
||||||
|
if ar.RetryableErr != "" {
|
||||||
|
return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If auth succeeded, return the token and refresh it when instructed.
|
||||||
|
if len(ar.Jwt) > 0 {
|
||||||
|
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the state got messed up.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthOutcome is a programmer-friendly sum type denoting the possible outcomes of Authenticate.
|
||||||
|
//go-sumtype:decl AuthOutcome
|
||||||
|
type AuthOutcome interface {
|
||||||
|
isAuthOutcome()
|
||||||
|
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||||
|
Serialize() AuthenticateResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthSuccess means the backend successfully authenticated this cloudflared.
|
||||||
|
type AuthSuccess struct {
|
||||||
|
jwt []byte
|
||||||
|
hoursUntilRefresh uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess {
|
||||||
|
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ao AuthSuccess) JWT() []byte {
|
||||||
|
return ao.jwt
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||||
|
func (ao AuthSuccess) RefreshAfter() time.Duration {
|
||||||
|
return hoursToTime(ao.hoursUntilRefresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||||
|
func (ao AuthSuccess) Serialize() AuthenticateResponse {
|
||||||
|
return AuthenticateResponse{
|
||||||
|
Jwt: ao.jwt,
|
||||||
|
HoursUntilRefresh: ao.hoursUntilRefresh,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ao AuthSuccess) isAuthOutcome() {}
|
||||||
|
|
||||||
|
// AuthFail means this cloudflared has the wrong auth and should exit.
|
||||||
|
type AuthFail struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthFail(err error) AuthFail {
|
||||||
|
return AuthFail{err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ao AuthFail) Error() string {
|
||||||
|
return ao.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||||
|
func (ao AuthFail) Serialize() AuthenticateResponse {
|
||||||
|
return AuthenticateResponse{
|
||||||
|
PermanentErr: ao.err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ao AuthFail) isAuthOutcome() {}
|
||||||
|
|
||||||
|
// AuthUnknown means the backend couldn't finish checking authentication. Try again later.
|
||||||
|
type AuthUnknown struct {
|
||||||
|
err error
|
||||||
|
hoursUntilRefresh uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown {
|
||||||
|
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ao AuthUnknown) Error() string {
|
||||||
|
return ao.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||||
|
func (ao AuthUnknown) RefreshAfter() time.Duration {
|
||||||
|
return hoursToTime(ao.hoursUntilRefresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||||
|
func (ao AuthUnknown) Serialize() AuthenticateResponse {
|
||||||
|
return AuthenticateResponse{
|
||||||
|
RetryableErr: ao.err.Error(),
|
||||||
|
HoursUntilRefresh: ao.hoursUntilRefresh,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ao AuthUnknown) isAuthOutcome() {}
|
||||||
|
|
||||||
|
func hoursToTime(hours uint8) time.Duration {
|
||||||
|
return time.Duration(hours) * time.Hour
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package pogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
|
||||||
|
"zombiezen.com/go/capnproto2/pogs"
|
||||||
|
"zombiezen.com/go/capnproto2/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (i TunnelServer_PogsImpl) Authenticate(p tunnelrpc.TunnelServer_authenticate) error {
|
||||||
|
originCert, err := p.Params.OriginCert()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hostname, err := p.Params.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
options, err := p.Params.Options()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
server.Ack(p.Options)
|
||||||
|
resp, err := i.impl.Authenticate(p.Ctx, originCert, hostname, pogsOptions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err := p.Results.NewResult()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return MarshalAuthenticateResponse(result, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse, p *AuthenticateResponse) error {
|
||||||
|
return pogs.Insert(tunnelrpc.AuthenticateResponse_TypeID, s.Struct, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c TunnelServer_PogsClient) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
|
||||||
|
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||||
|
promise := client.Authenticate(ctx, func(p tunnelrpc.TunnelServer_authenticate_Params) error {
|
||||||
|
err := p.SetOriginCert(originCert)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = p.SetHostname(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registrationOptions, err := p.NewOptions()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
retval, err := promise.Result().Struct()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return UnmarshalAuthenticateResponse(retval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse) (*AuthenticateResponse, error) {
|
||||||
|
p := new(AuthenticateResponse)
|
||||||
|
err := pogs.Extract(p, tunnelrpc.AuthenticateResponse_TypeID, s.Struct)
|
||||||
|
return p, err
|
||||||
|
}
|
|
@ -0,0 +1,134 @@
|
||||||
|
package pogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
capnp "zombiezen.com/go/capnproto2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure the AuthOutcome sum is correct
|
||||||
|
var _ AuthOutcome = &AuthSuccess{}
|
||||||
|
var _ AuthOutcome = &AuthFail{}
|
||||||
|
var _ AuthOutcome = &AuthUnknown{}
|
||||||
|
|
||||||
|
// Unit tests for AuthenticateResponse.Outcome()
|
||||||
|
func TestAuthenticateResponseOutcome(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
PermanentErr string
|
||||||
|
RetryableErr string
|
||||||
|
Jwt []byte
|
||||||
|
HoursUntilRefresh uint8
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want AuthOutcome
|
||||||
|
}{
|
||||||
|
{"success",
|
||||||
|
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
|
||||||
|
AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6},
|
||||||
|
},
|
||||||
|
{"fail",
|
||||||
|
fields{PermanentErr: "bad creds"},
|
||||||
|
AuthFail{err: fmt.Errorf("bad creds")},
|
||||||
|
},
|
||||||
|
{"error",
|
||||||
|
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
|
||||||
|
AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6},
|
||||||
|
},
|
||||||
|
{"nil (no fields are set)",
|
||||||
|
fields{},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{"nil (too few fields are set)",
|
||||||
|
fields{HoursUntilRefresh: 6},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ar := AuthenticateResponse{
|
||||||
|
PermanentErr: tt.fields.PermanentErr,
|
||||||
|
RetryableErr: tt.fields.RetryableErr,
|
||||||
|
Jwt: tt.fields.Jwt,
|
||||||
|
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
|
||||||
|
}
|
||||||
|
got := ar.Outcome()
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got != nil && !reflect.DeepEqual(got.Serialize(), ar) {
|
||||||
|
t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthSuccess(t *testing.T) {
|
||||||
|
input := NewAuthSuccess([]byte("asdf"), 6)
|
||||||
|
output, ok := input.Serialize().Outcome().(AuthSuccess)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, input, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthUnknown(t *testing.T) {
|
||||||
|
input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6)
|
||||||
|
output, ok := input.Serialize().Outcome().(AuthUnknown)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, input, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthFail(t *testing.T) {
|
||||||
|
input := NewAuthFail(fmt.Errorf("wrong creds"))
|
||||||
|
output, ok := input.Serialize().Outcome().(AuthFail)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, input, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWhenToRefresh(t *testing.T) {
|
||||||
|
expected := 4 * time.Hour
|
||||||
|
actual := hoursToTime(4)
|
||||||
|
if expected != actual {
|
||||||
|
t.Fatalf("expected %v hours, got %v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that serializing and deserializing AuthenticationResponse undo each other.
|
||||||
|
func TestSerializeAuthenticationResponse(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []*AuthenticateResponse{
|
||||||
|
&AuthenticateResponse{
|
||||||
|
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
|
||||||
|
HoursUntilRefresh: 24,
|
||||||
|
},
|
||||||
|
&AuthenticateResponse{
|
||||||
|
PermanentErr: "bad auth",
|
||||||
|
},
|
||||||
|
&AuthenticateResponse{
|
||||||
|
RetryableErr: "bad connection",
|
||||||
|
HoursUntilRefresh: 24,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, testCase := range tests {
|
||||||
|
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||||
|
capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatal("Couldn't initialize a new message")
|
||||||
|
}
|
||||||
|
err = MarshalAuthenticateResponse(capnpEntity, testCase)
|
||||||
|
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result, err := UnmarshalAuthenticateResponse(capnpEntity)
|
||||||
|
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||||
|
}
|
||||||
|
}
|
|
@ -197,11 +197,14 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dialContext := (&net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: hc.ProxyConnectionTimeout,
|
Timeout: hc.ProxyConnectionTimeout,
|
||||||
KeepAlive: hc.TCPKeepAlive,
|
KeepAlive: hc.TCPKeepAlive,
|
||||||
DualStack: hc.DialDualStack,
|
}
|
||||||
}).DialContext
|
if !hc.DialDualStack {
|
||||||
|
dialer.FallbackDelay = -1
|
||||||
|
}
|
||||||
|
dialContext := dialer.DialContext
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
DialContext: dialContext,
|
DialContext: dialContext,
|
||||||
|
@ -270,7 +273,6 @@ func (*HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
KeepAlive: 30 * time.Second,
|
KeepAlive: 30 * time.Second,
|
||||||
DualStack: true,
|
|
||||||
}).DialContext,
|
}).DialContext,
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
RootCAs: rootCAs,
|
RootCAs: rootCAs,
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
package pogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
"zombiezen.com/go/capnproto2/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (i TunnelServer_PogsImpl) ReconnectTunnel(p tunnelrpc.TunnelServer_reconnectTunnel) error {
|
||||||
|
jwt, err := p.Params.Jwt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
eventDigest, err := p.Params.EventDigest()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hostname, err := p.Params.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
options, err := p.Params.Options()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
server.Ack(p.Options)
|
||||||
|
registration, err := i.impl.ReconnectTunnel(p.Ctx, jwt, eventDigest, hostname, pogsOptions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err := p.Results.NewResult()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return MarshalTunnelRegistration(result, registration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c TunnelServer_PogsClient) ReconnectTunnel(
|
||||||
|
ctx context.Context,
|
||||||
|
jwt,
|
||||||
|
eventDigest []byte,
|
||||||
|
hostname string,
|
||||||
|
options *RegistrationOptions,
|
||||||
|
) (*TunnelRegistration, error) {
|
||||||
|
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||||
|
promise := client.ReconnectTunnel(ctx, func(p tunnelrpc.TunnelServer_reconnectTunnel_Params) error {
|
||||||
|
err := p.SetJwt(jwt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = p.SetEventDigest(eventDigest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = p.SetHostname(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registrationOptions, err := p.NewOptions()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
retval, err := promise.Result().Struct()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return UnmarshalTunnelRegistration(retval)
|
||||||
|
}
|
|
@ -9,13 +9,16 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
capnp "zombiezen.com/go/capnproto2"
|
capnp "zombiezen.com/go/capnproto2"
|
||||||
"zombiezen.com/go/capnproto2/pogs"
|
"zombiezen.com/go/capnproto2/pogs"
|
||||||
"zombiezen.com/go/capnproto2/rpc"
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
"zombiezen.com/go/capnproto2/server"
|
"zombiezen.com/go/capnproto2/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultRetryAfterSeconds = 15
|
||||||
|
)
|
||||||
|
|
||||||
type Authentication struct {
|
type Authentication struct {
|
||||||
Key string
|
Key string
|
||||||
Email string
|
Email string
|
||||||
|
@ -33,11 +36,112 @@ func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TunnelRegistration struct {
|
type TunnelRegistration struct {
|
||||||
Err string
|
SuccessfulTunnelRegistration
|
||||||
Url string
|
Err string
|
||||||
LogLines []string
|
PermanentFailure bool
|
||||||
PermanentFailure bool
|
RetryAfterSeconds uint16
|
||||||
TunnelID string `capnp:"tunnelID"`
|
}
|
||||||
|
|
||||||
|
type SuccessfulTunnelRegistration struct {
|
||||||
|
Url string
|
||||||
|
LogLines []string
|
||||||
|
TunnelID string `capnp:"tunnelID"`
|
||||||
|
EventDigest []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSuccessfulTunnelRegistration(
|
||||||
|
url string,
|
||||||
|
logLines []string,
|
||||||
|
tunnelID string,
|
||||||
|
eventDigest []byte,
|
||||||
|
) *TunnelRegistration {
|
||||||
|
// Marshal nil will result in an error
|
||||||
|
if logLines == nil {
|
||||||
|
logLines = []string{}
|
||||||
|
}
|
||||||
|
return &TunnelRegistration{
|
||||||
|
SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{
|
||||||
|
Url: url,
|
||||||
|
LogLines: logLines,
|
||||||
|
TunnelID: tunnelID,
|
||||||
|
EventDigest: eventDigest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not calling this function Error() to avoid confusion with implementing error interface
|
||||||
|
func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError {
|
||||||
|
if tr.Err != "" {
|
||||||
|
err := fmt.Errorf(tr.Err)
|
||||||
|
if tr.PermanentFailure {
|
||||||
|
return NewPermanentRegistrationError(err)
|
||||||
|
}
|
||||||
|
retryAfterSeconds := tr.RetryAfterSeconds
|
||||||
|
if retryAfterSeconds < defaultRetryAfterSeconds {
|
||||||
|
retryAfterSeconds = defaultRetryAfterSeconds
|
||||||
|
}
|
||||||
|
return NewRetryableRegistrationError(err, retryAfterSeconds)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelRegistrationError interface {
|
||||||
|
error
|
||||||
|
Serialize() *TunnelRegistration
|
||||||
|
IsPermanent() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PermanentRegistrationError struct {
|
||||||
|
err string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPermanentRegistrationError(err error) TunnelRegistrationError {
|
||||||
|
return &PermanentRegistrationError{
|
||||||
|
err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pre *PermanentRegistrationError) Error() string {
|
||||||
|
return pre.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration {
|
||||||
|
return &TunnelRegistration{
|
||||||
|
Err: pre.err,
|
||||||
|
PermanentFailure: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*PermanentRegistrationError) IsPermanent() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetryableRegistrationError struct {
|
||||||
|
err string
|
||||||
|
retryAfterSeconds uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError {
|
||||||
|
return &RetryableRegistrationError{
|
||||||
|
err: err.Error(),
|
||||||
|
retryAfterSeconds: retryAfterSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rre *RetryableRegistrationError) Error() string {
|
||||||
|
return rre.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration {
|
||||||
|
return &TunnelRegistration{
|
||||||
|
Err: rre.err,
|
||||||
|
PermanentFailure: false,
|
||||||
|
RetryAfterSeconds: rre.retryAfterSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RetryableRegistrationError) IsPermanent() bool {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
|
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
|
||||||
|
@ -63,6 +167,7 @@ type RegistrationOptions struct {
|
||||||
RunFromTerminal bool `capnp:"runFromTerminal"`
|
RunFromTerminal bool `capnp:"runFromTerminal"`
|
||||||
CompressionQuality uint64 `capnp:"compressionQuality"`
|
CompressionQuality uint64 `capnp:"compressionQuality"`
|
||||||
UUID string `capnp:"uuid"`
|
UUID string `capnp:"uuid"`
|
||||||
|
NumPreviousAttempts uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||||
|
@ -323,10 +428,12 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar
|
||||||
}
|
}
|
||||||
|
|
||||||
type TunnelServer interface {
|
type TunnelServer interface {
|
||||||
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
|
||||||
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||||
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
|
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
|
||||||
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
|
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
|
||||||
|
Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error)
|
||||||
|
ReconnectTunnel(ctx context.Context, jwt, eventDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||||
|
@ -355,15 +462,12 @@ func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerT
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
server.Ack(p.Options)
|
server.Ack(p.Options)
|
||||||
registration, err := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
|
registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
result, err := p.Results.NewResult()
|
result, err := p.Results.NewResult()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Info(registration.TunnelID)
|
|
||||||
return MarshalTunnelRegistration(result, registration)
|
return MarshalTunnelRegistration(result, registration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -416,7 +520,7 @@ func (c TunnelServer_PogsClient) Close() error {
|
||||||
return c.Conn.Close()
|
return c.Conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
|
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
|
||||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||||
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
|
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
|
||||||
err := p.SetOriginCert(originCert)
|
err := p.SetOriginCert(originCert)
|
||||||
|
@ -439,9 +543,13 @@ func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert
|
||||||
})
|
})
|
||||||
retval, err := promise.Result().Struct()
|
retval, err := promise.Result().Struct()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
|
||||||
}
|
}
|
||||||
return UnmarshalTunnelRegistration(retval)
|
registration, err := UnmarshalTunnelRegistration(retval)
|
||||||
|
if err != nil {
|
||||||
|
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
|
||||||
|
}
|
||||||
|
return registration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
|
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package pogs
|
package pogs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -11,6 +12,50 @@ import (
|
||||||
capnp "zombiezen.com/go/capnproto2"
|
capnp "zombiezen.com/go/capnproto2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testURL = "tunnel.example.com"
|
||||||
|
testTunnelID = "asdfghjkl;"
|
||||||
|
testRetryAfterSeconds = 19
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testErr = fmt.Errorf("Invalid credential")
|
||||||
|
testLogLines = []string{"all", "working"}
|
||||||
|
testEventDigest = []byte("asdf")
|
||||||
|
)
|
||||||
|
|
||||||
|
// *PermanentRegistrationError implements TunnelRegistrationError
|
||||||
|
var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil)
|
||||||
|
|
||||||
|
// *RetryableRegistrationError implements TunnelRegistrationError
|
||||||
|
var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil)
|
||||||
|
|
||||||
|
func TestTunnelRegistration(t *testing.T) {
|
||||||
|
testCases := []*TunnelRegistration{
|
||||||
|
NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID, testEventDigest),
|
||||||
|
NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID, testEventDigest),
|
||||||
|
NewPermanentRegistrationError(testErr).Serialize(),
|
||||||
|
NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(),
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||||
|
capnpEntity, err := tunnelrpc.NewTunnelRegistration(seg)
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
t.Fatal("Couldn't initialize a new message")
|
||||||
|
}
|
||||||
|
err = MarshalTunnelRegistration(capnpEntity, testCase)
|
||||||
|
if !assert.NoError(t, err, "testCase #%v failed to marshal", i) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result, err := UnmarshalTunnelRegistration(capnpEntity)
|
||||||
|
if !assert.NoError(t, err, "testCase #%v failed to unmarshal", i) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectResult(t *testing.T) {
|
func TestConnectResult(t *testing.T) {
|
||||||
testCases := []ConnectResult{
|
testCases := []ConnectResult{
|
||||||
&ConnectError{
|
&ConnectError{
|
||||||
|
|
|
@ -19,6 +19,10 @@ struct TunnelRegistration {
|
||||||
permanentFailure @3 :Bool;
|
permanentFailure @3 :Bool;
|
||||||
# Displayed to user
|
# Displayed to user
|
||||||
tunnelID @4 :Text;
|
tunnelID @4 :Text;
|
||||||
|
# How long should this connection wait to retry in seconds, if the error wasn't permanent
|
||||||
|
retryAfterSeconds @5 :UInt16;
|
||||||
|
# A unique ID used to reconnect this tunnel.
|
||||||
|
eventDigest @6 :Data;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RegistrationOptions {
|
struct RegistrationOptions {
|
||||||
|
@ -44,6 +48,8 @@ struct RegistrationOptions {
|
||||||
# cross stream compression setting, 0 - off, 3 - high
|
# cross stream compression setting, 0 - off, 3 - high
|
||||||
compressionQuality @10 :UInt64;
|
compressionQuality @10 :UInt64;
|
||||||
uuid @11 :Text;
|
uuid @11 :Text;
|
||||||
|
# number of previous attempts to send RegisterTunnel/ReconnectTunnel
|
||||||
|
numPreviousAttempts @12 :UInt8;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CapnpConnectParameters {
|
struct CapnpConnectParameters {
|
||||||
|
@ -274,11 +280,20 @@ struct FailedConfig {
|
||||||
reason @4 :Text;
|
reason @4 :Text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AuthenticateResponse {
|
||||||
|
permanentErr @0 :Text;
|
||||||
|
retryableErr @1 :Text;
|
||||||
|
jwt @2 :Data;
|
||||||
|
hoursUntilRefresh @3 :UInt8;
|
||||||
|
}
|
||||||
|
|
||||||
interface TunnelServer {
|
interface TunnelServer {
|
||||||
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||||
getServerInfo @1 () -> (result :ServerInfo);
|
getServerInfo @1 () -> (result :ServerInfo);
|
||||||
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
|
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
|
||||||
connect @3 (parameters :CapnpConnectParameters) -> (result :ConnectResult);
|
connect @3 (parameters :CapnpConnectParameters) -> (result :ConnectResult);
|
||||||
|
authenticate @4 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :AuthenticateResponse);
|
||||||
|
reconnectTunnel @5 (jwt :Data, eventDigest :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ClientService {
|
interface ClientService {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -60,6 +60,12 @@ func ValidateHostname(hostname string) (string, error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
|
||||||
|
// Note: when originUrl contains a scheme, the path is removed:
|
||||||
|
// ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
|
||||||
|
// but when it does not, the path is preserved:
|
||||||
|
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
|
||||||
|
// This is arguably a bug, but changing it might break some cloudflared users.
|
||||||
func ValidateUrl(originUrl string) (string, error) {
|
func ValidateUrl(originUrl string) (string, error) {
|
||||||
if originUrl == "" {
|
if originUrl == "" {
|
||||||
return "", fmt.Errorf("URL should not be empty")
|
return "", fmt.Errorf("URL should not be empty")
|
||||||
|
@ -121,6 +127,8 @@ func ValidateUrl(originUrl string) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||||
}
|
}
|
||||||
|
// This is why the path is preserved when `originUrl` doesn't have a schema.
|
||||||
|
// Using `parsedUrl.Port()` here, instead of `port`, would remove the path
|
||||||
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
|
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -182,10 +190,11 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
||||||
_, secondErr := client.Do(secondRequest)
|
_, secondErr := client.Do(secondRequest)
|
||||||
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
||||||
return errors.Errorf(
|
return errors.Errorf(
|
||||||
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
|
||||||
parsedURL.Host,
|
parsedURL.Host,
|
||||||
oldScheme,
|
oldScheme,
|
||||||
parsedURL.Scheme,
|
parsedURL.Scheme,
|
||||||
|
initialErr,
|
||||||
parsedURL,
|
parsedURL,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,98 +53,65 @@ func TestValidateHostname(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateUrl(t *testing.T) {
|
func TestValidateUrl(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input string
|
||||||
|
expectedOutput string
|
||||||
|
}
|
||||||
|
testCases := []testCase{
|
||||||
|
{"http://localhost", "http://localhost"},
|
||||||
|
{"http://localhost/", "http://localhost"},
|
||||||
|
{"http://localhost/api", "http://localhost"},
|
||||||
|
{"http://localhost/api/", "http://localhost"},
|
||||||
|
{"https://localhost", "https://localhost"},
|
||||||
|
{"https://localhost/", "https://localhost"},
|
||||||
|
{"https://localhost/api", "https://localhost"},
|
||||||
|
{"https://localhost/api/", "https://localhost"},
|
||||||
|
{"https://localhost:8080", "https://localhost:8080"},
|
||||||
|
{"https://localhost:8080/", "https://localhost:8080"},
|
||||||
|
{"https://localhost:8080/api", "https://localhost:8080"},
|
||||||
|
{"https://localhost:8080/api/", "https://localhost:8080"},
|
||||||
|
{"localhost", "http://localhost"},
|
||||||
|
{"localhost/", "http://localhost/"},
|
||||||
|
{"localhost/api", "http://localhost/api"},
|
||||||
|
{"localhost/api/", "http://localhost/api/"},
|
||||||
|
{"localhost:8080", "http://localhost:8080"},
|
||||||
|
{"localhost:8080/", "http://localhost:8080/"},
|
||||||
|
{"localhost:8080/api", "http://localhost:8080/api"},
|
||||||
|
{"localhost:8080/api/", "http://localhost:8080/api/"},
|
||||||
|
{"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"},
|
||||||
|
{"http://127.0.0.1:8080", "http://127.0.0.1:8080"},
|
||||||
|
{"127.0.0.1:8080", "http://127.0.0.1:8080"},
|
||||||
|
{"127.0.0.1", "http://127.0.0.1"},
|
||||||
|
{"https://127.0.0.1:8080", "https://127.0.0.1:8080"},
|
||||||
|
{"[::1]:8080", "http://[::1]:8080"},
|
||||||
|
{"http://[::1]", "http://[::1]"},
|
||||||
|
{"http://[::1]:8080", "http://[::1]:8080"},
|
||||||
|
{"[::1]", "http://[::1]"},
|
||||||
|
{"https://example.com", "https://example.com"},
|
||||||
|
{"example.com", "http://example.com"},
|
||||||
|
{"http://hello.example.com", "http://hello.example.com"},
|
||||||
|
{"hello.example.com", "http://hello.example.com"},
|
||||||
|
{"hello.example.com:8080", "http://hello.example.com:8080"},
|
||||||
|
{"https://hello.example.com:8080", "https://hello.example.com:8080"},
|
||||||
|
{"https://bücher.example.com", "https://xn--bcher-kva.example.com"},
|
||||||
|
{"bücher.example.com", "http://xn--bcher-kva.example.com"},
|
||||||
|
{"https%3A%2F%2Fhello.example.com", "https://hello.example.com"},
|
||||||
|
{"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"},
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
validUrl, err := ValidateUrl(testCase.input)
|
||||||
|
assert.NoError(t, err, "test case %v", i)
|
||||||
|
assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i)
|
||||||
|
}
|
||||||
|
|
||||||
validUrl, err := ValidateUrl("")
|
validUrl, err := ValidateUrl("")
|
||||||
assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
|
assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
|
||||||
assert.Empty(t, validUrl)
|
assert.Empty(t, validUrl)
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://localhost:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://localhost:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("localhost:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://localhost:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("http://localhost")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://localhost", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("http://127.0.0.1:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("127.0.0.1:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("127.0.0.1")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://127.0.0.1", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://127.0.0.1:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://127.0.0.1:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("[::1]:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://[::1]:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("http://[::1]")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://[::1]", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("http://[::1]:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://[::1]:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("[::1]")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://[::1]", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("http://hello.example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://hello.example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("hello.example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://hello.example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("hello.example.com:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://hello.example.com:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://hello.example.com:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://bücher.example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("bücher.example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://hello.example.com", validUrl)
|
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
||||||
assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
|
assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
|
||||||
assert.Empty(t, validUrl)
|
assert.Empty(t, validUrl)
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToggleProtocol(t *testing.T) {
|
func TestToggleProtocol(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue