From cfa1523b3bfc324ccf09680e8a359721cf1be2e1 Mon Sep 17 00:00:00 2001 From: Adnan Yunus Date: Tue, 28 Apr 2020 22:35:20 -0700 Subject: [PATCH] added option to map a cname host to a local origin --- cmd/cloudflared/tunnel/cmd.go | 6 +++++ cmd/cloudflared/tunnel/configuration.go | 12 +++++++++- origin/tunnel.go | 32 ++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 9e940126..50dfabfd 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -768,6 +768,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_HOSTNAME"}, Hidden: shouldHide, }), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ + Name: "host-to-origin", + Usage: "Decalre a direct host to origin mapping i.e. host1.example.com=http://localhost:8080. This takes precedence over the single url mapping.", + EnvVars: []string{"HOST_TO_ORIGIN"}, + Hidden: shouldHide, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "http-host-header", Usage: "Sets the HTTP Host header for the local webserver.", diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 8cd15d6c..cb98943c 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -71,7 +71,7 @@ func logClientOptions(c *cli.Context) { flags[flag] = c.Generic(flag) } - sliceFlags := []string{"header", "tag", "proxy-dns-upstream", "upstream", "edge"} + sliceFlags := []string{"header", "tag", "proxy-dns-upstream", "upstream", "edge", "host-to-origin"} for _, sliceFlag := range sliceFlags { if len(c.StringSlice(sliceFlag)) > 0 { flags[sliceFlag] = strings.Join(c.StringSlice(sliceFlag), ", ") @@ -179,6 +179,15 @@ func prepareTunnelConfig( return nil, errors.Wrap(err, "Error validating origin URL") } + hostToOriginUrls := make(map[string]string) + hostToOrigins := c.StringSlice("host-to-origin") + for _, hostToOrigin := range hostToOrigins { + hostToOriginArr := strings.Split(hostToOrigin, "=") + if len(hostToOriginArr) == 2 { + hostToOriginUrls[hostToOriginArr[0]] = hostToOriginArr[1] + } + } + var originCert []byte if !isFreeTunnel { originCert, err = getOriginCert(c) @@ -268,6 +277,7 @@ func prepareTunnelConfig( NoChunkedEncoding: c.Bool("no-chunked-encoding"), OriginCert: originCert, OriginUrl: originURL, + HostToOriginUrls: hostToOriginUrls, ReportedVersion: version, Retries: c.Uint("retries"), RunFromTerminal: isRunningFromTerminal(), diff --git a/origin/tunnel.go b/origin/tunnel.go index bc4b9c50..f75bcdea 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -85,6 +85,8 @@ type TunnelConfig struct { // OriginUrl may not be used if a user specifies a unix socket. OriginUrl string + // Optional mapping to enable routing to multiple OriginUrls based on Host header. + HostToOriginUrls map[string]string // feature-flag to use new edge reconnect tokens UseReconnectToken bool // feature-flag for using ConnectionDigest @@ -540,6 +542,7 @@ func LogServerInfo( type TunnelHandler struct { originUrl string + hostToOriginUrls map[string]string httpHostHeader string muxer *h2mux.Muxer httpClient http.RoundTripper @@ -565,8 +568,16 @@ func NewTunnelHandler(ctx context.Context, if err != nil { return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL) } + for k, _ := range config.HostToOriginUrls { + hostOriginURL, err := validation.ValidateUrl(config.HostToOriginUrls[k]) + if err != nil { + return nil, "", fmt.Errorf("unable to parse origin URL %#v", hostOriginURL) + } + config.HostToOriginUrls[k] = hostOriginURL + } h := &TunnelHandler{ originUrl: originURL, + hostToOriginUrls: config.HostToOriginUrls, httpHostHeader: config.HTTPHostHeader, httpClient: config.HTTPTransport, tlsConfig: config.ClientTlsConfig, @@ -629,8 +640,27 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { return nil } +func getOriginUrlForHost(h2 []h2mux.Header, hostToOriginUrls map[string]string) (string, string) { + host := "" + for _, header := range h2 { + switch strings.ToLower(header.Name) { + case ":authority": + host := header.Value + if val, ok := hostToOriginUrls[header.Value]; ok { + return host, val + } + } + } + return host, "" +} + func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) { - req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) + host, origin := getOriginUrlForHost(stream.Headers, h.hostToOriginUrls) + + if host == "" { + origin = h.originUrl + } + req, err := http.NewRequest("GET", origin, h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") }