TUN-3400: Use Go HTTP2 library as transport to connect with the edge
This commit is contained in:
		
							parent
							
								
									d7498b0c03
								
							
						
					
					
						commit
						8d7b2575ba
					
				| 
						 | 
					@ -156,7 +156,8 @@ func prepareTunnelConfig(
 | 
				
			||||||
	transportLogger logger.Service,
 | 
						transportLogger logger.Service,
 | 
				
			||||||
	namedTunnel *origin.NamedTunnelConfig,
 | 
						namedTunnel *origin.NamedTunnelConfig,
 | 
				
			||||||
) (*origin.TunnelConfig, error) {
 | 
					) (*origin.TunnelConfig, error) {
 | 
				
			||||||
	compatibilityMode := namedTunnel == nil
 | 
						isNamedTunnel := namedTunnel != nil
 | 
				
			||||||
 | 
						compatibilityMode := !isNamedTunnel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hostname, err := validation.ValidateHostname(c.String("hostname"))
 | 
						hostname, err := validation.ValidateHostname(c.String("hostname"))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -219,7 +220,7 @@ func prepareTunnelConfig(
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c)
 | 
						toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("unable to create TLS config to connect with edge: %s", err)
 | 
							logger.Errorf("unable to create TLS config to connect with edge: %s", err)
 | 
				
			||||||
		return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
 | 
							return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -9,7 +9,7 @@ import (
 | 
				
			||||||
	"github.com/pkg/errors"
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DialEdge makes a TLS connection to a Cloudflare edge node
 | 
					// DialEdgeWithH2Mux makes a TLS connection to a Cloudflare edge node
 | 
				
			||||||
func DialEdge(
 | 
					func DialEdge(
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	timeout time.Duration,
 | 
						timeout time.Duration,
 | 
				
			||||||
| 
						 | 
					@ -25,6 +25,7 @@ func DialEdge(
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, newDialError(err, "DialContext error")
 | 
							return nil, newDialError(err, "DialContext error")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
 | 
						tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
 | 
				
			||||||
	tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
 | 
						tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,14 @@
 | 
				
			||||||
 | 
					package origin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called
 | 
				
			||||||
 | 
					type persistentConn struct {
 | 
				
			||||||
 | 
						net.Conn
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (pc *persistentConn) Close() error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,160 @@
 | 
				
			||||||
 | 
					package origin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/h2mux"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/logger"
 | 
				
			||||||
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type cfdServer struct {
 | 
				
			||||||
 | 
						httpServer      *http2.Server
 | 
				
			||||||
 | 
						originClient    http.RoundTripper
 | 
				
			||||||
 | 
						logger          logger.Service
 | 
				
			||||||
 | 
						originURL       *url.URL
 | 
				
			||||||
 | 
						connectionIndex string
 | 
				
			||||||
 | 
						config          *TunnelConfig
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *cfdServer) serve(ctx context.Context, conn net.Conn) {
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							<-ctx.Done()
 | 
				
			||||||
 | 
							conn.Close()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
 | 
				
			||||||
 | 
							Context: ctx,
 | 
				
			||||||
 | 
							Handler: c,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *cfdServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						c.config.Metrics.incrementRequests(c.connectionIndex)
 | 
				
			||||||
 | 
						defer c.config.Metrics.decrementConcurrentRequests(c.connectionIndex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cfRay := findCfRayHeader(r)
 | 
				
			||||||
 | 
						lbProbe := isLBProbeRequest(r)
 | 
				
			||||||
 | 
						c.logRequest(r, cfRay, lbProbe)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r.URL = c.originURL
 | 
				
			||||||
 | 
						// TODO: TUN-3406 support websocket, event stream and WSGI servers.
 | 
				
			||||||
 | 
						var resp *http.Response
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						if strings.ToLower(r.Header.Get("Cf-Int-Argo-Tunnel-Upgrade")) == "websocket" {
 | 
				
			||||||
 | 
							resp, err = serveWebsocket(newWebsocketBody(w, r, c.logger), r, c.config.HTTPHostHeader, c.config.ClientTlsConfig)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							resp, err = c.originClient.RoundTrip(r)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.writeErrorResponse(w, err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						w.WriteHeader(resp.StatusCode)
 | 
				
			||||||
 | 
						_, err = io.Copy(w, resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							c.logger.Errorf("Copy response error, err: %v", err)
 | 
				
			||||||
 | 
							w.WriteHeader(http.StatusBadGateway)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *cfdServer) writeErrorResponse(w http.ResponseWriter, err error) {
 | 
				
			||||||
 | 
						c.logger.Errorf("HTTP request error: %s", err)
 | 
				
			||||||
 | 
						c.config.Metrics.incrementResponses(c.connectionIndex, "502")
 | 
				
			||||||
 | 
						jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
 | 
				
			||||||
 | 
						w.WriteHeader(http.StatusBadGateway)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *cfdServer) logRequest(r *http.Request, cfRay string, lbProbe bool) {
 | 
				
			||||||
 | 
						logger := c.logger
 | 
				
			||||||
 | 
						if cfRay != "" {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
 | 
				
			||||||
 | 
						} else if lbProbe {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						logger.Infof("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if contentLen := r.ContentLength; contentLen == -1 {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *cfdServer) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
 | 
				
			||||||
 | 
						c.config.Metrics.incrementResponses(c.connectionIndex, "200")
 | 
				
			||||||
 | 
						logger := c.logger
 | 
				
			||||||
 | 
						if cfRay != "" {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
 | 
				
			||||||
 | 
						} else if lbProbe {
 | 
				
			||||||
 | 
							logger.Debugf("Response to Load Balancer health check %s", r.Status)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							logger.Infof("%s", r.Status)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if contentLen := r.ContentLength; contentLen == -1 {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type WebsocketResp interface {
 | 
				
			||||||
 | 
						WriteRespHeaders(*http.Response) error
 | 
				
			||||||
 | 
						io.ReadWriter
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type http2WebsocketResp struct {
 | 
				
			||||||
 | 
						pr *io.PipeReader
 | 
				
			||||||
 | 
						w  http.ResponseWriter
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newWebsocketBody(w http.ResponseWriter, r *http.Request, logger logger.Service) *http2WebsocketResp {
 | 
				
			||||||
 | 
						pr, pw := io.Pipe()
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							n, err := io.Copy(pw, r.Body)
 | 
				
			||||||
 | 
							logger.Errorf("websocket body copy ended, err: %v, bytes: %d", err, n)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						return &http2WebsocketResp{pr: pr, w: w}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (wr *http2WebsocketResp) WriteRespHeaders(resp *http.Response) error {
 | 
				
			||||||
 | 
						dest := wr.w.Header()
 | 
				
			||||||
 | 
						for name, values := range resp.Header {
 | 
				
			||||||
 | 
							for _, v := range values {
 | 
				
			||||||
 | 
								dest.Add(name, v)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (wr *http2WebsocketResp) Read(p []byte) (n int, err error) {
 | 
				
			||||||
 | 
						return wr.pr.Read(p)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (wr *http2WebsocketResp) Write(p []byte) (n int, err error) {
 | 
				
			||||||
 | 
						return wr.w.Write(p)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type h2muxWebsocketResp struct {
 | 
				
			||||||
 | 
						*h2mux.MuxedStream
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error {
 | 
				
			||||||
 | 
						return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										120
									
								
								origin/tunnel.go
								
								
								
								
							
							
						
						
									
										120
									
								
								origin/tunnel.go
								
								
								
								
							| 
						 | 
					@ -17,7 +17,9 @@ import (
 | 
				
			||||||
	"github.com/google/uuid"
 | 
						"github.com/google/uuid"
 | 
				
			||||||
	"github.com/pkg/errors"
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
	"github.com/prometheus/client_golang/prometheus"
 | 
						"github.com/prometheus/client_golang/prometheus"
 | 
				
			||||||
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
	"golang.org/x/sync/errgroup"
 | 
						"golang.org/x/sync/errgroup"
 | 
				
			||||||
 | 
						"zombiezen.com/go/capnproto2/rpc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/buffer"
 | 
						"github.com/cloudflare/cloudflared/buffer"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
 | 
						"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
 | 
				
			||||||
| 
						 | 
					@ -30,6 +32,7 @@ import (
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
						"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
						"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
						tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/validation"
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/websocket"
 | 
						"github.com/cloudflare/cloudflared/websocket"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -304,7 +307,14 @@ func ServeTunnel(
 | 
				
			||||||
	connectionTag := uint8ToString(connectionIndex)
 | 
						connectionTag := uint8ToString(connectionIndex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
 | 
						if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
 | 
				
			||||||
		return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh)
 | 
							tlsConn, err := RegisterConnection(ctx, config, connectionIndex, uint8(backoff.retries), addr)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logger.Errorf("Register connectio error: %+v", err)
 | 
				
			||||||
 | 
								return err, true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							connectedFuse.Fuse(true)
 | 
				
			||||||
 | 
							backoff.SetGracePeriod()
 | 
				
			||||||
 | 
							return serveNamedTunnel(ctx, config, tlsConn, connectionIndex, reconnectCh)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Returns error from parsing the origin URL or handshake errors
 | 
						// Returns error from parsing the origin URL or handshake errors
 | 
				
			||||||
| 
						 | 
					@ -332,10 +342,6 @@ func ServeTunnel(
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if config.NamedTunnel != nil {
 | 
					 | 
				
			||||||
			return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr, uint8(backoff.retries))
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if config.UseReconnectToken && connectedFuse.Value() {
 | 
							if config.UseReconnectToken && connectedFuse.Value() {
 | 
				
			||||||
			err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
 | 
								err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
 | 
				
			||||||
			if err == nil {
 | 
								if err == nil {
 | 
				
			||||||
| 
						 | 
					@ -426,7 +432,55 @@ func ServeTunnel(
 | 
				
			||||||
	return nil, true
 | 
						return nil, true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RegisterConnection(
 | 
					func serveNamedTunnel(
 | 
				
			||||||
 | 
						ctx context.Context,
 | 
				
			||||||
 | 
						config *TunnelConfig,
 | 
				
			||||||
 | 
						tlsConn net.Conn,
 | 
				
			||||||
 | 
						connectionIndex uint8,
 | 
				
			||||||
 | 
						reconnectCh chan ReconnectSignal,
 | 
				
			||||||
 | 
					) (err error, recoverable bool) {
 | 
				
			||||||
 | 
						originURLStr, err := validation.ValidateUrl(config.OriginUrl)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("unable to parse origin URL %#v", config.OriginUrl), false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						originURL, err := url.Parse(originURLStr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("unable to parse origin URL %#v", originURLStr), false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						originClient := config.HTTPTransport
 | 
				
			||||||
 | 
						if originClient == nil {
 | 
				
			||||||
 | 
							originClient = http.DefaultTransport
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						errGroup, serveCtx := errgroup.WithContext(ctx)
 | 
				
			||||||
 | 
						errGroup.Go(func() error {
 | 
				
			||||||
 | 
							cfdServer := &cfdServer{
 | 
				
			||||||
 | 
								httpServer:      &http2.Server{},
 | 
				
			||||||
 | 
								originClient:    originClient,
 | 
				
			||||||
 | 
								logger:          config.Logger,
 | 
				
			||||||
 | 
								originURL:       originURL,
 | 
				
			||||||
 | 
								connectionIndex: uint8ToString(connectionIndex),
 | 
				
			||||||
 | 
								config:          config,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							cfdServer.serve(serveCtx, tlsConn)
 | 
				
			||||||
 | 
							return fmt.Errorf("Connection with edge closed")
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						errGroup.Go(func() error {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case reconnect := <-reconnectCh:
 | 
				
			||||||
 | 
								return &reconnect
 | 
				
			||||||
 | 
							case <-serveCtx.Done():
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = errGroup.Wait()
 | 
				
			||||||
 | 
						return err, true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func RegisterConnectionWithH2Mux(
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	muxer *h2mux.Muxer,
 | 
						muxer *h2mux.Muxer,
 | 
				
			||||||
	config *TunnelConfig,
 | 
						config *TunnelConfig,
 | 
				
			||||||
| 
						 | 
					@ -470,6 +524,52 @@ func RegisterConnection(
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func RegisterConnection(
 | 
				
			||||||
 | 
						ctx context.Context,
 | 
				
			||||||
 | 
						config *TunnelConfig,
 | 
				
			||||||
 | 
						connectionID uint8,
 | 
				
			||||||
 | 
						numPreviousAttempts uint8,
 | 
				
			||||||
 | 
						addr *net.TCPAddr,
 | 
				
			||||||
 | 
					) (net.Conn, error) {
 | 
				
			||||||
 | 
						originCert, err := tls.X509KeyPair(config.OriginCert, config.OriginCert)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tlsConfig := config.TlsConfig
 | 
				
			||||||
 | 
						tlsConfig.Certificates = []tls.Certificate{originCert}
 | 
				
			||||||
 | 
						tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rpcTransport := tunnelrpc.NewTransportLogger(config.Logger, rpc.StreamTransport(&persistentConn{tlsServerConn}))
 | 
				
			||||||
 | 
						rpcConn := rpc.NewConn(
 | 
				
			||||||
 | 
							rpcTransport,
 | 
				
			||||||
 | 
							tunnelrpc.ConnLog(config.Logger),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx), Conn: rpcConn}
 | 
				
			||||||
 | 
						connDetail, err := rpcClient.RegisterConnection(
 | 
				
			||||||
 | 
							ctx,
 | 
				
			||||||
 | 
							config.NamedTunnel.Auth,
 | 
				
			||||||
 | 
							config.NamedTunnel.ID,
 | 
				
			||||||
 | 
							connectionID,
 | 
				
			||||||
 | 
							config.ConnectionOptions(tlsServerConn.LocalAddr().String(), numPreviousAttempts),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config.Logger.Infof("Connection %d registered with %s using ID %s", connectionID, connDetail.Location, connDetail.UUID)
 | 
				
			||||||
 | 
						rpcTransport.Close()
 | 
				
			||||||
 | 
						// Closing the client will also close the connection
 | 
				
			||||||
 | 
						rpcClient.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						flushMessage := make([]byte, 8)
 | 
				
			||||||
 | 
						buf := make([]byte, len(flushMessage))
 | 
				
			||||||
 | 
						tlsServerConn.Write(buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return tlsServerConn, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
 | 
					func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
 | 
				
			||||||
	if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
 | 
						if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
 | 
				
			||||||
		return &serverRegisterTunnelError{
 | 
							return &serverRegisterTunnelError{
 | 
				
			||||||
| 
						 | 
					@ -698,7 +798,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
				
			||||||
	var resp *http.Response
 | 
						var resp *http.Response
 | 
				
			||||||
	var respErr error
 | 
						var respErr error
 | 
				
			||||||
	if websocket.IsWebSocketUpgrade(req) {
 | 
						if websocket.IsWebSocketUpgrade(req) {
 | 
				
			||||||
		resp, respErr = h.serveWebsocket(stream, req, rule)
 | 
							resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		resp, respErr = h.serveHTTP(stream, req, rule)
 | 
							resp, respErr = h.serveHTTP(stream, req, rule)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -725,7 +825,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
 | 
				
			||||||
	return req, rule, nil
 | 
						return req, rule, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
 | 
					func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
 | 
				
			||||||
	if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
 | 
						if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
 | 
				
			||||||
		req.Header.Set("Host", hostHeader)
 | 
							req.Header.Set("Host", hostHeader)
 | 
				
			||||||
		req.Host = hostHeader
 | 
							req.Host = hostHeader
 | 
				
			||||||
| 
						 | 
					@ -740,13 +840,13 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer conn.Close()
 | 
						defer conn.Close()
 | 
				
			||||||
	err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response))
 | 
						err = wsResp.WriteRespHeaders(response)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, errors.Wrap(err, "Error writing response header")
 | 
							return nil, errors.Wrap(err, "Error writing response header")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// Copy to/from stream to the undelying connection. Use the underlying
 | 
						// Copy to/from stream to the undelying connection. Use the underlying
 | 
				
			||||||
	// connection because cloudflared doesn't operate on the message themselves
 | 
						// connection because cloudflared doesn't operate on the message themselves
 | 
				
			||||||
	websocket.Stream(conn.UnderlyingConn(), stream)
 | 
						websocket.Stream(conn.UnderlyingConn(), wsResp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return response, nil
 | 
						return response, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,10 @@ const (
 | 
				
			||||||
	OriginCAPoolFlag = "origin-ca-pool"
 | 
						OriginCAPoolFlag = "origin-ca-pool"
 | 
				
			||||||
	CaCertFlag       = "cacert"
 | 
						CaCertFlag       = "cacert"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	edgeTLSServerName = "cftunnel.com"
 | 
						// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
 | 
				
			||||||
 | 
						edgeH2muxTLSServerName = "cftunnel.com"
 | 
				
			||||||
 | 
						// edgeH2TLSServerName is the server name to establish http2 connection with edge
 | 
				
			||||||
 | 
						edgeH2TLSServerName = "h2.cftunnel.com"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CertReloader can load and reload a TLS certificate from a particular filepath.
 | 
					// CertReloader can load and reload a TLS certificate from a particular filepath.
 | 
				
			||||||
| 
						 | 
					@ -120,13 +123,17 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
 | 
				
			||||||
	return certPool, nil
 | 
						return certPool, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) {
 | 
					func CreateTunnelConfig(c *cli.Context, isNamedTunnel bool) (*tls.Config, error) {
 | 
				
			||||||
	var rootCAs []string
 | 
						var rootCAs []string
 | 
				
			||||||
	if c.String(CaCertFlag) != "" {
 | 
						if c.String(CaCertFlag) != "" {
 | 
				
			||||||
		rootCAs = append(rootCAs, c.String(CaCertFlag))
 | 
							rootCAs = append(rootCAs, c.String(CaCertFlag))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: edgeTLSServerName}
 | 
						serverName := edgeH2muxTLSServerName
 | 
				
			||||||
 | 
						if isNamedTunnel {
 | 
				
			||||||
 | 
							serverName = edgeH2TLSServerName
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
 | 
				
			||||||
	tlsConfig, err := GetConfig(userConfig)
 | 
						tlsConfig, err := GetConfig(userConfig)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -176,8 +176,12 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	initialRequest.Host = hostname
 | 
						initialRequest.Host = hostname
 | 
				
			||||||
	_, initialErr := client.Do(initialRequest)
 | 
						resp, initialErr := client.Do(initialRequest)
 | 
				
			||||||
	if initialErr != nil {
 | 
						if initialErr == nil {
 | 
				
			||||||
 | 
							resp.Body.Close()
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
 | 
						// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
 | 
				
			||||||
	oldScheme := parsedURL.Scheme
 | 
						oldScheme := parsedURL.Scheme
 | 
				
			||||||
	parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
 | 
						parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
 | 
				
			||||||
| 
						 | 
					@ -187,8 +191,9 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	secondRequest.Host = hostname
 | 
						secondRequest.Host = hostname
 | 
				
			||||||
		_, secondErr := client.Do(secondRequest)
 | 
						resp, secondErr := client.Do(secondRequest)
 | 
				
			||||||
	if secondErr == nil { // Worked this time--advise the user to switch protocols
 | 
						if secondErr == nil { // Worked this time--advise the user to switch protocols
 | 
				
			||||||
 | 
							resp.Body.Close()
 | 
				
			||||||
		return errors.Errorf(
 | 
							return errors.Errorf(
 | 
				
			||||||
			"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
 | 
								"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
 | 
				
			||||||
			parsedURL.Host,
 | 
								parsedURL.Host,
 | 
				
			||||||
| 
						 | 
					@ -198,7 +203,6 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
 | 
				
			||||||
			parsedURL,
 | 
								parsedURL,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return initialErr
 | 
						return initialErr
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue