added option to map a cname host to a local origin
This commit is contained in:
parent
976eb24883
commit
cfa1523b3b
|
@ -768,6 +768,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "host-to-origin",
|
||||||
|
Usage: "Decalre a direct host to origin mapping i.e. host1.example.com=http://localhost:8080. This takes precedence over the single url mapping.",
|
||||||
|
EnvVars: []string{"HOST_TO_ORIGIN"},
|
||||||
|
Hidden: shouldHide,
|
||||||
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "http-host-header",
|
Name: "http-host-header",
|
||||||
Usage: "Sets the HTTP Host header for the local webserver.",
|
Usage: "Sets the HTTP Host header for the local webserver.",
|
||||||
|
|
|
@ -71,7 +71,7 @@ func logClientOptions(c *cli.Context) {
|
||||||
flags[flag] = c.Generic(flag)
|
flags[flag] = c.Generic(flag)
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceFlags := []string{"header", "tag", "proxy-dns-upstream", "upstream", "edge"}
|
sliceFlags := []string{"header", "tag", "proxy-dns-upstream", "upstream", "edge", "host-to-origin"}
|
||||||
for _, sliceFlag := range sliceFlags {
|
for _, sliceFlag := range sliceFlags {
|
||||||
if len(c.StringSlice(sliceFlag)) > 0 {
|
if len(c.StringSlice(sliceFlag)) > 0 {
|
||||||
flags[sliceFlag] = strings.Join(c.StringSlice(sliceFlag), ", ")
|
flags[sliceFlag] = strings.Join(c.StringSlice(sliceFlag), ", ")
|
||||||
|
@ -179,6 +179,15 @@ func prepareTunnelConfig(
|
||||||
return nil, errors.Wrap(err, "Error validating origin URL")
|
return nil, errors.Wrap(err, "Error validating origin URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hostToOriginUrls := make(map[string]string)
|
||||||
|
hostToOrigins := c.StringSlice("host-to-origin")
|
||||||
|
for _, hostToOrigin := range hostToOrigins {
|
||||||
|
hostToOriginArr := strings.Split(hostToOrigin, "=")
|
||||||
|
if len(hostToOriginArr) == 2 {
|
||||||
|
hostToOriginUrls[hostToOriginArr[0]] = hostToOriginArr[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var originCert []byte
|
var originCert []byte
|
||||||
if !isFreeTunnel {
|
if !isFreeTunnel {
|
||||||
originCert, err = getOriginCert(c)
|
originCert, err = getOriginCert(c)
|
||||||
|
@ -268,6 +277,7 @@ func prepareTunnelConfig(
|
||||||
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
|
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
|
||||||
OriginCert: originCert,
|
OriginCert: originCert,
|
||||||
OriginUrl: originURL,
|
OriginUrl: originURL,
|
||||||
|
HostToOriginUrls: hostToOriginUrls,
|
||||||
ReportedVersion: version,
|
ReportedVersion: version,
|
||||||
Retries: c.Uint("retries"),
|
Retries: c.Uint("retries"),
|
||||||
RunFromTerminal: isRunningFromTerminal(),
|
RunFromTerminal: isRunningFromTerminal(),
|
||||||
|
|
|
@ -85,6 +85,8 @@ type TunnelConfig struct {
|
||||||
// OriginUrl may not be used if a user specifies a unix socket.
|
// OriginUrl may not be used if a user specifies a unix socket.
|
||||||
OriginUrl string
|
OriginUrl string
|
||||||
|
|
||||||
|
// Optional mapping to enable routing to multiple OriginUrls based on Host header.
|
||||||
|
HostToOriginUrls map[string]string
|
||||||
// feature-flag to use new edge reconnect tokens
|
// feature-flag to use new edge reconnect tokens
|
||||||
UseReconnectToken bool
|
UseReconnectToken bool
|
||||||
// feature-flag for using ConnectionDigest
|
// feature-flag for using ConnectionDigest
|
||||||
|
@ -540,6 +542,7 @@ func LogServerInfo(
|
||||||
|
|
||||||
type TunnelHandler struct {
|
type TunnelHandler struct {
|
||||||
originUrl string
|
originUrl string
|
||||||
|
hostToOriginUrls map[string]string
|
||||||
httpHostHeader string
|
httpHostHeader string
|
||||||
muxer *h2mux.Muxer
|
muxer *h2mux.Muxer
|
||||||
httpClient http.RoundTripper
|
httpClient http.RoundTripper
|
||||||
|
@ -565,8 +568,16 @@ func NewTunnelHandler(ctx context.Context,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL)
|
return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL)
|
||||||
}
|
}
|
||||||
|
for k, _ := range config.HostToOriginUrls {
|
||||||
|
hostOriginURL, err := validation.ValidateUrl(config.HostToOriginUrls[k])
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("unable to parse origin URL %#v", hostOriginURL)
|
||||||
|
}
|
||||||
|
config.HostToOriginUrls[k] = hostOriginURL
|
||||||
|
}
|
||||||
h := &TunnelHandler{
|
h := &TunnelHandler{
|
||||||
originUrl: originURL,
|
originUrl: originURL,
|
||||||
|
hostToOriginUrls: config.HostToOriginUrls,
|
||||||
httpHostHeader: config.HTTPHostHeader,
|
httpHostHeader: config.HTTPHostHeader,
|
||||||
httpClient: config.HTTPTransport,
|
httpClient: config.HTTPTransport,
|
||||||
tlsConfig: config.ClientTlsConfig,
|
tlsConfig: config.ClientTlsConfig,
|
||||||
|
@ -629,8 +640,27 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getOriginUrlForHost(h2 []h2mux.Header, hostToOriginUrls map[string]string) (string, string) {
|
||||||
|
host := ""
|
||||||
|
for _, header := range h2 {
|
||||||
|
switch strings.ToLower(header.Name) {
|
||||||
|
case ":authority":
|
||||||
|
host := header.Value
|
||||||
|
if val, ok := hostToOriginUrls[header.Value]; ok {
|
||||||
|
return host, val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return host, ""
|
||||||
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
host, origin := getOriginUrlForHost(stream.Headers, h.hostToOriginUrls)
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
origin = h.originUrl
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("GET", origin, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue