TUN-3500: Integrate replace h2mux by http2 work with multiple origin support

This commit is contained in:
cthuang 2020-11-02 11:21:34 +00:00
parent eef5b78eac
commit 5974fb4cfd
16 changed files with 252 additions and 716 deletions

View File

@ -190,9 +190,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
// ValidateUrl will validate url flag correctness. It can be either from --url or argument // ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context, allowFromArgs bool) (*url.URL, error) { func ValidateUrl(c *cli.Context, allowURLFromArgs bool) (*url.URL, error) {
var url = c.String("url") var url = c.String("url")
if allowFromArgs && c.NArg() > 0 { if allowURLFromArgs && c.NArg() > 0 {
if c.IsSet("url") { if c.IsSet("url") {
return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
} }

View File

@ -367,12 +367,12 @@ func StartServer(
return errors.Wrap(err, "error setting up transport logger") return errors.Wrap(err, "error setting up transport logger")
} }
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled) tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled)
if err != nil { if err != nil {
return err return err
} }
tunnelConfig.IngressRules.StartOrigins(&wg, log, shutdownC, errC) ingressRules.StartOrigins(&wg, log, shutdownC, errC)
reconnectCh := make(chan origin.ReconnectSignal, 1) reconnectCh := make(chan origin.ReconnectSignal, 1)
if c.IsSet("stdin-control") { if c.IsSet("stdin-control") {
@ -391,8 +391,7 @@ func StartServer(
version, version,
hostname, hostname,
metricsListener.Addr().String(), metricsListener.Addr().String(),
// TODO (TUN-3461): Update UI to show multiple origin URLs &ingressRules,
&tunnelConfig.IngressRules,
tunnelConfig.HAConnections, tunnelConfig.HAConnections,
) )
logLevels, err := logger.ParseLevelString(c.String("loglevel")) logLevels, err := logger.ParseLevelString(c.String("loglevel"))

View File

@ -1,6 +1,7 @@
package tunnel package tunnel
import ( import (
"crypto/tls"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -160,27 +161,27 @@ func prepareTunnelConfig(
transportLogger logger.Service, transportLogger logger.Service,
namedTunnel *connection.NamedTunnelConfig, namedTunnel *connection.NamedTunnelConfig,
uiIsEnabled bool, uiIsEnabled bool,
) (*origin.TunnelConfig, error) { ) (*origin.TunnelConfig, ingress.Ingress, error) {
isNamedTunnel := namedTunnel != nil isNamedTunnel := namedTunnel != nil
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
logger.Errorf("Invalid hostname: %s", err) logger.Errorf("Invalid hostname: %s", err)
return nil, errors.Wrap(err, "Invalid hostname") return nil, ingress.Ingress{}, errors.Wrap(err, "Invalid hostname")
} }
isFreeTunnel := hostname == "" isFreeTunnel := hostname == ""
clientID := c.String("id") clientID := c.String("id")
if !c.IsSet("id") { if !c.IsSet("id") {
clientID, err = generateRandomClientID(logger) clientID, err = generateRandomClientID(logger)
if err != nil { if err != nil {
return nil, err return nil, ingress.Ingress{}, err
} }
} }
tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil { if err != nil {
logger.Errorf("Tag parse failure: %s", err) logger.Errorf("Tag parse failure: %s", err)
return nil, errors.Wrap(err, "Tag parse failure") return nil, ingress.Ingress{}, errors.Wrap(err, "Tag parse failure")
} }
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
@ -189,7 +190,7 @@ func prepareTunnelConfig(
if !isFreeTunnel { if !isFreeTunnel {
originCert, err = getOriginCert(c, logger) originCert, err = getOriginCert(c, logger)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error getting origin cert") return nil, ingress.Ingress{}, errors.Wrap(err, "Error getting origin cert")
} }
} }
@ -200,7 +201,7 @@ func prepareTunnelConfig(
if isNamedTunnel { if isNamedTunnel {
clientUUID, err := uuid.NewRandom() clientUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "can't generate clientUUID") return nil, ingress.Ingress{}, errors.Wrap(err, "can't generate clientUUID")
} }
namedTunnel.Client = tunnelpogs.ClientInfo{ namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:], ClientID: clientUUID[:],
@ -210,10 +211,10 @@ func prepareTunnelConfig(
} }
ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) ingressRules, err = ingress.ParseIngress(config.GetConfiguration())
if err != nil && err != ingress.ErrNoIngressRules { if err != nil && err != ingress.ErrNoIngressRules {
return nil, err return nil, ingress.Ingress{}, err
} }
if !ingressRules.IsEmpty() && c.IsSet("url") { if !ingressRules.IsEmpty() && c.IsSet("url") {
return nil, ingress.ErrURLIncompatibleWithIngress return nil, ingress.Ingress{}, ingress.ErrURLIncompatibleWithIngress
} }
} else { } else {
classicTunnel = &connection.ClassicTunnelConfig{ classicTunnel = &connection.ClassicTunnelConfig{
@ -226,15 +227,15 @@ func prepareTunnelConfig(
// Convert single-origin configuration into multi-origin configuration. // Convert single-origin configuration into multi-origin configuration.
if ingressRules.IsEmpty() { if ingressRules.IsEmpty() {
ingressRules, err = ingress.NewSingleOrigin(c, compatibilityMode, logger) ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel, logger)
if err != nil { if err != nil {
return nil, err return nil, ingress.Ingress{}, err
} }
} }
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger) protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger)
if err != nil { if err != nil {
return nil, err return nil, ingress.Ingress{}, err
} }
logger.Infof("Initial protocol %s", protocolSelector.Current()) logger.Infof("Initial protocol %s", protocolSelector.Current())
@ -242,20 +243,12 @@ func prepareTunnelConfig(
for _, p := range connection.ProtocolList { for _, p := range connection.ProtocolList {
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName()) edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
edgeTLSConfigs[p] = edgeTLSConfig edgeTLSConfigs[p] = edgeTLSConfig
} }
proxyConfig := &origin.ProxyConfig{ originClient := origin.NewClient(ingressRules, tags, logger)
Client: httpTransport,
URL: originURL,
TLSConfig: httpTransport.TLSClientConfig,
HostHeader: c.String("http-host-header"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
Tags: tags,
}
originClient := origin.NewClient(proxyConfig, logger)
connectionConfig := &connection.Config{ connectionConfig := &connection.Config{
OriginClient: originClient, OriginClient: originClient,
GracePeriod: c.Duration("grace-period"), GracePeriod: c.Duration("grace-period"),
@ -275,7 +268,6 @@ func prepareTunnelConfig(
return &origin.TunnelConfig{ return &origin.TunnelConfig{
ConnectionConfig: connectionConfig, ConnectionConfig: connectionConfig,
ProxyConfig: proxyConfig,
BuildInfo: buildInfo, BuildInfo: buildInfo,
ClientID: clientID, ClientID: clientID,
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
@ -284,6 +276,7 @@ func prepareTunnelConfig(
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel, IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"), LBPool: c.String("lb-pool"),
Tags: tags,
Logger: logger, Logger: logger,
Observer: connection.NewObserver(transportLogger, tunnelEventChan), Observer: connection.NewObserver(transportLogger, tunnelEventChan),
ReportedVersion: version, ReportedVersion: version,
@ -293,10 +286,9 @@ func prepareTunnelConfig(
ClassicTunnel: classicTunnel, ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig, MuxerConfig: muxerConfig,
TunnelEventChan: tunnelEventChan, TunnelEventChan: tunnelEventChan,
IngressRules: ingressRules,
ProtocolSelector: protocolSelector, ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs, EdgeTLSConfigs: edgeTLSConfigs,
}, nil }, ingressRules, nil
} }
func isRunningFromTerminal() bool { func isRunningFromTerminal() bool {

View File

@ -22,7 +22,6 @@ const (
type h2muxConnection struct { type h2muxConnection struct {
config *Config config *Config
muxerConfig *MuxerConfig muxerConfig *MuxerConfig
originURL string
muxer *h2mux.Muxer muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
connIndexStr string connIndexStr string
@ -54,7 +53,6 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.S
func NewH2muxConnection(ctx context.Context, func NewH2muxConnection(ctx context.Context,
config *Config, config *Config,
muxerConfig *MuxerConfig, muxerConfig *MuxerConfig,
originURL string,
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
@ -62,7 +60,6 @@ func NewH2muxConnection(ctx context.Context,
h := &h2muxConnection{ h := &h2muxConnection{
config: config, config: config,
muxerConfig: muxerConfig, muxerConfig: muxerConfig,
originURL: originURL,
connIndexStr: uint8ToString(connIndex), connIndexStr: uint8ToString(connIndex),
connIndex: connIndex, connIndex: connIndex,
observer: observer, observer: observer,
@ -188,7 +185,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
} }
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) { func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req, err := http.NewRequest("GET", h.originURL, h2mux.MuxedStreamReader{MuxedStream: stream}) req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
} }

View File

@ -7,7 +7,6 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
@ -31,7 +30,6 @@ type HTTP2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
config *Config config *Config
originURL *url.URL
namedTunnel *NamedTunnelConfig namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
observer *Observer observer *Observer
@ -44,7 +42,6 @@ type HTTP2Connection struct {
func NewHTTP2Connection( func NewHTTP2Connection(
conn net.Conn, conn net.Conn,
config *Config, config *Config,
originURL *url.URL,
namedTunnelConfig *NamedTunnelConfig, namedTunnelConfig *NamedTunnelConfig,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
observer *Observer, observer *Observer,
@ -57,7 +54,6 @@ func NewHTTP2Connection(
MaxConcurrentStreams: math.MaxUint32, MaxConcurrentStreams: math.MaxUint32,
}, },
config: config, config: config,
originURL: originURL,
namedTunnel: namedTunnelConfig, namedTunnel: namedTunnelConfig,
connOptions: connOptions, connOptions: connOptions,
observer: observer, observer: observer,
@ -83,9 +79,6 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.wg.Add(1) c.wg.Add(1)
defer c.wg.Done() defer c.wg.Done()
r.URL.Scheme = c.originURL.Scheme
r.URL.Host = c.originURL.Host
respWriter := &http2RespWriter{ respWriter := &http2RespWriter{
r: r.Body, r: r.Body,
w: w, w: w,

View File

@ -22,6 +22,7 @@ const (
UptimeRoute = "/uptime" UptimeRoute = "/uptime"
WSRoute = "/ws" WSRoute = "/ws"
SSERoute = "/sse" SSERoute = "/sse"
HealthRoute = "/_health"
defaultSSEFreq = time.Second * 10 defaultSSEFreq = time.Second * 10
) )
@ -114,6 +115,7 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now())) muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader)) muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
muxer.HandleFunc(SSERoute, sseHandler(logger)) muxer.HandleFunc(SSERoute, sseHandler(logger))
muxer.HandleFunc(HealthRoute, healthHandler())
muxer.HandleFunc("/", rootHandler(serverName)) muxer.HandleFunc("/", rootHandler(serverName))
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer} httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
go func() { go func() {
@ -221,6 +223,12 @@ func sseHandler(logger logger.Service) http.HandlerFunc {
} }
} }
func healthHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}
}
func rootHandler(serverName string) http.HandlerFunc { func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {

View File

@ -63,9 +63,9 @@ type Ingress struct {
// NewSingleOrigin constructs an Ingress set with only one rule, constructed from // NewSingleOrigin constructs an Ingress set with only one rule, constructed from
// legacy CLI parameters like --url or --no-chunked-encoding. // legacy CLI parameters like --url or --no-chunked-encoding.
func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Service) (Ingress, error) { func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool, logger logger.Service) (Ingress, error) {
service, err := parseSingleOriginService(c, compatibilityMode) service, err := parseSingleOriginService(c, allowURLFromArgs)
if err != nil { if err != nil {
return Ingress{}, err return Ingress{}, err
} }
@ -85,19 +85,15 @@ func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Servi
} }
// Get a single origin service from the CLI/config. // Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) { func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
if c.IsSet("hello-world") { if c.IsSet("hello-world") {
return new(helloWorld), nil return new(helloWorld), nil
} }
if c.IsSet("url") { if c.IsSet("url") {
originURLStr, err := config.ValidateUrl(c, compatibilityMode) originURL, err := config.ValidateUrl(c, allowURLFromArgs)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error validating origin URL") return nil, errors.Wrap(err, "Error validating origin URL")
} }
originURL, err := url.Parse(originURLStr)
if err != nil {
return nil, errors.Wrap(err, "couldn't parse origin URL")
}
return &localService{URL: originURL, RootURL: originURL}, nil return &localService{URL: originURL, RootURL: originURL}, nil
} }
if c.IsSet("unix-socket") { if c.IsSet("unix-socket") {

View File

@ -245,7 +245,7 @@ type statusCode struct {
func newStatusCode(status int) statusCode { func newStatusCode(status int) statusCode {
resp := &http.Response{ resp := &http.Response{
StatusCode: status, StatusCode: status,
Status: http.StatusText(status), Status: fmt.Sprintf("%d %s", status, http.StatusText(status)),
Body: new(NopReadCloser), Body: new(NopReadCloser),
} }
return statusCode{resp: resp} return statusCode{resp: resp}

View File

@ -1,229 +0,0 @@
package origin
import (
"bufio"
"context"
"io"
"net"
"net/http"
"strconv"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
type TunnelHandler struct {
ingressRules ingress.Ingress
muxer *h2mux.Muxer
tags []tunnelpogs.Tag
metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string
logger logger.Service
bufferPool *buffer.Pool
}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewTunnelHandler(ctx context.Context,
config *TunnelConfig,
addr *net.TCPAddr,
connectionID uint8,
bufferPool *buffer.Pool,
) (*TunnelHandler, string, error) {
h := &TunnelHandler{
ingressRules: config.IngressRules,
tags: config.Tags,
metrics: config.Metrics,
connectionID: uint8ToString(connectionID),
logger: config.Logger,
bufferPool: bufferPool,
}
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
if err != nil {
return nil, "", err
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
if err != nil {
return nil, "", errors.Wrap(err, "h2mux handshake with edge error")
}
return h, edgeConn.LocalAddr().String(), nil
}
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
for _, tag := range h.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
}
}
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID)
defer h.metrics.decrementConcurrentRequests(h.connectionID)
req, rule, reqErr := h.createRequest(stream)
if reqErr != nil {
h.writeErrorResponse(stream, reqErr)
return reqErr
}
cfRay := findCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
h.logRequest(req, cfRay, lbProbe)
var resp *http.Response
var respErr error
if websocket.IsWebSocketUpgrade(req) {
resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule)
} else {
resp, respErr = h.serveHTTP(stream, req, rule)
}
if respErr != nil {
h.writeErrorResponse(stream, respErr)
return respErr
}
h.logResponseOk(resp, cfRay, lbProbe)
return nil
}
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) {
req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, nil, errors.Wrap(err, "invalid request received")
}
rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
rule.Service.RewriteOriginURL(req.URL)
return req, rule, nil
}
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
response, err := h.httpClient.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
defer response.Body.Close()
headers := h2mux.H1ResponseToH2ResponseHeaders(response)
headers = append(headers, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceOrigin))
err = stream.WriteHeaders(headers)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
if h.isEventStream(response) {
h.writeEventStream(stream, response.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
buf := h.bufferPool.Get()
defer h.bufferPool.Put(buf)
io.CopyBuffer(stream, response.Body, buf)
}
return response, nil
}
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
reader := bufio.NewReader(responseBody)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
break
}
stream.Write(line)
}
}
func (h *TunnelHandler) isEventStream(response *http.Response) bool {
if response.Header.Get("content-type") == "text/event-stream" {
h.logger.Debug("Detected Server-Side Events from Origin")
return true
}
return false
}
func (h *TunnelHandler) writeErrorResponse(stream *h2mux.MuxedStream, err error) {
h.logger.Errorf("HTTP request error: %s", err)
stream.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"},
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
})
stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502")
}
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) {
logger := h.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s %s %s", cfRay, req.Method, req.URL, req.Proto)
} else if lbProbe {
logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, req.Method, req.URL, req.Proto)
} else {
logger.Infof("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, req.Method, req.URL, req.Proto)
}
logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, req.Header)
if contentLen := req.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
}
}
func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
h.metrics.incrementResponses(h.connectionID, "200")
logger := h.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
} else if lbProbe {
logger.Debugf("Response to Load Balancer health check %s", r.Status)
} else {
logger.Infof("%s", r.Status)
}
logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
}
}
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
}
type h2muxWebsocketResp struct {
*h2mux.MuxedStream
}
func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error {
return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
}

View File

@ -1,320 +0,0 @@
package origin
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/pkg/errors"
"golang.org/x/net/http2"
"zombiezen.com/go/capnproto2/rpc"
)
const (
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
websocketUpgrade = "websocket"
controlPlaneUpgrade = "control-plane"
)
type http2Server struct {
server *http2.Server
ingressRules ingress.Ingress
logger logger.Service
connIndexStr string
connIndex uint8
config *TunnelConfig
localAddr net.Addr
shutdownChan chan struct{}
connectedFuse *h2mux.BooleanFuse
}
func newHTTP2Server(config *TunnelConfig, connIndex uint8, localAddr net.Addr, connectedFuse *h2mux.BooleanFuse) (*http2Server, error) {
return &http2Server{
server: &http2.Server{},
ingressRules: config.IngressRules,
logger: config.Logger,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
config: config,
localAddr: localAddr,
shutdownChan: make(chan struct{}),
connectedFuse: connectedFuse,
}, nil
}
func (c *http2Server) serve(ctx context.Context, conn net.Conn) {
go func() {
<-ctx.Done()
c.close(conn)
}()
c.server.ServeConn(conn, &http2.ServeConnOpts{
Context: ctx,
Handler: c,
})
}
func (c *http2Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.config.Metrics.incrementRequests(c.connIndexStr)
defer c.config.Metrics.decrementConcurrentRequests(c.connIndexStr)
cfRay := findCfRayHeader(r)
lbProbe := isLBProbeRequest(r)
c.logRequest(r, cfRay, lbProbe)
rule, _ := c.ingressRules.FindMatchingRule(r.Host, r.URL.Path)
rule.Service.RewriteOriginURL(r.URL)
var resp *http.Response
var err error
if isControlPlaneUpgrade(r) {
stripWebsocketUpgradeHeader(r)
err = c.serveControlPlane(w, r)
} else if isWebsocketUpgrade(r) {
stripWebsocketUpgradeHeader(r)
var respBody BidirectionalStream
respBody, err = newHTTP2Stream(w, r)
if err == nil {
resp, err = serveWebsocket(respBody, r, rule)
}
} else {
resp, err = c.serveHTTP(w, r, rule)
}
if err != nil {
c.writeErrorResponse(w, err)
return
}
if resp != nil {
resp.Body.Close()
}
}
func (c *http2Server) serveHTTP(w http.ResponseWriter, r *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
r.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(r.Header.Get("Content-Length"))
if err == nil {
r.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
r.Header.Set("Connection", "keep-alive")
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
r.Header.Set("Host", hostHeader)
r.Host = hostHeader
}
resp, err := rule.HTTPTransport.RoundTrip(r)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
w.WriteHeader(resp.StatusCode)
_, err = io.Copy(w, resp.Body)
if err != nil {
return nil, errors.Wrap(err, "Copy response error")
}
return resp, nil
}
func (c *http2Server) serveControlPlane(w http.ResponseWriter, r *http.Request) error {
stream, err := newHTTP2Stream(w, r)
if err != nil {
return err
}
rpcTransport := tunnelrpc.NewTransportLogger(c.logger, rpc.StreamTransport(stream))
rpcConn := rpc.NewConn(
rpcTransport,
tunnelrpc.ConnLog(c.logger),
)
rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(r.Context()), Conn: rpcConn}
if err = c.registerConnection(r.Context(), rpcClient, 0); err != nil {
return err
}
c.connectedFuse.Fuse(true)
<-c.shutdownChan
c.gracefulShutdown(rpcClient)
// Closing the client will also close the connection
rpcClient.Close()
rpcTransport.Close()
close(c.shutdownChan)
return nil
}
func (c *http2Server) registerConnection(
ctx context.Context,
rpcClient tunnelpogs.TunnelServer_PogsClient,
numPreviousAttempts uint8,
) error {
connDetail, err := rpcClient.RegisterConnection(
ctx,
c.config.NamedTunnel.Auth,
c.config.NamedTunnel.ID,
c.connIndex,
c.config.ConnectionOptions(c.localAddr.String(), numPreviousAttempts),
)
if err != nil {
c.logger.Errorf("Cannot register connection, err: %v", err)
return err
}
c.logger.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
return nil
}
func (c *http2Server) gracefulShutdown(rpcClient tunnelpogs.TunnelServer_PogsClient) {
ctx, cancel := context.WithTimeout(context.Background(), c.config.GracePeriod)
defer cancel()
err := rpcClient.UnregisterConnection(ctx)
if err != nil {
c.logger.Errorf("Cannot unregister connection gracefully, err: %v", err)
return
}
c.logger.Info("Sent graceful shutdown signal")
<-ctx.Done()
}
func (c *http2Server) writeErrorResponse(w http.ResponseWriter, err error) {
c.logger.Errorf("HTTP request error: %s", err)
c.config.Metrics.incrementResponses(c.connIndexStr, "502")
jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
if err != nil {
panic(err)
}
w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
w.WriteHeader(http.StatusBadGateway)
}
func (c *http2Server) logRequest(r *http.Request, cfRay string, lbProbe bool) {
logger := c.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else if lbProbe {
logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else {
logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
}
logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
}
}
func (c *http2Server) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
c.config.Metrics.incrementResponses(c.connIndexStr, "200")
logger := c.logger
if cfRay != "" {
logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
} else if lbProbe {
logger.Debugf("Response to Load Balancer health check %s", r.Status)
} else {
logger.Infof("%s", r.Status)
}
logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
} else {
logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
}
}
func (c *http2Server) close(conn net.Conn) {
// Send signal to control loop to start graceful shutdown
c.shutdownChan <- struct{}{}
// Wait for control loop to close channel
<-c.shutdownChan
conn.Close()
}
type http2Stream struct {
r io.Reader
w http.ResponseWriter
flusher http.Flusher
}
func newHTTP2Stream(w http.ResponseWriter, r *http.Request) (*http2Stream, error) {
flusher, ok := w.(http.Flusher)
if !ok {
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
}
return &http2Stream{r: r.Body, w: w, flusher: flusher}, nil
}
func (wr *http2Stream) WriteRespHeaders(resp *http.Response) error {
dest := wr.w.Header()
userHeaders := make(http.Header, len(resp.Header))
for header, values := range resp.Header {
// Since these are http2 headers, they're required to be lowercase
h2name := strings.ToLower(header)
for _, v := range values {
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,
// so it should be sent as an HTTP/2 response header.
dest.Add(h2name, v)
// Since these are http2 headers, they're required to be lowercase
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
// User headers, on the other hand, must all be serialized so that
// HTTP/2 header validation won't be applied to HTTP/1 header values
userHeaders.Add(h2name, v)
}
}
}
// Perform user header serialization and set them in the single header
dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
wr.w.WriteHeader(http.StatusOK)
wr.flusher.Flush()
return nil
}
func (wr *http2Stream) Read(p []byte) (n int, err error) {
return wr.r.Read(p)
}
func (wr *http2Stream) Write(p []byte) (n int, err error) {
n, err = wr.w.Write(p)
if err != nil {
return 0, err
}
wr.flusher.Flush()
return
}
func (wr *http2Stream) Close() error {
return nil
}
func isControlPlaneUpgrade(r *http.Request) bool {
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlPlaneUpgrade
}
func isWebsocketUpgrade(r *http.Request) bool {
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
}
func stripWebsocketUpgradeHeader(r *http.Request) {
r.Header.Del(internalUpgradeHeader)
}

View File

@ -34,6 +34,14 @@ var (
}, },
[]string{"status_code"}, []string{"status_code"},
) )
requestErrors = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: connection.MetricsNamespace,
Subsystem: connection.TunnelSubsystem,
Name: "request_errors",
Help: "Count of error proxying to origin",
},
)
haConnections = prometheus.NewGauge( haConnections = prometheus.NewGauge(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: connection.MetricsNamespace, Namespace: connection.MetricsNamespace,
@ -49,6 +57,7 @@ func init() {
totalRequests, totalRequests,
concurrentRequests, concurrentRequests,
responseByCode, responseByCode,
requestErrors,
haConnections, haConnections,
) )
} }

View File

@ -3,15 +3,15 @@ package origin
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/tls" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
@ -23,28 +23,21 @@ const (
) )
type client struct { type client struct {
config *ProxyConfig ingressRules ingress.Ingress
logger logger.Service tags []tunnelpogs.Tag
bufferPool *buffer.Pool logger logger.Service
bufferPool *buffer.Pool
} }
func NewClient(config *ProxyConfig, logger logger.Service) connection.OriginClient { func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, logger logger.Service) connection.OriginClient {
return &client{ return &client{
config: config, ingressRules: ingressRules,
logger: logger, tags: tags,
bufferPool: buffer.NewPool(512 * 1024), logger: logger,
bufferPool: buffer.NewPool(512 * 1024),
} }
} }
type ProxyConfig struct {
Client http.RoundTripper
URL *url.URL
TLSConfig *tls.Config
HostHeader string
NoChunkedEncoding bool
Tags []tunnelpogs.Tag
}
func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error { func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
incrementRequests() incrementRequests()
defer decrementConcurrentRequests() defer decrementConcurrentRequests()
@ -53,29 +46,30 @@ func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsock
lbProbe := isLBProbeRequest(req) lbProbe := isLBProbeRequest(req)
c.appendTagHeaders(req) c.appendTagHeaders(req)
c.logRequest(req, cfRay, lbProbe) rule, ruleNum := c.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
c.logRequest(req, cfRay, lbProbe, ruleNum)
var ( var (
resp *http.Response resp *http.Response
err error err error
) )
if isWebsocket { if isWebsocket {
resp, err = c.proxyWebsocket(w, req) resp, err = c.proxyWebsocket(w, req, rule)
} else { } else {
resp, err = c.proxyHTTP(w, req) resp, err = c.proxyHTTP(w, req, rule)
} }
if err != nil { if err != nil {
c.logger.Errorf("HTTP request error: %s", err) c.logRequestError(err, cfRay, ruleNum)
responseByCode.WithLabelValues("502").Inc()
w.WriteErrorResponse(err) w.WriteErrorResponse(err)
return err return err
} }
c.logResponseOk(resp, cfRay, lbProbe) c.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
return nil return nil
} }
func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if c.config.NoChunkedEncoding { if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"} req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil { if err == nil {
@ -86,9 +80,12 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
// Request origin to keep connection alive to improve performance // Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive") req.Header.Set("Connection", "keep-alive")
c.setHostHeader(req) if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
resp, err := c.config.Client.RoundTrip(req) resp, err := rule.Service.RoundTrip(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin") return nil, errors.Wrap(err, "Error proxying request to origin")
} }
@ -111,9 +108,17 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
return resp, nil return resp, nil
} }
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
c.setHostHeader(req) if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig) req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
dialler, ok := rule.Service.(websocket.Dialler)
if !ok {
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
}
conn, resp, err := websocket.ClientConnect(req, dialler)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,28 +150,22 @@ func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadC
} }
} }
func (c *client) setHostHeader(req *http.Request) {
if c.config.HostHeader != "" {
req.Header.Set("Host", c.config.HostHeader)
req.Host = c.config.HostHeader
}
}
func (c *client) appendTagHeaders(r *http.Request) { func (c *client) appendTagHeaders(r *http.Request) {
for _, tag := range c.config.Tags { for _, tag := range c.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
} }
} }
func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) { func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) {
if cfRay != "" { if cfRay != "" {
c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else if lbProbe { } else if lbProbe {
c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else { } else {
c.logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) c.logger.Debugf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
} }
c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
c.logger.Debugf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum)
if contentLen := r.ContentLength; contentLen == -1 { if contentLen := r.ContentLength; contentLen == -1 {
c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
@ -175,14 +174,14 @@ func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) {
} }
} }
func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { func (c *client) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) {
responseByCode.WithLabelValues("200").Inc() responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc()
if cfRay != "" { if cfRay != "" {
c.logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) c.logger.Infof("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum)
} else if lbProbe { } else if lbProbe {
c.logger.Debugf("Response to Load Balancer health check %s", r.Status) c.logger.Debugf("Response to Load Balancer health check %s", r.Status)
} else { } else {
c.logger.Infof("%s", r.Status) c.logger.Debugf("Status: %s served by ingress %d", r.Status, ruleNum)
} }
c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
@ -193,6 +192,16 @@ func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
} }
} }
func (c *client) logRequestError(err error, cfRay string, ruleNum int) {
requestErrors.Inc()
if cfRay != "" {
c.logger.Errorf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err)
} else {
c.logger.Errorf("Proxying to ingress %d error: %v", ruleNum, err)
}
}
func findCfRayHeader(req *http.Request) string { func findCfRayHeader(req *http.Request) string {
return req.Header.Get("Cf-Ray") return req.Header.Get("Cf-Ray")
} }

View File

@ -3,27 +3,32 @@ package origin
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "flag"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tlsconfig" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/urfave/cli/v2"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
var (
testTags = []tunnelpogs.Tag(nil)
)
type mockHTTPRespWriter struct { type mockHTTPRespWriter struct {
*httptest.ResponseRecorder *httptest.ResponseRecorder
} }
@ -99,49 +104,39 @@ func (w *mockSSERespWriter) ReadBytes() []byte {
return <-w.writeNotification return <-w.writeNotification
} }
func TestProxy(t *testing.T) { func TestProxySingleOrigin(t *testing.T) {
logger, err := logger.New() logger, err := logger.New()
require.NoError(t, err) require.NoError(t, err)
// let runtime pick an available port
listener, err := hello.CreateTLSListener("127.0.0.1:0")
require.NoError(t, err)
originURL := &url.URL{
Scheme: "https",
Host: listener.Addr().String(),
}
originCA := x509.NewCertPool()
helloCert, err := tlsconfig.GetHelloCertificateX509()
require.NoError(t, err)
originCA.AddCert(helloCert)
clientTLS := &tls.Config{
RootCAs: originCA,
}
proxyConfig := &ProxyConfig{
Client: &http.Transport{
TLSClientConfig: clientTLS,
},
URL: originURL,
TLSConfig: clientTLS,
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
hello.StartHelloWorldServer(logger, listener, ctx.Done()) flagSet.Bool("hello-world", true, "")
}()
client := NewClient(proxyConfig, logger) cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL)) err = cliCtx.Set("hello-world", "true")
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS)) require.NoError(t, err)
t.Run("testProxySSE", testProxySSE(t, client, originURL))
allowURLFromArgs := false
ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs, logger)
require.NoError(t, err)
var wg sync.WaitGroup
errC := make(chan error)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
client := NewClient(ingressRule, testTags, logger)
t.Run("testProxyHTTP", testProxyHTTP(t, client))
t.Run("testProxyWebsocket", testProxyWebsocket(t, client))
t.Run("testProxySSE", testProxySSE(t, client))
cancel() cancel()
wg.Wait()
} }
func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) { func testProxyHTTP(t *testing.T, client connection.OriginClient) func(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
respWriter := newMockHTTPRespWriter() respWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err) require.NoError(t, err)
err = client.Proxy(respWriter, req, false) err = client.Proxy(respWriter, req, false)
@ -151,11 +146,11 @@ func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.
} }
} }
func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) { func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
// WSRoute is a websocket echo handler // WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), nil)
readPipe, writePipe := io.Pipe() readPipe, writePipe := io.Pipe()
respWriter := newMockWSRespWriter(readPipe) respWriter := newMockWSRespWriter(readPipe)
@ -191,7 +186,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
} }
} }
func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) { func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
var ( var (
pushCount = 50 pushCount = 50
@ -199,7 +194,7 @@ func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.U
) )
respWriter := newMockSSERespWriter() respWriter := newMockSSERespWriter()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil)
require.NoError(t, err) require.NoError(t, err)
var wg sync.WaitGroup var wg sync.WaitGroup
@ -225,3 +220,98 @@ func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.U
wg.Wait() wg.Wait()
} }
} }
func TestProxyMultipleOrigins(t *testing.T) {
api := httptest.NewServer(mockAPI{})
defer api.Close()
unvalidatedIngress := []config.UnvalidatedIngressRule{
{
Hostname: "api.example.com",
Service: api.URL,
},
{
Hostname: "hello.example.com",
Service: "hello-world",
},
{
Hostname: "health.example.com",
Path: "/health",
Service: "http_status:200",
},
{
Hostname: "*",
Service: "http_status:404",
},
}
ingress, err := ingress.ParseIngress(&config.Configuration{
TunnelID: t.Name(),
Ingress: unvalidatedIngress,
})
require.NoError(t, err)
logger, err := logger.New()
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
errC := make(chan error)
var wg sync.WaitGroup
ingress.StartOrigins(&wg, logger, ctx.Done(), errC)
client := NewClient(ingress, testTags, logger)
tests := []struct {
url string
expectedStatus int
expectedBody []byte
}{
{
url: "http://api.example.com",
expectedStatus: http.StatusCreated,
expectedBody: []byte("Created"),
},
{
url: fmt.Sprintf("http://hello.example.com%s", hello.HealthRoute),
expectedStatus: http.StatusOK,
expectedBody: []byte("ok"),
},
{
url: "http://health.example.com/health",
expectedStatus: http.StatusOK,
},
{
url: "http://health.example.com/",
expectedStatus: http.StatusNotFound,
},
{
url: "http://not-found.example.com",
expectedStatus: http.StatusNotFound,
},
}
for _, test := range tests {
respWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, test.url, nil)
require.NoError(t, err)
err = client.Proxy(respWriter, req, false)
require.NoError(t, err)
assert.Equal(t, test.expectedStatus, respWriter.Code)
if test.expectedBody != nil {
assert.Equal(t, test.expectedBody, respWriter.Body.Bytes())
} else {
assert.Equal(t, 0, respWriter.Body.Len())
}
}
cancel()
wg.Wait()
}
type mockAPI struct{}
func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
w.Write([]byte("Created"))
}

View File

@ -20,7 +20,6 @@ import (
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
@ -47,7 +46,6 @@ const (
type TunnelConfig struct { type TunnelConfig struct {
ConnectionConfig *connection.Config ConnectionConfig *connection.Config
ProxyConfig *ProxyConfig
BuildInfo *buildinfo.BuildInfo BuildInfo *buildinfo.BuildInfo
ClientID string ClientID string
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
@ -57,6 +55,7 @@ type TunnelConfig struct {
IsAutoupdated bool IsAutoupdated bool
IsFreeTunnel bool IsFreeTunnel bool
LBPool string LBPool string
Tags []tunnelpogs.Tag
Logger logger.Service Logger logger.Service
Observer *connection.Observer Observer *connection.Observer
ReportedVersion string ReportedVersion string
@ -67,7 +66,6 @@ type TunnelConfig struct {
ClassicTunnel *connection.ClassicTunnelConfig ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
ProtocolSelector connection.ProtocolSelector ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config EdgeTLSConfigs map[connection.Protocol]*tls.Config
} }
@ -113,7 +111,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch), OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
ExistingTunnelPolicy: policy, ExistingTunnelPolicy: policy,
PoolName: c.LBPool, PoolName: c.LBPool,
Tags: c.ProxyConfig.Tags, Tags: c.Tags,
ConnectionID: connectionID, ConnectionID: connectionID,
OriginLocalIP: OriginLocalIP, OriginLocalIP: OriginLocalIP,
IsAutoupdated: c.IsAutoupdated, IsAutoupdated: c.IsAutoupdated,
@ -324,7 +322,7 @@ func ServeH2mux(
) (err error, recoverable bool) { ) (err error, recoverable bool) {
config.Logger.Debugf("Connecting via h2mux") config.Logger.Debugf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer) handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, edgeConn, connectionIndex, config.Observer)
if err != nil { if err != nil {
return err, recoverable return err, recoverable
} }
@ -388,7 +386,7 @@ func ServeHTTP2(
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
config.Logger.Debugf("Connecting via http2") config.Logger.Debugf("Connecting via http2")
server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {

View File

@ -166,7 +166,12 @@ func validateIP(scheme, host, port string) (string, error) {
} }
// originURL shouldn't be a pointer, because this function might change the scheme // originURL shouldn't be a pointer, because this function might change the scheme
func ValidateHTTPService(originURL url.URL, hostname string, transport http.RoundTripper) error { func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
parsedURL, err := url.Parse(originURL)
if err != nil {
return err
}
client := &http.Client{ client := &http.Client{
Transport: transport, Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
@ -175,7 +180,7 @@ func ValidateHTTPService(originURL url.URL, hostname string, transport http.Roun
Timeout: validationTimeout, Timeout: validationTimeout,
} }
initialRequest, err := http.NewRequest("GET", originURL.String(), nil) initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil { if err != nil {
return err return err
} }
@ -187,10 +192,10 @@ func ValidateHTTPService(originURL url.URL, hostname string, transport http.Roun
} }
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := originURL.Scheme oldScheme := parsedURL.Scheme
originURL.Scheme = toggleProtocol(originURL.Scheme) parsedURL.Scheme = toggleProtocol(oldScheme)
secondRequest, err := http.NewRequest("GET", originURL.String(), nil) secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil { if err != nil {
return err return err
} }
@ -200,9 +205,9 @@ func ValidateHTTPService(originURL url.URL, hostname string, transport http.Roun
resp.Body.Close() resp.Body.Close()
return errors.Errorf( return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v", "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v",
originURL.Host, parsedURL.Host,
oldScheme, oldScheme,
originURL.Scheme, parsedURL.Scheme,
initialErr, initialErr,
originURL, originURL,
) )

View File

@ -123,7 +123,7 @@ func TestToggleProtocol(t *testing.T) {
// Happy path 1: originURL is HTTP, and HTTP connections work // Happy path 1: originURL is HTTP, and HTTP connections work
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
originURL := mustParse(t, "http://127.0.0.1/") originURL := "http://127.0.0.1/"
hostname := "example.com" hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@ -151,7 +151,7 @@ func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
// Happy path 2: originURL is HTTPS, and HTTPS connections work // Happy path 2: originURL is HTTPS, and HTTPS connections work
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
originURL := mustParse(t, "https://127.0.0.1:1234/") originURL := "https://127.0.0.1:1234/"
hostname := "example.com" hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@ -179,7 +179,7 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
// Error path 1: originURL is HTTPS, but HTTP connections work // Error path 1: originURL is HTTPS, but HTTP connections work
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := mustParse(t, "https://127.0.0.1:1234/") originURL := "https://127.0.0.1:1234/"
hostname := "example.com" hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@ -207,13 +207,10 @@ func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
// Error path 2: originURL is HTTP, but HTTPS connections work // Error path 2: originURL is HTTP, but HTTPS connections work
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
originURLWithPort := url.URL{ originURL := "http://127.0.0.1:1234/"
Scheme: "http",
Host: "127.0.0.1:1234",
}
hostname := "example.com" hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname) assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" { if req.URL.Scheme == "http" {
return nil, assert.AnError return nil, assert.AnError
@ -224,7 +221,7 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
panic("Shouldn't reach here") panic("Shouldn't reach here")
}))) })))
assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname) assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" { if req.URL.Scheme == "http" {
return nil, assert.AnError return nil, assert.AnError
@ -253,14 +250,12 @@ func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer redirectServer.Close() defer redirectServer.Close()
redirectServerURL, err := url.Parse(redirectServer.URL) assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
assert.NoError(t, err)
assert.NoError(t, ValidateHTTPService(*redirectServerURL, hostname, redirectClient.Transport))
} }
// Ensure validation times out when origin URL is nonresponsive // Ensure validation times out when origin URL is nonresponsive
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) { func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
originURL := mustParse(t, "http://127.0.0.1/") originURL := "http://127.0.0.1/"
hostname := "example.com" hostname := "example.com"
oldValidationTimeout := validationTimeout oldValidationTimeout := validationTimeout
defer func() { defer func() {
@ -376,9 +371,3 @@ func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *h
return server, client, nil return server, client, nil
} }
func mustParse(t *testing.T, originURL string) url.URL {
parsedURL, err := url.Parse(originURL)
assert.NoError(t, err)
return *parsedURL
}