- Implement requested changes to resolve #111
- Add functionality to serveWebsocket() as well
This commit is contained in:
parent
9ff1611a6a
commit
4619d380e0
|
@ -489,9 +489,9 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "httphost",
|
Name: "http-host-header",
|
||||||
Usage: "Sets the HTTP Host header for the local webserver.",
|
Usage: "Sets the HTTP Host header for the local webserver.",
|
||||||
EnvVars: []string{"TUNNEL_HTTPHOST"},
|
EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"},
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
|
|
@ -251,7 +251,7 @@ func prepareTunnelConfig(
|
||||||
HTTPTransport: httpTransport,
|
HTTPTransport: httpTransport,
|
||||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
HTTPHost: c.String("httphost"),
|
HTTPHostHeader: c.String("http-host-header"),
|
||||||
IncidentLookup: origin.NewIncidentLookup(),
|
IncidentLookup: origin.NewIncidentLookup(),
|
||||||
IsAutoupdated: c.Bool("is-autoupdated"),
|
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||||
IsFreeTunnel: isFreeTunnel,
|
IsFreeTunnel: isFreeTunnel,
|
||||||
|
|
|
@ -52,7 +52,7 @@ type TunnelConfig struct {
|
||||||
HTTPTransport http.RoundTripper
|
HTTPTransport http.RoundTripper
|
||||||
HeartbeatInterval time.Duration
|
HeartbeatInterval time.Duration
|
||||||
Hostname string
|
Hostname string
|
||||||
HTTPHost string
|
HTTPHostHeader string
|
||||||
IncidentLookup IncidentLookup
|
IncidentLookup IncidentLookup
|
||||||
IsAutoupdated bool
|
IsAutoupdated bool
|
||||||
IsFreeTunnel bool
|
IsFreeTunnel bool
|
||||||
|
@ -520,13 +520,13 @@ func FindCfRayHeader(h1 *http.Request) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TunnelHandler struct {
|
type TunnelHandler struct {
|
||||||
originUrl string
|
originUrl string
|
||||||
httpHost string
|
httpHostHeader string
|
||||||
muxer *h2mux.Muxer
|
muxer *h2mux.Muxer
|
||||||
httpClient http.RoundTripper
|
httpClient http.RoundTripper
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
metrics *TunnelMetrics
|
metrics *TunnelMetrics
|
||||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||||
connectionID string
|
connectionID string
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
|
@ -547,7 +547,7 @@ func NewTunnelHandler(ctx context.Context,
|
||||||
}
|
}
|
||||||
h := &TunnelHandler{
|
h := &TunnelHandler{
|
||||||
originUrl: originURL,
|
originUrl: originURL,
|
||||||
httpHost: config.HTTPHost,
|
httpHostHeader: config.HTTPHostHeader,
|
||||||
httpClient: config.HTTPTransport,
|
httpClient: config.HTTPTransport,
|
||||||
tlsConfig: config.ClientTlsConfig,
|
tlsConfig: config.ClientTlsConfig,
|
||||||
tags: config.Tags,
|
tags: config.Tags,
|
||||||
|
@ -642,6 +642,11 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||||
|
if h.httpHostHeader != "" {
|
||||||
|
req.Header.Set("Host", h.httpHostHeader)
|
||||||
|
req.Host = h.httpHostHeader
|
||||||
|
}
|
||||||
|
|
||||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -670,9 +675,9 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request)
|
||||||
// Request origin to keep connection alive to improve performance
|
// Request origin to keep connection alive to improve performance
|
||||||
req.Header.Set("Connection", "keep-alive")
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
if h.httpHost != "" {
|
if h.httpHostHeader != "" {
|
||||||
req.Header.Set("Host", h.httpHost)
|
req.Header.Set("Host", h.httpHostHeader)
|
||||||
req.Host = h.httpHost
|
req.Host = h.httpHostHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := h.httpClient.RoundTrip(req)
|
response, err := h.httpClient.RoundTrip(req)
|
||||||
|
|
Loading…
Reference in New Issue