- Implement requested changes to resolve #111

- Add functionality to serveWebsocket() as well
This commit is contained in:
David Barr 2019-07-04 09:40:45 +10:00
parent 9ff1611a6a
commit 4619d380e0
No known key found for this signature in database
GPG Key ID: 8BC1E18438835BB3
3 changed files with 20 additions and 15 deletions

View File

@ -489,9 +489,9 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "httphost", Name: "http-host-header",
Usage: "Sets the HTTP Host header for the local webserver.", Usage: "Sets the HTTP Host header for the local webserver.",
EnvVars: []string{"TUNNEL_HTTPHOST"}, EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{

View File

@ -251,7 +251,7 @@ func prepareTunnelConfig(
HTTPTransport: httpTransport, HTTPTransport: httpTransport,
HeartbeatInterval: c.Duration("heartbeat-interval"), HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname, Hostname: hostname,
HTTPHost: c.String("httphost"), HTTPHostHeader: c.String("http-host-header"),
IncidentLookup: origin.NewIncidentLookup(), IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel, IsFreeTunnel: isFreeTunnel,

View File

@ -52,7 +52,7 @@ type TunnelConfig struct {
HTTPTransport http.RoundTripper HTTPTransport http.RoundTripper
HeartbeatInterval time.Duration HeartbeatInterval time.Duration
Hostname string Hostname string
HTTPHost string HTTPHostHeader string
IncidentLookup IncidentLookup IncidentLookup IncidentLookup
IsAutoupdated bool IsAutoupdated bool
IsFreeTunnel bool IsFreeTunnel bool
@ -520,13 +520,13 @@ func FindCfRayHeader(h1 *http.Request) string {
} }
type TunnelHandler struct { type TunnelHandler struct {
originUrl string originUrl string
httpHost string httpHostHeader string
muxer *h2mux.Muxer muxer *h2mux.Muxer
httpClient http.RoundTripper httpClient http.RoundTripper
tlsConfig *tls.Config tlsConfig *tls.Config
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
metrics *TunnelMetrics metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string connectionID string
logger *log.Logger logger *log.Logger
@ -547,7 +547,7 @@ func NewTunnelHandler(ctx context.Context,
} }
h := &TunnelHandler{ h := &TunnelHandler{
originUrl: originURL, originUrl: originURL,
httpHost: config.HTTPHost, httpHostHeader: config.HTTPHostHeader,
httpClient: config.HTTPTransport, httpClient: config.HTTPTransport,
tlsConfig: config.ClientTlsConfig, tlsConfig: config.ClientTlsConfig,
tags: config.Tags, tags: config.Tags,
@ -642,6 +642,11 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
} }
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
if h.httpHostHeader != "" {
req.Header.Set("Host", h.httpHostHeader)
req.Host = h.httpHostHeader
}
conn, response, err := websocket.ClientConnect(req, h.tlsConfig) conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -670,9 +675,9 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request)
// Request origin to keep connection alive to improve performance // Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive") req.Header.Set("Connection", "keep-alive")
if h.httpHost != "" { if h.httpHostHeader != "" {
req.Header.Set("Host", h.httpHost) req.Header.Set("Host", h.httpHostHeader)
req.Host = h.httpHost req.Host = h.httpHostHeader
} }
response, err := h.httpClient.RoundTrip(req) response, err := h.httpClient.RoundTrip(req)