TUN-2606: add DialEdge helpers
This commit is contained in:
		
							parent
							
								
									92736b2677
								
							
						
					
					
						commit
						8f4fd70783
					
				|  | @ -2,7 +2,6 @@ package connection | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"net" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflared/h2mux" | 	"github.com/cloudflare/cloudflared/h2mux" | ||||||
|  | @ -20,20 +19,12 @@ const ( | ||||||
| 	openStreamTimeout = 30 * time.Second | 	openStreamTimeout = 30 * time.Second | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type dialError struct { |  | ||||||
| 	cause error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (e dialError) Error() string { |  | ||||||
| 	return e.cause.Error() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Connection struct { | type Connection struct { | ||||||
| 	id    uuid.UUID | 	id    uuid.UUID | ||||||
| 	muxer *h2mux.Muxer | 	muxer *h2mux.Muxer | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) { | func newConnection(muxer *h2mux.Muxer) (*Connection, error) { | ||||||
| 	id, err := uuid.NewRandom() | 	id, err := uuid.NewRandom() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  |  | ||||||
|  | @ -0,0 +1,54 @@ | ||||||
|  | package connection | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"net" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // DialEdge makes a TLS connection to a Cloudflare edge node
 | ||||||
|  | func DialEdge( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	timeout time.Duration, | ||||||
|  | 	tlsConfig *tls.Config, | ||||||
|  | 	edgeTCPAddr *net.TCPAddr, | ||||||
|  | ) (net.Conn, error) { | ||||||
|  | 	// Inherit from parent context so we can cancel (Ctrl-C) while dialing
 | ||||||
|  | 	dialCtx, dialCancel := context.WithTimeout(ctx, timeout) | ||||||
|  | 	defer dialCancel() | ||||||
|  | 
 | ||||||
|  | 	dialer := net.Dialer{} | ||||||
|  | 	edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, newDialError(err, "DialContext error") | ||||||
|  | 	} | ||||||
|  | 	tlsEdgeConn := tls.Client(edgeConn, tlsConfig) | ||||||
|  | 	tlsEdgeConn.SetDeadline(time.Now().Add(timeout)) | ||||||
|  | 
 | ||||||
|  | 	if err = tlsEdgeConn.Handshake(); err != nil { | ||||||
|  | 		return nil, newDialError(err, "Handshake with edge error") | ||||||
|  | 	} | ||||||
|  | 	// clear the deadline on the conn; h2mux has its own timeouts
 | ||||||
|  | 	tlsEdgeConn.SetDeadline(time.Time{}) | ||||||
|  | 	return tlsEdgeConn, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DialError is an error returned from DialEdge
 | ||||||
|  | type DialError struct { | ||||||
|  | 	cause error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newDialError(err error, message string) error { | ||||||
|  | 	return DialError{cause: errors.Wrap(err, message)} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (e DialError) Error() string { | ||||||
|  | 	return e.cause.Error() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (e DialError) Cause() error { | ||||||
|  | 	return e.cause | ||||||
|  | } | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -128,12 +127,12 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError { | func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError { | ||||||
| 	edgeIP := em.serviceDiscoverer.Addr() | 	edgeTCPAddr := em.serviceDiscoverer.Addr() | ||||||
| 	edgeConn, err := em.dialEdge(ctx, edgeIP) | 	configurable := em.state.getConfigurable() | ||||||
|  | 	edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return retryConnection(fmt.Sprintf("dial edge error: %v", err)) | 		return retryConnection(fmt.Sprintf("dial edge error: %v", err)) | ||||||
| 	} | 	} | ||||||
| 	configurable := em.state.getConfigurable() |  | ||||||
| 	// Establish a muxed connection with the edge
 | 	// Establish a muxed connection with the edge
 | ||||||
| 	// Client mux handshake with agent server
 | 	// Client mux handshake with agent server
 | ||||||
| 	muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ | 	muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ | ||||||
|  | @ -148,7 +147,7 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError { | ||||||
| 		retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err)) | 		retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	h2muxConn, err := newConnection(muxer, edgeIP) | 	h2muxConn, err := newConnection(muxer) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err)) | 		return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err)) | ||||||
| 	} | 	} | ||||||
|  | @ -196,28 +195,6 @@ func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) { | ||||||
| 	em.state.closeConnection(conn) | 	em.state.closeConnection(conn) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) { |  | ||||||
| 	timeout := em.state.getConfigurable().Timeout |  | ||||||
| 	// Inherit from parent context so we can cancel (Ctrl-C) while dialing
 |  | ||||||
| 	dialCtx, dialCancel := context.WithTimeout(ctx, timeout) |  | ||||||
| 	defer dialCancel() |  | ||||||
| 
 |  | ||||||
| 	dialer := net.Dialer{} |  | ||||||
| 	edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, dialError{cause: errors.Wrap(err, "DialContext error")} |  | ||||||
| 	} |  | ||||||
| 	tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig) |  | ||||||
| 	tlsEdgeConn.SetDeadline(time.Now().Add(timeout)) |  | ||||||
| 
 |  | ||||||
| 	if err = tlsEdgeConn.Handshake(); err != nil { |  | ||||||
| 		return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")} |  | ||||||
| 	} |  | ||||||
| 	// clear the deadline on the conn; h2mux has its own timeouts
 |  | ||||||
| 	tlsEdgeConn.SetDeadline(time.Time{}) |  | ||||||
| 	return tlsEdgeConn, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (em *EdgeManager) noRetryMessage() string { | func (em *EdgeManager) noRetryMessage() string { | ||||||
| 	messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" + | 	messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" + | ||||||
| 		"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." + | 		"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." + | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										1
									
								
								go.mod
								
								
								
								
							|  | @ -57,6 +57,7 @@ require ( | ||||||
| 	golang.org/x/sync v0.0.0-20190423024810-112230192c58 | 	golang.org/x/sync v0.0.0-20190423024810-112230192c58 | ||||||
| 	golang.org/x/sys v0.0.0-20191008105621-543471e840be | 	golang.org/x/sys v0.0.0-20191008105621-543471e840be | ||||||
| 	golang.org/x/text v0.3.2 // indirect | 	golang.org/x/text v0.3.2 // indirect | ||||||
|  | 	google.golang.org/appengine v1.4.0 // indirect | ||||||
| 	google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd // indirect | 	google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd // indirect | ||||||
| 	google.golang.org/grpc v1.24.0 // indirect | 	google.golang.org/grpc v1.24.0 // indirect | ||||||
| 	gopkg.in/coreos/go-oidc.v2 v2.1.0 | 	gopkg.in/coreos/go-oidc.v2 v2.1.0 | ||||||
|  |  | ||||||
|  | @ -183,7 +183,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign | ||||||
| 			return | 			return | ||||||
| 		// try the next address if it was a dialError(network problem) or
 | 		// try the next address if it was a dialError(network problem) or
 | ||||||
| 		// dupConnRegisterTunnelError
 | 		// dupConnRegisterTunnelError
 | ||||||
| 		case dialError, dupConnRegisterTunnelError: | 		case connection.DialError, dupConnRegisterTunnelError: | ||||||
| 			s.replaceEdgeIP(0) | 			s.replaceEdgeIP(0) | ||||||
| 		default: | 		default: | ||||||
| 			return | 			return | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" | 	"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" | ||||||
|  | 	"github.com/cloudflare/cloudflared/connection" | ||||||
| 	"github.com/cloudflare/cloudflared/h2mux" | 	"github.com/cloudflare/cloudflared/h2mux" | ||||||
| 	"github.com/cloudflare/cloudflared/signal" | 	"github.com/cloudflare/cloudflared/signal" | ||||||
| 	"github.com/cloudflare/cloudflared/streamhandler" | 	"github.com/cloudflare/cloudflared/streamhandler" | ||||||
|  | @ -27,7 +28,6 @@ import ( | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
| 	"github.com/prometheus/client_golang/prometheus" | 	"github.com/prometheus/client_golang/prometheus" | ||||||
| 	_ "github.com/prometheus/client_golang/prometheus" |  | ||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/sync/errgroup" | 	"golang.org/x/sync/errgroup" | ||||||
| 	rpc "zombiezen.com/go/capnproto2/rpc" | 	rpc "zombiezen.com/go/capnproto2/rpc" | ||||||
|  | @ -76,14 +76,6 @@ type TunnelConfig struct { | ||||||
| 	OriginUrl string | 	OriginUrl string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type dialError struct { |  | ||||||
| 	cause error |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (e dialError) Error() string { |  | ||||||
| 	return e.cause.Error() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type dupConnRegisterTunnelError struct{} | type dupConnRegisterTunnelError struct{} | ||||||
| 
 | 
 | ||||||
| func (e dupConnRegisterTunnelError) Error() string { | func (e dupConnRegisterTunnelError) Error() string { | ||||||
|  | @ -214,11 +206,11 @@ func ServeTunnel( | ||||||
| 	tags["ha"] = connectionTag | 	tags["ha"] = connectionTag | ||||||
| 
 | 
 | ||||||
| 	// Returns error from parsing the origin URL or handshake errors
 | 	// Returns error from parsing the origin URL or handshake errors
 | ||||||
| 	handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) | 	handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		errLog := logger.WithError(err) | 		errLog := logger.WithError(err) | ||||||
| 		switch err.(type) { | 		switch err.(type) { | ||||||
| 		case dialError: | 		case connection.DialError: | ||||||
| 			errLog.Error("Unable to dial edge") | 			errLog.Error("Unable to dial edge") | ||||||
| 		case h2mux.MuxerHandshakeError: | 		case h2mux.MuxerHandshakeError: | ||||||
| 			errLog.Error("Handshake failed with edge server") | 			errLog.Error("Handshake failed with edge server") | ||||||
|  | @ -470,7 +462,7 @@ var dialer = net.Dialer{} | ||||||
| // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
 | // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
 | ||||||
| func NewTunnelHandler(ctx context.Context, | func NewTunnelHandler(ctx context.Context, | ||||||
| 	config *TunnelConfig, | 	config *TunnelConfig, | ||||||
| 	addr string, | 	addr *net.TCPAddr, | ||||||
| 	connectionID uint8, | 	connectionID uint8, | ||||||
| ) (*TunnelHandler, string, error) { | ) (*TunnelHandler, string, error) { | ||||||
| 	originURL, err := validation.ValidateUrl(config.OriginUrl) | 	originURL, err := validation.ValidateUrl(config.OriginUrl) | ||||||
|  | @ -491,22 +483,11 @@ func NewTunnelHandler(ctx context.Context, | ||||||
| 	if h.httpClient == nil { | 	if h.httpClient == nil { | ||||||
| 		h.httpClient = http.DefaultTransport | 		h.httpClient = http.DefaultTransport | ||||||
| 	} | 	} | ||||||
| 	// Inherit from parent context so we can cancel (Ctrl-C) while dialing
 | 
 | ||||||
| 	dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout) | 	edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) | ||||||
| 	// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
 |  | ||||||
| 	plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr) |  | ||||||
| 	dialCancel() |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")} | 		return nil, "", err | ||||||
| 	} | 	} | ||||||
| 	edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig) |  | ||||||
| 	edgeConn.SetDeadline(time.Now().Add(dialTimeout)) |  | ||||||
| 	err = edgeConn.Handshake() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")} |  | ||||||
| 	} |  | ||||||
| 	// clear the deadline on the conn; h2mux has its own timeouts
 |  | ||||||
| 	edgeConn.SetDeadline(time.Time{}) |  | ||||||
| 	// Establish a muxed connection with the edge
 | 	// Establish a muxed connection with the edge
 | ||||||
| 	// Client mux handshake with agent server
 | 	// Client mux handshake with agent server
 | ||||||
| 	h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ | 	h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ | ||||||
|  | @ -519,9 +500,9 @@ func NewTunnelHandler(ctx context.Context, | ||||||
| 		CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality), | 		CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality), | ||||||
| 	}, h.metrics.activeStreams) | 	}, h.metrics.activeStreams) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return h, "", errors.New("TLS handshake error") | 		return nil, "", errors.Wrap(err, "Handshake with edge error") | ||||||
| 	} | 	} | ||||||
| 	return h, edgeConn.LocalAddr().String(), err | 	return h, edgeConn.LocalAddr().String(), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { | func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue