added option to map a cname host to a local origin
This commit is contained in:
parent
976eb24883
commit
cfa1523b3b
|
@ -768,6 +768,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||
Name: "host-to-origin",
|
||||
Usage: "Decalre a direct host to origin mapping i.e. host1.example.com=http://localhost:8080. This takes precedence over the single url mapping.",
|
||||
EnvVars: []string{"HOST_TO_ORIGIN"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "http-host-header",
|
||||
Usage: "Sets the HTTP Host header for the local webserver.",
|
||||
|
|
|
@ -71,7 +71,7 @@ func logClientOptions(c *cli.Context) {
|
|||
flags[flag] = c.Generic(flag)
|
||||
}
|
||||
|
||||
sliceFlags := []string{"header", "tag", "proxy-dns-upstream", "upstream", "edge"}
|
||||
sliceFlags := []string{"header", "tag", "proxy-dns-upstream", "upstream", "edge", "host-to-origin"}
|
||||
for _, sliceFlag := range sliceFlags {
|
||||
if len(c.StringSlice(sliceFlag)) > 0 {
|
||||
flags[sliceFlag] = strings.Join(c.StringSlice(sliceFlag), ", ")
|
||||
|
@ -179,6 +179,15 @@ func prepareTunnelConfig(
|
|||
return nil, errors.Wrap(err, "Error validating origin URL")
|
||||
}
|
||||
|
||||
hostToOriginUrls := make(map[string]string)
|
||||
hostToOrigins := c.StringSlice("host-to-origin")
|
||||
for _, hostToOrigin := range hostToOrigins {
|
||||
hostToOriginArr := strings.Split(hostToOrigin, "=")
|
||||
if len(hostToOriginArr) == 2 {
|
||||
hostToOriginUrls[hostToOriginArr[0]] = hostToOriginArr[1]
|
||||
}
|
||||
}
|
||||
|
||||
var originCert []byte
|
||||
if !isFreeTunnel {
|
||||
originCert, err = getOriginCert(c)
|
||||
|
@ -268,6 +277,7 @@ func prepareTunnelConfig(
|
|||
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
|
||||
OriginCert: originCert,
|
||||
OriginUrl: originURL,
|
||||
HostToOriginUrls: hostToOriginUrls,
|
||||
ReportedVersion: version,
|
||||
Retries: c.Uint("retries"),
|
||||
RunFromTerminal: isRunningFromTerminal(),
|
||||
|
|
|
@ -85,6 +85,8 @@ type TunnelConfig struct {
|
|||
// OriginUrl may not be used if a user specifies a unix socket.
|
||||
OriginUrl string
|
||||
|
||||
// Optional mapping to enable routing to multiple OriginUrls based on Host header.
|
||||
HostToOriginUrls map[string]string
|
||||
// feature-flag to use new edge reconnect tokens
|
||||
UseReconnectToken bool
|
||||
// feature-flag for using ConnectionDigest
|
||||
|
@ -540,6 +542,7 @@ func LogServerInfo(
|
|||
|
||||
type TunnelHandler struct {
|
||||
originUrl string
|
||||
hostToOriginUrls map[string]string
|
||||
httpHostHeader string
|
||||
muxer *h2mux.Muxer
|
||||
httpClient http.RoundTripper
|
||||
|
@ -565,8 +568,16 @@ func NewTunnelHandler(ctx context.Context,
|
|||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL)
|
||||
}
|
||||
for k, _ := range config.HostToOriginUrls {
|
||||
hostOriginURL, err := validation.ValidateUrl(config.HostToOriginUrls[k])
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unable to parse origin URL %#v", hostOriginURL)
|
||||
}
|
||||
config.HostToOriginUrls[k] = hostOriginURL
|
||||
}
|
||||
h := &TunnelHandler{
|
||||
originUrl: originURL,
|
||||
hostToOriginUrls: config.HostToOriginUrls,
|
||||
httpHostHeader: config.HTTPHostHeader,
|
||||
httpClient: config.HTTPTransport,
|
||||
tlsConfig: config.ClientTlsConfig,
|
||||
|
@ -629,8 +640,27 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func getOriginUrlForHost(h2 []h2mux.Header, hostToOriginUrls map[string]string) (string, string) {
|
||||
host := ""
|
||||
for _, header := range h2 {
|
||||
switch strings.ToLower(header.Name) {
|
||||
case ":authority":
|
||||
host := header.Value
|
||||
if val, ok := hostToOriginUrls[header.Value]; ok {
|
||||
return host, val
|
||||
}
|
||||
}
|
||||
}
|
||||
return host, ""
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
host, origin := getOriginUrlForHost(stream.Headers, h.hostToOriginUrls)
|
||||
|
||||
if host == "" {
|
||||
origin = h.originUrl
|
||||
}
|
||||
req, err := http.NewRequest("GET", origin, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue