TUN-3400: Use Go HTTP2 library as transport to connect with the edge

This commit is contained in:
cthuang 2020-09-11 23:02:34 +01:00
parent d7498b0c03
commit 8d7b2575ba
7 changed files with 324 additions and 37 deletions

View File

@ -156,7 +156,8 @@ func prepareTunnelConfig(
transportLogger logger.Service, transportLogger logger.Service,
namedTunnel *origin.NamedTunnelConfig, namedTunnel *origin.NamedTunnelConfig,
) (*origin.TunnelConfig, error) { ) (*origin.TunnelConfig, error) {
compatibilityMode := namedTunnel == nil isNamedTunnel := namedTunnel != nil
compatibilityMode := !isNamedTunnel
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
@ -219,7 +220,7 @@ func prepareTunnelConfig(
} }
} }
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c) toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel)
if err != nil { if err != nil {
logger.Errorf("unable to create TLS config to connect with edge: %s", err) logger.Errorf("unable to create TLS config to connect with edge: %s", err)
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")

View File

@ -9,7 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// DialEdge makes a TLS connection to a Cloudflare edge node // DialEdgeWithH2Mux makes a TLS connection to a Cloudflare edge node
func DialEdge( func DialEdge(
ctx context.Context, ctx context.Context,
timeout time.Duration, timeout time.Duration,
@ -25,6 +25,7 @@ func DialEdge(
if err != nil { if err != nil {
return nil, newDialError(err, "DialContext error") return nil, newDialError(err, "DialContext error")
} }
tlsEdgeConn := tls.Client(edgeConn, tlsConfig) tlsEdgeConn := tls.Client(edgeConn, tlsConfig)
tlsEdgeConn.SetDeadline(time.Now().Add(timeout)) tlsEdgeConn.SetDeadline(time.Now().Add(timeout))

14
origin/connection.go Normal file
View File

@ -0,0 +1,14 @@
package origin
import (
"net"
)
// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called
type persistentConn struct {
net.Conn
}
func (pc *persistentConn) Close() error {
return nil
}

160
origin/server.go Normal file
View File

@ -0,0 +1,160 @@
package origin
import (
"context"
"encoding/json"
"io"
"net"
"net/http"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"golang.org/x/net/http2"
)
type cfdServer struct {
httpServer *http2.Server
originClient http.RoundTripper
logger logger.Service
originURL *url.URL
connectionIndex string
config *TunnelConfig
}
func (c *cfdServer) serve(ctx context.Context, conn net.Conn) {
go func() {
<-ctx.Done()
conn.Close()
}()
c.httpServer.ServeConn(conn, &http2.ServeConnOpts{
Context: ctx,
Handler: c,
})
}
func (c *cfdServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.config.Metrics.incrementRequests(c.connectionIndex)
defer c.config.Metrics.decrementConcurrentRequests(c.connectionIndex)
cfRay := findCfRayHeader(r)
lbProbe := isLBProbeRequest(r)
c.logRequest(r, cfRay, lbProbe)
r.URL = c.originURL
// TODO: TUN-3406 support websocket, event stream and WSGI servers.
var resp *http.Response
var err error
if strings.ToLower(r.Header.Get("Cf-Int-Argo-Tunnel-Upgrade")) == "websocket" {
resp, err = serveWebsocket(newWebsocketBody(w, r, c.logger), r, c.config.HTTPHostHeader, c.config.ClientTlsConfig)
} else {
resp, err = c.originClient.RoundTrip(r)
}
if err != nil {
c.writeErrorResponse(w, err)
return
}
defer resp.Body.Close()
w.WriteHeader(resp.StatusCode)
_, err = io.Copy(w, resp.Body)
if err != nil {
c.logger.Errorf("Copy response error, err: %v", err)
w.WriteHeader(http.StatusBadGateway)
return
}
}
func (c *cfdServer) writeErrorResponse(w http.ResponseWriter, err error) {
c.logger.Errorf("HTTP request error: %s", err)
c.config.Metrics.incrementResponses(c.connectionIndex, "502")
jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
if err != nil {
panic(err)
}
w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
w.WriteHeader(http.StatusBadGateway)
}
func (c *cfdServer) logRequest(r *http.Request, cfRay string, lbProbe bool) {
logger := c.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else if lbProbe {
logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else {
logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
}
logger.Infof("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
}
}
func (c *cfdServer) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
c.config.Metrics.incrementResponses(c.connectionIndex, "200")
logger := c.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
} else if lbProbe {
logger.Debugf("Response to Load Balancer health check %s", r.Status)
} else {
logger.Infof("%s", r.Status)
}
logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
}
}
type WebsocketResp interface {
WriteRespHeaders(*http.Response) error
io.ReadWriter
}
type http2WebsocketResp struct {
pr *io.PipeReader
w http.ResponseWriter
}
func newWebsocketBody(w http.ResponseWriter, r *http.Request, logger logger.Service) *http2WebsocketResp {
pr, pw := io.Pipe()
go func() {
n, err := io.Copy(pw, r.Body)
logger.Errorf("websocket body copy ended, err: %v, bytes: %d", err, n)
}()
return &http2WebsocketResp{pr: pr, w: w}
}
func (wr *http2WebsocketResp) WriteRespHeaders(resp *http.Response) error {
dest := wr.w.Header()
for name, values := range resp.Header {
for _, v := range values {
dest.Add(name, v)
}
}
return nil
}
func (wr *http2WebsocketResp) Read(p []byte) (n int, err error) {
return wr.pr.Read(p)
}
func (wr *http2WebsocketResp) Write(p []byte) (n int, err error) {
return wr.w.Write(p)
}
type h2muxWebsocketResp struct {
*h2mux.MuxedStream
}
func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error {
return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
}

View File

@ -17,7 +17,9 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/http2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
@ -30,6 +32,7 @@ import (
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
) )
@ -304,7 +307,14 @@ func ServeTunnel(
connectionTag := uint8ToString(connectionIndex) connectionTag := uint8ToString(connectionIndex)
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol { if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh) tlsConn, err := RegisterConnection(ctx, config, connectionIndex, uint8(backoff.retries), addr)
if err != nil {
logger.Errorf("Register connectio error: %+v", err)
return err, true
}
connectedFuse.Fuse(true)
backoff.SetGracePeriod()
return serveNamedTunnel(ctx, config, tlsConn, connectionIndex, reconnectCh)
} }
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
@ -332,10 +342,6 @@ func ServeTunnel(
} }
}() }()
if config.NamedTunnel != nil {
return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr, uint8(backoff.retries))
}
if config.UseReconnectToken && connectedFuse.Value() { if config.UseReconnectToken && connectedFuse.Value() {
err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager) err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
if err == nil { if err == nil {
@ -426,7 +432,55 @@ func ServeTunnel(
return nil, true return nil, true
} }
func RegisterConnection( func serveNamedTunnel(
ctx context.Context,
config *TunnelConfig,
tlsConn net.Conn,
connectionIndex uint8,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
originURLStr, err := validation.ValidateUrl(config.OriginUrl)
if err != nil {
return fmt.Errorf("unable to parse origin URL %#v", config.OriginUrl), false
}
originURL, err := url.Parse(originURLStr)
if err != nil {
return fmt.Errorf("unable to parse origin URL %#v", originURLStr), false
}
originClient := config.HTTPTransport
if originClient == nil {
originClient = http.DefaultTransport
}
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
cfdServer := &cfdServer{
httpServer: &http2.Server{},
originClient: originClient,
logger: config.Logger,
originURL: originURL,
connectionIndex: uint8ToString(connectionIndex),
config: config,
}
cfdServer.serve(serveCtx, tlsConn)
return fmt.Errorf("Connection with edge closed")
})
errGroup.Go(func() error {
select {
case reconnect := <-reconnectCh:
return &reconnect
case <-serveCtx.Done():
return nil
}
})
err = errGroup.Wait()
return err, true
}
func RegisterConnectionWithH2Mux(
ctx context.Context, ctx context.Context,
muxer *h2mux.Muxer, muxer *h2mux.Muxer,
config *TunnelConfig, config *TunnelConfig,
@ -470,6 +524,52 @@ func RegisterConnection(
return nil return nil
} }
func RegisterConnection(
ctx context.Context,
config *TunnelConfig,
connectionID uint8,
numPreviousAttempts uint8,
addr *net.TCPAddr,
) (net.Conn, error) {
originCert, err := tls.X509KeyPair(config.OriginCert, config.OriginCert)
if err != nil {
return nil, err
}
tlsConfig := config.TlsConfig
tlsConfig.Certificates = []tls.Certificate{originCert}
tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
if err != nil {
return nil, err
}
rpcTransport := tunnelrpc.NewTransportLogger(config.Logger, rpc.StreamTransport(&persistentConn{tlsServerConn}))
rpcConn := rpc.NewConn(
rpcTransport,
tunnelrpc.ConnLog(config.Logger),
)
rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx), Conn: rpcConn}
connDetail, err := rpcClient.RegisterConnection(
ctx,
config.NamedTunnel.Auth,
config.NamedTunnel.ID,
connectionID,
config.ConnectionOptions(tlsServerConn.LocalAddr().String(), numPreviousAttempts),
)
if err != nil {
return nil, err
}
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionID, connDetail.Location, connDetail.UUID)
rpcTransport.Close()
// Closing the client will also close the connection
rpcClient.Close()
flushMessage := make([]byte, 8)
buf := make([]byte, len(flushMessage))
tlsServerConn.Write(buf)
return tlsServerConn, nil
}
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
if retryable, ok := err.(*tunnelpogs.RetryableError); ok { if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
return &serverRegisterTunnelError{ return &serverRegisterTunnelError{
@ -698,7 +798,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
var resp *http.Response var resp *http.Response
var respErr error var respErr error
if websocket.IsWebSocketUpgrade(req) { if websocket.IsWebSocketUpgrade(req) {
resp, respErr = h.serveWebsocket(stream, req, rule) resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule)
} else { } else {
resp, respErr = h.serveHTTP(stream, req, rule) resp, respErr = h.serveHTTP(stream, req, rule)
} }
@ -725,7 +825,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
return req, rule, nil return req, rule, nil
} }
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader) req.Header.Set("Host", hostHeader)
req.Host = hostHeader req.Host = hostHeader
@ -740,13 +840,13 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ
return nil, err return nil, err
} }
defer conn.Close() defer conn.Close()
err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response)) err = wsResp.WriteRespHeaders(response)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error writing response header") return nil, errors.Wrap(err, "Error writing response header")
} }
// Copy to/from stream to the undelying connection. Use the underlying // Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves // connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream) websocket.Stream(conn.UnderlyingConn(), wsResp)
return response, nil return response, nil
} }

View File

@ -18,7 +18,10 @@ const (
OriginCAPoolFlag = "origin-ca-pool" OriginCAPoolFlag = "origin-ca-pool"
CaCertFlag = "cacert" CaCertFlag = "cacert"
edgeTLSServerName = "cftunnel.com" // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
) )
// CertReloader can load and reload a TLS certificate from a particular filepath. // CertReloader can load and reload a TLS certificate from a particular filepath.
@ -120,13 +123,17 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
return certPool, nil return certPool, nil
} }
func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) { func CreateTunnelConfig(c *cli.Context, isNamedTunnel bool) (*tls.Config, error) {
var rootCAs []string var rootCAs []string
if c.String(CaCertFlag) != "" { if c.String(CaCertFlag) != "" {
rootCAs = append(rootCAs, c.String(CaCertFlag)) rootCAs = append(rootCAs, c.String(CaCertFlag))
} }
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: edgeTLSServerName} serverName := edgeH2muxTLSServerName
if isNamedTunnel {
serverName = edgeH2TLSServerName
}
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
tlsConfig, err := GetConfig(userConfig) tlsConfig, err := GetConfig(userConfig)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -176,8 +176,12 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
return err return err
} }
initialRequest.Host = hostname initialRequest.Host = hostname
_, initialErr := client.Do(initialRequest) resp, initialErr := client.Do(initialRequest)
if initialErr != nil { if initialErr == nil {
resp.Body.Close()
return nil
}
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := parsedURL.Scheme oldScheme := parsedURL.Scheme
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
@ -187,8 +191,9 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
return err return err
} }
secondRequest.Host = hostname secondRequest.Host = hostname
_, secondErr := client.Do(secondRequest) resp, secondErr := client.Do(secondRequest)
if secondErr == nil { // Worked this time--advise the user to switch protocols if secondErr == nil { // Worked this time--advise the user to switch protocols
resp.Body.Close()
return errors.Errorf( return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s", "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
parsedURL.Host, parsedURL.Host,
@ -198,7 +203,6 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
parsedURL, parsedURL,
) )
} }
}
return initialErr return initialErr
} }