TUN-3863: Consolidate header handling logic in the connection package; move headers definitions from h2mux to packages that manage them; cleanup header conversions
All header transformation code from h2mux has been consolidated in the connection package since it's used by both h2mux and http2 logic. Exported headers used by proxying between edge and cloudflared so then can be shared by tunnel service on the edge. Moved access-related headers to corresponding packages that have the code that sets/uses these headers. Removed tunnel hostname tracking from h2mux since it wasn't used by anything. We will continue to set the tunnel hostname header from the edge for backward compatibilty, but it's no longer used by cloudflared. Move bastion-related logic into carrier package, untangled dependencies between carrier, origin, and websocket packages.
This commit is contained in:
parent
ebf5292bf9
commit
8ca0d86c85
|
@ -5,20 +5,25 @@ package carrier
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/token"
|
"github.com/cloudflare/cloudflared/token"
|
||||||
)
|
)
|
||||||
|
|
||||||
const LogFieldOriginURL = "originURL"
|
const (
|
||||||
|
LogFieldOriginURL = "originURL"
|
||||||
|
CFAccessTokenHeader = "Cf-Access-Token"
|
||||||
|
cfJumpDestinationHeader = "Cf-Access-Jump-Destination"
|
||||||
|
)
|
||||||
|
|
||||||
type StartOptions struct {
|
type StartOptions struct {
|
||||||
AppInfo *token.AppInfo
|
AppInfo *token.AppInfo
|
||||||
|
@ -32,15 +37,11 @@ type StartOptions struct {
|
||||||
type Connection interface {
|
type Connection interface {
|
||||||
// ServeStream is used to forward data from the client to the edge
|
// ServeStream is used to forward data from the client to the edge
|
||||||
ServeStream(*StartOptions, io.ReadWriter) error
|
ServeStream(*StartOptions, io.ReadWriter) error
|
||||||
|
|
||||||
// StartServer is used to listen for incoming connections from the edge to the origin
|
|
||||||
StartServer(net.Listener, string, <-chan struct{}) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StdinoutStream is empty struct for wrapping stdin/stdout
|
// StdinoutStream is empty struct for wrapping stdin/stdout
|
||||||
// into a single ReadWriter
|
// into a single ReadWriter
|
||||||
type StdinoutStream struct {
|
type StdinoutStream struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Read will read from Stdin
|
// Read will read from Stdin
|
||||||
func (c *StdinoutStream) Read(p []byte) (int, error) {
|
func (c *StdinoutStream) Read(p []byte) (int, error) {
|
||||||
|
@ -149,7 +150,7 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
originRequest.Header.Set(h2mux.CFAccessTokenHeader, token)
|
originRequest.Header.Set(CFAccessTokenHeader, token)
|
||||||
|
|
||||||
for k, v := range options.Headers {
|
for k, v := range options.Headers {
|
||||||
if len(v) >= 1 {
|
if len(v) >= 1 {
|
||||||
|
@ -159,3 +160,26 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque
|
||||||
|
|
||||||
return originRequest, nil
|
return originRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetBastionDest(header http.Header, destination string) {
|
||||||
|
if destination != "" {
|
||||||
|
header.Set(cfJumpDestinationHeader, destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveBastionDest(r *http.Request) (string, error) {
|
||||||
|
jumpDestination := r.Header.Get(cfJumpDestinationHeader)
|
||||||
|
if jumpDestination == "" {
|
||||||
|
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||||
|
}
|
||||||
|
// Strip scheme and path set by client. Without a scheme
|
||||||
|
// Parsing a hostname and path without scheme might not return an error due to parsing ambiguities
|
||||||
|
if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" {
|
||||||
|
return removePath(jumpURL.Host), nil
|
||||||
|
}
|
||||||
|
return removePath(jumpDestination), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removePath(dest string) string {
|
||||||
|
return strings.SplitN(dest, "/", 2)[0]
|
||||||
|
}
|
||||||
|
|
|
@ -156,3 +156,99 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||||
|
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBastionDestination(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header http.Header
|
||||||
|
expectedDest string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hostname destination",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"localhost"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname destination with port",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"localhost:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname destination with scheme and port",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"ssh://localhost:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full hostname url",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname destination with port and path",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"localhost:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"127.0.0.1"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with port",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"127.0.0.1:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with port and path",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with schem and port",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full ip url",
|
||||||
|
header: http.Header{
|
||||||
|
cfJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no destination",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
r := &http.Request{
|
||||||
|
Header: test.header,
|
||||||
|
}
|
||||||
|
dest, err := ResolveBastionDest(r)
|
||||||
|
if test.wantErr {
|
||||||
|
assert.Error(t, err, "Test %s expects error", test.name)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
|
||||||
|
assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,17 +1,13 @@
|
||||||
package carrier
|
package carrier
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
|
||||||
"github.com/cloudflare/cloudflared/token"
|
"github.com/cloudflare/cloudflared/token"
|
||||||
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
|
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
@ -23,20 +19,6 @@ type Websocket struct {
|
||||||
isSocks bool
|
isSocks bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type wsdialer struct {
|
|
||||||
conn *cfwebsocket.GorillaConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, error) {
|
|
||||||
local, ok := d.conn.LocalAddr().(*net.TCPAddr)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("not a tcp connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := socks.AddrSpec{IP: local.IP, Port: local.Port}
|
|
||||||
return d.conn, &addr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWSConnection returns a new connection object
|
// NewWSConnection returns a new connection object
|
||||||
func NewWSConnection(log *zerolog.Logger) Connection {
|
func NewWSConnection(log *zerolog.Logger) Connection {
|
||||||
return &Websocket{
|
return &Websocket{
|
||||||
|
@ -54,16 +36,10 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
|
||||||
}
|
}
|
||||||
defer wsConn.Close()
|
defer wsConn.Close()
|
||||||
|
|
||||||
ingress.Stream(wsConn, conn, ws.log)
|
cfwebsocket.Stream(wsConn, conn, ws.log)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartServer creates a Websocket server to listen for connections.
|
|
||||||
// This is used on the origin (tunnel) side to take data from the muxer and send it to the origin
|
|
||||||
func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC <-chan struct{}) error {
|
|
||||||
return cfwebsocket.StartProxyServer(ws.log, listener, remote, shutdownC, ingress.DefaultStreamHandler)
|
|
||||||
}
|
|
||||||
|
|
||||||
// createWebsocketStream will create a WebSocket connection to stream data over
|
// createWebsocketStream will create a WebSocket connection to stream data over
|
||||||
// It also handles redirects from Access and will present that flow if
|
// It also handles redirects from Access and will present that flow if
|
||||||
// the token is not present on the request
|
// the token is not present on the request
|
||||||
|
|
|
@ -6,19 +6,20 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
|
||||||
"github.com/cloudflare/cloudflared/validation"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
|
"github.com/cloudflare/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
LogFieldHost = "host"
|
LogFieldHost = "host"
|
||||||
|
cfAccessClientIDHeader = "Cf-Access-Client-Id"
|
||||||
|
cfAccessClientSecretHeader = "Cf-Access-Client-Secret"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StartForwarder starts a client side websocket forward
|
// StartForwarder starts a client side websocket forward
|
||||||
|
@ -31,16 +32,14 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z
|
||||||
// get the headers from the config file and add to the request
|
// get the headers from the config file and add to the request
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
if forwarder.TokenClientID != "" {
|
if forwarder.TokenClientID != "" {
|
||||||
headers.Set(h2mux.CFAccessClientIDHeader, forwarder.TokenClientID)
|
headers.Set(cfAccessClientIDHeader, forwarder.TokenClientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if forwarder.TokenSecret != "" {
|
if forwarder.TokenSecret != "" {
|
||||||
headers.Set(h2mux.CFAccessClientSecretHeader, forwarder.TokenSecret)
|
headers.Set(cfAccessClientSecretHeader, forwarder.TokenSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
if forwarder.Destination != "" {
|
carrier.SetBastionDest(headers, forwarder.Destination)
|
||||||
headers.Add(h2mux.CFJumpDestinationHeader, forwarder.Destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
options := &carrier.StartOptions{
|
options := &carrier.StartOptions{
|
||||||
OriginURL: forwarder.URL,
|
OriginURL: forwarder.URL,
|
||||||
|
@ -72,16 +71,13 @@ func ssh(c *cli.Context) error {
|
||||||
// get the headers from the cmdline and add them
|
// get the headers from the cmdline and add them
|
||||||
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||||
if c.IsSet(sshTokenIDFlag) {
|
if c.IsSet(sshTokenIDFlag) {
|
||||||
headers.Set(h2mux.CFAccessClientIDHeader, c.String(sshTokenIDFlag))
|
headers.Set(cfAccessClientIDHeader, c.String(sshTokenIDFlag))
|
||||||
}
|
}
|
||||||
if c.IsSet(sshTokenSecretFlag) {
|
if c.IsSet(sshTokenSecretFlag) {
|
||||||
headers.Set(h2mux.CFAccessClientSecretHeader, c.String(sshTokenSecretFlag))
|
headers.Set(cfAccessClientSecretHeader, c.String(sshTokenSecretFlag))
|
||||||
}
|
}
|
||||||
|
|
||||||
destination := c.String(sshDestinationFlag)
|
carrier.SetBastionDest(headers, c.String(sshDestinationFlag))
|
||||||
if destination != "" {
|
|
||||||
headers.Add(h2mux.CFJumpDestinationHeader, destination)
|
|
||||||
}
|
|
||||||
|
|
||||||
options := &carrier.StartOptions{
|
options := &carrier.StartOptions{
|
||||||
OriginURL: originURL,
|
OriginURL: originURL,
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/sshgen"
|
"github.com/cloudflare/cloudflared/sshgen"
|
||||||
"github.com/cloudflare/cloudflared/token"
|
"github.com/cloudflare/cloudflared/token"
|
||||||
|
@ -286,7 +285,7 @@ func curl(c *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdArgs = append(cmdArgs, "-H")
|
cmdArgs = append(cmdArgs, "-H")
|
||||||
cmdArgs = append(cmdArgs, fmt.Sprintf("%s: %s", h2mux.CFAccessTokenHeader, tok))
|
cmdArgs = append(cmdArgs, fmt.Sprintf("%s: %s", carrier.CFAccessTokenHeader, tok))
|
||||||
return run("curl", cmdArgs...)
|
return run("curl", cmdArgs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -472,10 +471,10 @@ func isFileThere(candidate string) bool {
|
||||||
func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error {
|
func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error {
|
||||||
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||||
if c.IsSet(sshTokenIDFlag) {
|
if c.IsSet(sshTokenIDFlag) {
|
||||||
headers.Add(h2mux.CFAccessClientIDHeader, c.String(sshTokenIDFlag))
|
headers.Add(cfAccessClientIDHeader, c.String(sshTokenIDFlag))
|
||||||
}
|
}
|
||||||
if c.IsSet(sshTokenSecretFlag) {
|
if c.IsSet(sshTokenSecretFlag) {
|
||||||
headers.Add(h2mux.CFAccessClientSecretHeader, c.String(sshTokenSecretFlag))
|
headers.Add(cfAccessClientSecretHeader, c.String(sshTokenSecretFlag))
|
||||||
}
|
}
|
||||||
options := &carrier.StartOptions{AppInfo: appInfo, OriginURL: appUrl.String(), Headers: headers}
|
options := &carrier.StartOptions{AppInfo: appInfo, OriginURL: appUrl.String(), Headers: headers}
|
||||||
|
|
||||||
|
|
|
@ -234,7 +234,7 @@ func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||||
}
|
}
|
||||||
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
|
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "invalid request received")
|
return nil, errors.Wrap(err, "invalid request received")
|
||||||
}
|
}
|
||||||
|
@ -246,15 +246,15 @@ type h2muxRespWriter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
headers := h2mux.H1ResponseToH2ResponseHeaders(status, header)
|
headers := H1ResponseToH2ResponseHeaders(status, header)
|
||||||
headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
|
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
|
||||||
return rp.WriteHeaders(headers)
|
return rp.WriteHeaders(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteErrorResponse() {
|
func (rp *h2muxRespWriter) WriteErrorResponse() {
|
||||||
_ = rp.WriteHeaders([]h2mux.Header{
|
_ = rp.WriteHeaders([]h2mux.Header{
|
||||||
{Name: ":status", Value: "502"},
|
{Name: ":status", Value: "502"},
|
||||||
{Name: ResponseMetaHeaderField, Value: responseMetaHeaderCfd},
|
{Name: ResponseMetaHeader, Value: responseMetaHeaderCfd},
|
||||||
})
|
})
|
||||||
_, _ = rp.Write([]byte("502 Bad Gateway"))
|
_, _ = rp.Write([]byte("502 Bad Gateway"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,9 +115,9 @@ func TestServeStreamHTTP(t *testing.T) {
|
||||||
require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
|
require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
|
||||||
|
|
||||||
if test.isProxyError {
|
if test.isProxyError {
|
||||||
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderCfd))
|
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderCfd))
|
||||||
} else {
|
} else {
|
||||||
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
|
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||||
body := make([]byte, len(test.expectedBody))
|
body := make([]byte, len(test.expectedBody))
|
||||||
_, err = stream.Read(body)
|
_, err = stream.Read(body)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -164,7 +164,7 @@ func TestServeStreamWS(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
|
require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
|
||||||
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
|
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||||
|
|
||||||
data := []byte("test websocket")
|
data := []byte("test websocket")
|
||||||
err = wsutil.WriteClientText(writePipe, data)
|
err = wsutil.WriteClientText(writePipe, data)
|
||||||
|
@ -268,7 +268,7 @@ func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
|
|
||||||
require.NoError(b, openstreamErr)
|
require.NoError(b, openstreamErr)
|
||||||
assert.True(b, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
|
assert.True(b, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||||
require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
|
require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
|
||||||
require.NoError(b, readBodyErr)
|
require.NoError(b, readBodyErr)
|
||||||
require.Equal(b, test.expectedBody, body)
|
require.Equal(b, test.expectedBody, body)
|
||||||
|
|
|
@ -1,19 +1,31 @@
|
||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
ResponseMetaHeaderField = "cf-cloudflared-response-meta"
|
// h2mux-style special headers
|
||||||
|
RequestUserHeaders = "cf-cloudflared-request-headers"
|
||||||
|
ResponseUserHeaders = "cf-cloudflared-response-headers"
|
||||||
|
ResponseMetaHeader = "cf-cloudflared-response-meta"
|
||||||
|
|
||||||
|
// h2mux-style special headers
|
||||||
|
CanonicalResponseUserHeaders = http.CanonicalHeaderKey(ResponseUserHeaders)
|
||||||
|
CanonicalResponseMetaHeader = http.CanonicalHeaderKey(ResponseMetaHeader)
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField)
|
// pre-generate possible values for res
|
||||||
canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(ResponseMetaHeaderField)
|
|
||||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
||||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||||
)
|
)
|
||||||
|
@ -29,3 +41,204 @@ func mustInitRespMetaHeader(src string) string {
|
||||||
}
|
}
|
||||||
return string(header)
|
return string(header)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var headerEncoding = base64.RawStdEncoding
|
||||||
|
|
||||||
|
// note: all h2mux headers should be lower-case (http/2 style)
|
||||||
|
const ()
|
||||||
|
|
||||||
|
// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld
|
||||||
|
// to an HTTP/1 Request object destined for the local origin web service.
|
||||||
|
// This operation includes conversion of the pseudo-headers into their closest
|
||||||
|
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
|
||||||
|
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||||
|
for _, header := range h2 {
|
||||||
|
name := strings.ToLower(header.Name)
|
||||||
|
if !IsControlHeader(name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case ":method":
|
||||||
|
h1.Method = header.Value
|
||||||
|
case ":scheme":
|
||||||
|
// noop - use the preexisting scheme from h1.URL
|
||||||
|
case ":authority":
|
||||||
|
// Otherwise the host header will be based on the origin URL
|
||||||
|
h1.Host = header.Value
|
||||||
|
case ":path":
|
||||||
|
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
|
||||||
|
// However, this HTTP/1 Request object belongs to the Go standard library,
|
||||||
|
// whose URL package makes some opinionated decisions about the encoding of
|
||||||
|
// URL characters: see the docs of https://godoc.org/net/url#URL,
|
||||||
|
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
|
||||||
|
// which is always used when computing url.URL.String(), whether we'd like it or not.
|
||||||
|
//
|
||||||
|
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
|
||||||
|
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
|
||||||
|
// is treated differently when HTTP_PROXY is set!
|
||||||
|
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
|
||||||
|
//
|
||||||
|
// This means we are subject to the behavior of net/url's function `shouldEscape`
|
||||||
|
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
|
||||||
|
|
||||||
|
if header.Value == "*" {
|
||||||
|
h1.URL.Path = "*"
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Due to the behavior of validation.ValidateUrl, h1.URL may
|
||||||
|
// already have a partial value, with or without a trailing slash.
|
||||||
|
base := h1.URL.String()
|
||||||
|
base = strings.TrimRight(base, "/")
|
||||||
|
// But we know :path begins with '/', because we handled '*' above - see RFC7540
|
||||||
|
requestURL, err := url.Parse(base + header.Value)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
|
||||||
|
}
|
||||||
|
h1.URL = requestURL
|
||||||
|
case "content-length":
|
||||||
|
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unparseable content length")
|
||||||
|
}
|
||||||
|
h1.ContentLength = contentLength
|
||||||
|
case RequestUserHeaders:
|
||||||
|
// Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version
|
||||||
|
// Find and parse user headers serialized into a single one
|
||||||
|
userHeaders, err := DeserializeHeaders(header.Value)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Unable to parse user headers")
|
||||||
|
}
|
||||||
|
for _, userHeader := range userHeaders {
|
||||||
|
h1.Header.Add(userHeader.Name, userHeader.Value)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// All other control headers shall just be proxied transparently
|
||||||
|
h1.Header.Add(header.Name, header.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsControlHeader(headerName string) bool {
|
||||||
|
return headerName == "content-length" ||
|
||||||
|
headerName == "connection" || headerName == "upgrade" || // Websocket headers
|
||||||
|
strings.HasPrefix(headerName, ":") ||
|
||||||
|
strings.HasPrefix(headerName, "cf-")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly
|
||||||
|
func IsWebsocketClientHeader(headerName string) bool {
|
||||||
|
return headerName == "sec-websocket-accept" ||
|
||||||
|
headerName == "connection" ||
|
||||||
|
headerName == "upgrade"
|
||||||
|
}
|
||||||
|
|
||||||
|
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) {
|
||||||
|
h2 = []h2mux.Header{
|
||||||
|
{Name: ":status", Value: strconv.Itoa(status)},
|
||||||
|
}
|
||||||
|
userHeaders := make(http.Header, len(h1))
|
||||||
|
for header, values := range h1 {
|
||||||
|
h2name := strings.ToLower(header)
|
||||||
|
if h2name == "content-length" {
|
||||||
|
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||||
|
// so it should be sent as an HTTP/2 response header.
|
||||||
|
|
||||||
|
// Since these are http2 headers, they're required to be lowercase
|
||||||
|
h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]})
|
||||||
|
} else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) {
|
||||||
|
// User headers, on the other hand, must all be serialized so that
|
||||||
|
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||||
|
userHeaders[header] = values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform user header serialization and set them in the single header
|
||||||
|
h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)})
|
||||||
|
return h2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize HTTP1.x headers by base64-encoding each header name and value,
|
||||||
|
// and then joining them in the format of [key:value;]
|
||||||
|
func SerializeHeaders(h1Headers http.Header) string {
|
||||||
|
// compute size of the fully serialized value and largest temp buffer we will need
|
||||||
|
serializedLen := 0
|
||||||
|
maxTempLen := 0
|
||||||
|
for headerName, headerValues := range h1Headers {
|
||||||
|
for _, headerValue := range headerValues {
|
||||||
|
nameLen := headerEncoding.EncodedLen(len(headerName))
|
||||||
|
valueLen := headerEncoding.EncodedLen(len(headerValue))
|
||||||
|
const delims = 2
|
||||||
|
serializedLen += delims + nameLen + valueLen
|
||||||
|
if nameLen > maxTempLen {
|
||||||
|
maxTempLen = nameLen
|
||||||
|
}
|
||||||
|
if valueLen > maxTempLen {
|
||||||
|
maxTempLen = valueLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var buf strings.Builder
|
||||||
|
buf.Grow(serializedLen)
|
||||||
|
|
||||||
|
temp := make([]byte, maxTempLen)
|
||||||
|
writeB64 := func(s string) {
|
||||||
|
n := headerEncoding.EncodedLen(len(s))
|
||||||
|
if n > len(temp) {
|
||||||
|
temp = make([]byte, n)
|
||||||
|
}
|
||||||
|
headerEncoding.Encode(temp[:n], []byte(s))
|
||||||
|
buf.Write(temp[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
for headerName, headerValues := range h1Headers {
|
||||||
|
for _, headerValue := range headerValues {
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
buf.WriteByte(';')
|
||||||
|
}
|
||||||
|
writeB64(headerName)
|
||||||
|
buf.WriteByte(':')
|
||||||
|
writeB64(headerValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize headers serialized by `SerializeHeader`
|
||||||
|
func DeserializeHeaders(serializedHeaders string) ([]h2mux.Header, error) {
|
||||||
|
const unableToDeserializeErr = "Unable to deserialize headers"
|
||||||
|
|
||||||
|
var deserialized []h2mux.Header
|
||||||
|
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
||||||
|
if len(serializedPair) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
serializedHeaderParts := strings.Split(serializedPair, ":")
|
||||||
|
if len(serializedHeaderParts) != 2 {
|
||||||
|
return nil, errors.New(unableToDeserializeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
serializedName := serializedHeaderParts[0]
|
||||||
|
serializedValue := serializedHeaderParts[1]
|
||||||
|
deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName)))
|
||||||
|
deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue)))
|
||||||
|
|
||||||
|
if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil {
|
||||||
|
return nil, errors.Wrap(err, unableToDeserializeErr)
|
||||||
|
}
|
||||||
|
if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil {
|
||||||
|
return nil, errors.Wrap(err, unableToDeserializeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
deserialized = append(deserialized, h2mux.Header{
|
||||||
|
Name: string(deserializedName),
|
||||||
|
Value: string(deserializedValue),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return deserialized, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package h2mux
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -14,9 +14,11 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ByName []Header
|
type ByName []h2mux.Header
|
||||||
|
|
||||||
func (a ByName) Len() int { return len(a) }
|
func (a ByName) Len() int { return len(a) }
|
||||||
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
@ -37,16 +39,16 @@ func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
|
||||||
"Mock header 2": {"Mock value 2"},
|
"Mock header 2": {"Mock value 2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeadersField, mockHeaders), request)
|
headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request)
|
||||||
|
|
||||||
assert.True(t, reflect.DeepEqual(mockHeaders, request.Header))
|
assert.True(t, reflect.DeepEqual(mockHeaders, request.Header))
|
||||||
assert.NoError(t, headersConversionErr)
|
assert.NoError(t, headersConversionErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSerializedHeaders(headersField string, headers http.Header) []Header {
|
func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header {
|
||||||
return []Header{{
|
return []h2mux.Header{{
|
||||||
headersField,
|
Name: headersField,
|
||||||
SerializeHeaders(headers),
|
Value: SerializeHeaders(headers),
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,15 +56,16 @@ func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
emptyHeaders := make(http.Header)
|
||||||
headersConversionErr := H2RequestHeadersToH1Request(
|
headersConversionErr := H2RequestHeadersToH1Request(
|
||||||
[]Header{{
|
[]h2mux.Header{{
|
||||||
RequestUserHeadersField,
|
Name: RequestUserHeaders,
|
||||||
SerializeHeaders(http.Header{}),
|
Value: SerializeHeaders(emptyHeaders),
|
||||||
}},
|
}},
|
||||||
request,
|
request,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert.True(t, reflect.DeepEqual(http.Header{}, request.Header))
|
assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header))
|
||||||
assert.NoError(t, headersConversionErr)
|
assert.NoError(t, headersConversionErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,9 +73,9 @@ func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
mockRequestHeaders := []Header{
|
mockRequestHeaders := []h2mux.Header{
|
||||||
{Name: ":path", Value: "//bad_path/"},
|
{Name: ":path", Value: "//bad_path/"},
|
||||||
{Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||||
}
|
}
|
||||||
|
|
||||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||||
|
@ -90,9 +93,9 @@ func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
mockRequestHeaders := []Header{
|
mockRequestHeaders := []h2mux.Header{
|
||||||
{Name: ":path", Value: "/?query=mock%20value"},
|
{Name: ":path", Value: "/?query=mock%20value"},
|
||||||
{Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||||
}
|
}
|
||||||
|
|
||||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||||
|
@ -110,9 +113,9 @@ func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
|
||||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
mockRequestHeaders := []Header{
|
mockRequestHeaders := []h2mux.Header{
|
||||||
{Name: ":path", Value: "/mock%20path"},
|
{Name: ":path", Value: "/mock%20path"},
|
||||||
{Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||||
}
|
}
|
||||||
|
|
||||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||||
|
@ -267,9 +270,9 @@ func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
|
||||||
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
mockRequestHeaders := []Header{
|
mockRequestHeaders := []h2mux.Header{
|
||||||
{Name: ":path", Value: testCase.path},
|
{Name: ":path", Value: testCase.path},
|
||||||
{Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||||
}
|
}
|
||||||
|
|
||||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||||
|
@ -337,12 +340,12 @@ func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
|
||||||
const expectedMethod = "POST"
|
const expectedMethod = "POST"
|
||||||
const expectedHostname = "request.hostname.example.com"
|
const expectedHostname = "request.hostname.example.com"
|
||||||
|
|
||||||
h2 := []Header{
|
h2 := []h2mux.Header{
|
||||||
{Name: ":method", Value: expectedMethod},
|
{Name: ":method", Value: expectedMethod},
|
||||||
{Name: ":scheme", Value: testScheme},
|
{Name: ":scheme", Value: testScheme},
|
||||||
{Name: ":authority", Value: expectedHostname},
|
{Name: ":authority", Value: expectedHostname},
|
||||||
{Name: ":path", Value: testPath},
|
{Name: ":path", Value: testPath},
|
||||||
{Name: RequestUserHeadersField, Value: ""},
|
{Name: RequestUserHeaders, Value: ""},
|
||||||
}
|
}
|
||||||
h1, err := http.NewRequest("GET", testOrigin.url, nil)
|
h1, err := http.NewRequest("GET", testOrigin.url, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -424,10 +427,10 @@ func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []Header) {
|
func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) {
|
||||||
for name, values := range headers {
|
for name, values := range headers {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
h2muxHeaders = append(h2muxHeaders, Header{name, value})
|
h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -515,14 +518,14 @@ func TestParseHeaders(t *testing.T) {
|
||||||
"Mock-Header-Three": {"3"},
|
"Mock-Header-Three": {"3"},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockHeaders := []Header{
|
mockHeaders := []h2mux.Header{
|
||||||
{Name: "One", Value: "1"}, // will be dropped
|
{Name: "One", Value: "1"}, // will be dropped
|
||||||
{Name: "Cf-Two", Value: "cf-value-1"},
|
{Name: "Cf-Two", Value: "cf-value-1"},
|
||||||
{Name: "Cf-Two", Value: "cf-value-2"},
|
{Name: "Cf-Two", Value: "cf-value-2"},
|
||||||
{Name: RequestUserHeadersField, Value: SerializeHeaders(mockUserHeadersToSerialize)},
|
{Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)},
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedHeaders := []Header{
|
expectedHeaders := []h2mux.Header{
|
||||||
{Name: "Cf-Two", Value: "cf-value-1"},
|
{Name: "Cf-Two", Value: "cf-value-1"},
|
||||||
{Name: "Cf-Two", Value: "cf-value-2"},
|
{Name: "Cf-Two", Value: "cf-value-2"},
|
||||||
{Name: "Mock-Header-One", Value: "1"},
|
{Name: "Mock-Header-One", Value: "1"},
|
||||||
|
@ -583,7 +586,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
|
||||||
|
|
||||||
serializedHeadersIndex := -1
|
serializedHeadersIndex := -1
|
||||||
for i, header := range headers {
|
for i, header := range headers {
|
||||||
if header.Name == ResponseUserHeadersField {
|
if header.Name == ResponseUserHeaders {
|
||||||
serializedHeadersIndex = i
|
serializedHeadersIndex = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -593,7 +596,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
|
||||||
headers[:serializedHeadersIndex],
|
headers[:serializedHeadersIndex],
|
||||||
headers[serializedHeadersIndex+1:]...,
|
headers[serializedHeadersIndex+1:]...,
|
||||||
)
|
)
|
||||||
expectedControlHeaders := []Header{
|
expectedControlHeaders := []h2mux.Header{
|
||||||
{Name: ":status", Value: "200"},
|
{Name: ":status", Value: "200"},
|
||||||
{Name: "content-length", Value: "123"},
|
{Name: "content-length", Value: "123"},
|
||||||
}
|
}
|
||||||
|
@ -601,7 +604,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
|
||||||
assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders)
|
assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders)
|
||||||
|
|
||||||
actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value)
|
actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value)
|
||||||
expectedUserHeaders := []Header{
|
expectedUserHeaders := []h2mux.Header{
|
||||||
{Name: "User-header-one", Value: ""},
|
{Name: "User-header-one", Value: ""},
|
||||||
{Name: "User-header-two", Value: "1"},
|
{Name: "User-header-two", Value: "1"},
|
||||||
{Name: "User-header-two", Value: "2"},
|
{Name: "User-header-two", Value: "2"},
|
||||||
|
@ -630,7 +633,7 @@ func TestHeaderSize(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, header := range serializedHeaders {
|
for _, header := range serializedHeaders {
|
||||||
if header.Name != ResponseUserHeadersField {
|
if header.Name != ResponseUserHeaders {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,15 +13,15 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// note: these constants are exported so we can reuse them in the edge-side code
|
||||||
const (
|
const (
|
||||||
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
InternalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
||||||
tcpStreamHeader = "Cf-Cloudflared-Proxy-Src"
|
InternalTCPProxySrcHeader = "Cf-Cloudflared-Proxy-Src"
|
||||||
websocketUpgrade = "websocket"
|
WebsocketUpgrade = "websocket"
|
||||||
controlStreamUpgrade = "control-stream"
|
ControlStreamUpgrade = "control-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
|
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
|
||||||
|
@ -178,25 +178,23 @@ func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (
|
||||||
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
dest := rp.w.Header()
|
dest := rp.w.Header()
|
||||||
userHeaders := make(http.Header, len(header))
|
userHeaders := make(http.Header, len(header))
|
||||||
for header, values := range header {
|
for name, values := range header {
|
||||||
// Since these are http2 headers, they're required to be lowercase
|
// Since these are http2 headers, they're required to be lowercase
|
||||||
h2name := strings.ToLower(header)
|
h2name := strings.ToLower(name)
|
||||||
for _, v := range values {
|
|
||||||
if h2name == "content-length" {
|
if h2name == "content-length" {
|
||||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||||
// so it should be sent as an HTTP/2 response header.
|
// so it should be sent as an HTTP/2 response header.
|
||||||
dest.Add(h2name, v)
|
dest[name] = values
|
||||||
// Since these are http2 headers, they're required to be lowercase
|
// Since these are http2 headers, they're required to be lowercase
|
||||||
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
|
} else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) {
|
||||||
// User headers, on the other hand, must all be serialized so that
|
// User headers, on the other hand, must all be serialized so that
|
||||||
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||||
userHeaders.Add(h2name, v)
|
userHeaders[name] = values
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform user header serialization and set them in the single header
|
// Perform user header serialization and set them in the single header
|
||||||
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders))
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
||||||
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
||||||
if status == http.StatusSwitchingProtocols {
|
if status == http.StatusSwitchingProtocols {
|
||||||
|
@ -218,7 +216,7 @@ func (rp *http2RespWriter) WriteErrorResponse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) setResponseMetaHeader(value string) {
|
func (rp *http2RespWriter) setResponseMetaHeader(value string) {
|
||||||
rp.w.Header().Set(canonicalResponseMetaHeaderField, value)
|
rp.w.Header().Set(CanonicalResponseMetaHeader, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
||||||
|
@ -258,18 +256,18 @@ func determineHTTP2Type(r *http.Request) Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isControlStreamUpgrade(r *http.Request) bool {
|
func isControlStreamUpgrade(r *http.Request) bool {
|
||||||
return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade
|
return r.Header.Get(InternalUpgradeHeader) == ControlStreamUpgrade
|
||||||
}
|
}
|
||||||
|
|
||||||
func isWebsocketUpgrade(r *http.Request) bool {
|
func isWebsocketUpgrade(r *http.Request) bool {
|
||||||
return r.Header.Get(internalUpgradeHeader) == websocketUpgrade
|
return r.Header.Get(InternalUpgradeHeader) == WebsocketUpgrade
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTCPStream discerns if the connection request needs a tcp stream proxy.
|
// IsTCPStream discerns if the connection request needs a tcp stream proxy.
|
||||||
func IsTCPStream(r *http.Request) bool {
|
func IsTCPStream(r *http.Request) bool {
|
||||||
return r.Header.Get(tcpStreamHeader) != ""
|
return r.Header.Get(InternalTCPProxySrcHeader) != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func stripWebsocketUpgradeHeader(r *http.Request) {
|
func stripWebsocketUpgradeHeader(r *http.Request) {
|
||||||
r.Header.Del(internalUpgradeHeader)
|
r.Header.Del(InternalUpgradeHeader)
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,9 +103,9 @@ func TestServeHTTP(t *testing.T) {
|
||||||
require.Equal(t, test.expectedBody, respBody)
|
require.Equal(t, test.expectedBody, respBody)
|
||||||
}
|
}
|
||||||
if test.isProxyError {
|
if test.isProxyError {
|
||||||
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeaderField))
|
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
|
||||||
} else {
|
} else {
|
||||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField))
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
|
@ -191,7 +191,7 @@ func TestServeWS(t *testing.T) {
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(internalUpgradeHeader, websocketUpgrade)
|
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -211,7 +211,7 @@ func TestServeWS(t *testing.T) {
|
||||||
resp := respWriter.Result()
|
resp := respWriter.Result()
|
||||||
// http2RespWriter should rewrite status 101 to 200
|
// http2RespWriter should rewrite status 101 to 200
|
||||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField))
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -235,7 +235,7 @@ func TestServeControlStream(t *testing.T) {
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
||||||
|
|
||||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -274,7 +274,7 @@ func TestFailRegistration(t *testing.T) {
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
||||||
|
|
||||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -310,7 +310,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
||||||
|
|
||||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
234
h2mux/header.go
234
h2mux/header.go
|
@ -1,234 +0,0 @@
|
||||||
package h2mux
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Header struct {
|
|
||||||
Name, Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
var headerEncoding = base64.RawStdEncoding
|
|
||||||
|
|
||||||
const (
|
|
||||||
RequestUserHeadersField = "cf-cloudflared-request-headers"
|
|
||||||
ResponseUserHeadersField = "cf-cloudflared-response-headers"
|
|
||||||
|
|
||||||
CFAccessTokenHeader = "cf-access-token"
|
|
||||||
CFJumpDestinationHeader = "CF-Access-Jump-Destination"
|
|
||||||
CFAccessClientIDHeader = "CF-Access-Client-Id"
|
|
||||||
CFAccessClientSecretHeader = "CF-Access-Client-Secret"
|
|
||||||
)
|
|
||||||
|
|
||||||
// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld
|
|
||||||
// to an HTTP/1 Request object destined for the local origin web service.
|
|
||||||
// This operation includes conversion of the pseudo-headers into their closest
|
|
||||||
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
|
|
||||||
func H2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error {
|
|
||||||
for _, header := range h2 {
|
|
||||||
name := strings.ToLower(header.Name)
|
|
||||||
if !IsControlHeader(name) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch name {
|
|
||||||
case ":method":
|
|
||||||
h1.Method = header.Value
|
|
||||||
case ":scheme":
|
|
||||||
// noop - use the preexisting scheme from h1.URL
|
|
||||||
case ":authority":
|
|
||||||
// Otherwise the host header will be based on the origin URL
|
|
||||||
h1.Host = header.Value
|
|
||||||
case ":path":
|
|
||||||
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
|
|
||||||
// However, this HTTP/1 Request object belongs to the Go standard library,
|
|
||||||
// whose URL package makes some opinionated decisions about the encoding of
|
|
||||||
// URL characters: see the docs of https://godoc.org/net/url#URL,
|
|
||||||
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
|
|
||||||
// which is always used when computing url.URL.String(), whether we'd like it or not.
|
|
||||||
//
|
|
||||||
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
|
|
||||||
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
|
|
||||||
// is treated differently when HTTP_PROXY is set!
|
|
||||||
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
|
|
||||||
//
|
|
||||||
// This means we are subject to the behavior of net/url's function `shouldEscape`
|
|
||||||
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
|
|
||||||
|
|
||||||
if header.Value == "*" {
|
|
||||||
h1.URL.Path = "*"
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Due to the behavior of validation.ValidateUrl, h1.URL may
|
|
||||||
// already have a partial value, with or without a trailing slash.
|
|
||||||
base := h1.URL.String()
|
|
||||||
base = strings.TrimRight(base, "/")
|
|
||||||
// But we know :path begins with '/', because we handled '*' above - see RFC7540
|
|
||||||
requestURL, err := url.Parse(base + header.Value)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
|
|
||||||
}
|
|
||||||
h1.URL = requestURL
|
|
||||||
case "content-length":
|
|
||||||
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unparseable content length")
|
|
||||||
}
|
|
||||||
h1.ContentLength = contentLength
|
|
||||||
case RequestUserHeadersField:
|
|
||||||
// Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version
|
|
||||||
// Find and parse user headers serialized into a single one
|
|
||||||
userHeaders, err := ParseUserHeaders(RequestUserHeadersField, h2)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Unable to parse user headers")
|
|
||||||
}
|
|
||||||
for _, userHeader := range userHeaders {
|
|
||||||
h1.Header.Add(http.CanonicalHeaderKey(userHeader.Name), userHeader.Value)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// All other control headers shall just be proxied transparently
|
|
||||||
h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseUserHeaders(headerNameToParseFrom string, headers []Header) ([]Header, error) {
|
|
||||||
for _, header := range headers {
|
|
||||||
if header.Name == headerNameToParseFrom {
|
|
||||||
return DeserializeHeaders(header.Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("%v header not found", RequestUserHeadersField)
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsControlHeader(headerName string) bool {
|
|
||||||
return headerName == "content-length" ||
|
|
||||||
headerName == "connection" || headerName == "upgrade" || // Websocket headers
|
|
||||||
strings.HasPrefix(headerName, ":") ||
|
|
||||||
strings.HasPrefix(headerName, "cf-")
|
|
||||||
}
|
|
||||||
|
|
||||||
// isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly
|
|
||||||
func IsWebsocketClientHeader(headerName string) bool {
|
|
||||||
return headerName == "sec-websocket-accept" ||
|
|
||||||
headerName == "connection" ||
|
|
||||||
headerName == "upgrade"
|
|
||||||
}
|
|
||||||
|
|
||||||
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []Header) {
|
|
||||||
h2 = []Header{
|
|
||||||
{Name: ":status", Value: strconv.Itoa(status)},
|
|
||||||
}
|
|
||||||
userHeaders := make(http.Header, len(h1))
|
|
||||||
for header, values := range h1 {
|
|
||||||
h2name := strings.ToLower(header)
|
|
||||||
if h2name == "content-length" {
|
|
||||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
|
||||||
// so it should be sent as an HTTP/2 response header.
|
|
||||||
|
|
||||||
// Since these are http2 headers, they're required to be lowercase
|
|
||||||
h2 = append(h2, Header{Name: "content-length", Value: values[0]})
|
|
||||||
} else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) {
|
|
||||||
// User headers, on the other hand, must all be serialized so that
|
|
||||||
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
|
||||||
userHeaders[header] = values
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform user header serialization and set them in the single header
|
|
||||||
h2 = append(h2, Header{ResponseUserHeadersField, SerializeHeaders(userHeaders)})
|
|
||||||
return h2
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize HTTP1.x headers by base64-encoding each header name and value,
|
|
||||||
// and then joining them in the format of [key:value;]
|
|
||||||
func SerializeHeaders(h1Headers http.Header) string {
|
|
||||||
// compute size of the fully serialized value and largest temp buffer we will need
|
|
||||||
serializedLen := 0
|
|
||||||
maxTempLen := 0
|
|
||||||
for headerName, headerValues := range h1Headers {
|
|
||||||
for _, headerValue := range headerValues {
|
|
||||||
nameLen := headerEncoding.EncodedLen(len(headerName))
|
|
||||||
valueLen := headerEncoding.EncodedLen(len(headerValue))
|
|
||||||
const delims = 2
|
|
||||||
serializedLen += delims + nameLen + valueLen
|
|
||||||
if nameLen > maxTempLen {
|
|
||||||
maxTempLen = nameLen
|
|
||||||
}
|
|
||||||
if valueLen > maxTempLen {
|
|
||||||
maxTempLen = valueLen
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var buf strings.Builder
|
|
||||||
buf.Grow(serializedLen)
|
|
||||||
|
|
||||||
temp := make([]byte, maxTempLen)
|
|
||||||
writeB64 := func(s string) {
|
|
||||||
n := headerEncoding.EncodedLen(len(s))
|
|
||||||
if n > len(temp) {
|
|
||||||
temp = make([]byte, n)
|
|
||||||
}
|
|
||||||
headerEncoding.Encode(temp[:n], []byte(s))
|
|
||||||
buf.Write(temp[:n])
|
|
||||||
}
|
|
||||||
|
|
||||||
for headerName, headerValues := range h1Headers {
|
|
||||||
for _, headerValue := range headerValues {
|
|
||||||
if buf.Len() > 0 {
|
|
||||||
buf.WriteByte(';')
|
|
||||||
}
|
|
||||||
writeB64(headerName)
|
|
||||||
buf.WriteByte(':')
|
|
||||||
writeB64(headerValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deserialize headers serialized by `SerializeHeader`
|
|
||||||
func DeserializeHeaders(serializedHeaders string) ([]Header, error) {
|
|
||||||
const unableToDeserializeErr = "Unable to deserialize headers"
|
|
||||||
|
|
||||||
var deserialized []Header
|
|
||||||
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
|
||||||
if len(serializedPair) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
serializedHeaderParts := strings.Split(serializedPair, ":")
|
|
||||||
if len(serializedHeaderParts) != 2 {
|
|
||||||
return nil, errors.New(unableToDeserializeErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
serializedName := serializedHeaderParts[0]
|
|
||||||
serializedValue := serializedHeaderParts[1]
|
|
||||||
deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName)))
|
|
||||||
deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue)))
|
|
||||||
|
|
||||||
if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil {
|
|
||||||
return nil, errors.Wrap(err, unableToDeserializeErr)
|
|
||||||
}
|
|
||||||
if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil {
|
|
||||||
return nil, errors.Wrap(err, unableToDeserializeErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
deserialized = append(deserialized, Header{
|
|
||||||
Name: string(deserializedName),
|
|
||||||
Value: string(deserializedValue),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return deserialized, nil
|
|
||||||
}
|
|
|
@ -23,6 +23,10 @@ type MuxedStreamDataSignaller interface {
|
||||||
Signal(ID uint32)
|
Signal(ID uint32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Header struct {
|
||||||
|
Name, Value string
|
||||||
|
}
|
||||||
|
|
||||||
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
||||||
type MuxedStream struct {
|
type MuxedStream struct {
|
||||||
streamID uint32
|
streamID uint32
|
||||||
|
@ -74,8 +78,6 @@ type MuxedStream struct {
|
||||||
sentEOF bool
|
sentEOF bool
|
||||||
// true if the peer sent us an EOF
|
// true if the peer sent us an EOF
|
||||||
receivedEOF bool
|
receivedEOF bool
|
||||||
// If valid, tunnelHostname is used to identify which origin service is the intended recipient of the request
|
|
||||||
tunnelHostname TunnelHostname
|
|
||||||
// Compression-related fields
|
// Compression-related fields
|
||||||
receivedUseDict bool
|
receivedUseDict bool
|
||||||
method string
|
method string
|
||||||
|
@ -252,10 +254,6 @@ func (s *MuxedStream) IsRPCStream() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MuxedStream) TunnelHostname() TunnelHostname {
|
|
||||||
return s.tunnelHostname
|
|
||||||
}
|
|
||||||
|
|
||||||
// Block until a value is sent on writeBufferHasSpace.
|
// Block until a value is sent on writeBufferHasSpace.
|
||||||
// Must be called while holding writeLock
|
// Must be called while holding writeLock
|
||||||
func (s *MuxedStream) awaitWriteBufferHasSpace() {
|
func (s *MuxedStream) awaitWriteBufferHasSpace() {
|
||||||
|
|
|
@ -12,10 +12,6 @@ import (
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
CloudflaredProxyTunnelHostnameHeader = "cf-cloudflared-proxy-tunnel-hostname"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MuxReader struct {
|
type MuxReader struct {
|
||||||
// f is used to read HTTP2 frames.
|
// f is used to read HTTP2 frames.
|
||||||
f *http2.Framer
|
f *http2.Framer
|
||||||
|
@ -252,8 +248,6 @@ func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
||||||
if r.dictionaries.write != nil {
|
if r.dictionaries.write != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case CloudflaredProxyTunnelHostnameHeader:
|
|
||||||
stream.tunnelHostname = TunnelHostname(header.Value)
|
|
||||||
}
|
}
|
||||||
headers = append(headers, Header{Name: header.Name, Value: header.Value})
|
headers = append(headers, Header{Name: header.Name, Value: header.Value})
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,10 +21,6 @@ var (
|
||||||
Name: ":path",
|
Name: ":path",
|
||||||
Value: "/api/tunnels",
|
Value: "/api/tunnels",
|
||||||
}
|
}
|
||||||
tunnelHostnameHeader = Header{
|
|
||||||
Name: CloudflaredProxyTunnelHostnameHeader,
|
|
||||||
Value: "tunnel.example.com",
|
|
||||||
}
|
|
||||||
respStatusHeader = Header{
|
respStatusHeader = Header{
|
||||||
Name: ":status",
|
Name: ":status",
|
||||||
Value: "200",
|
Value: "200",
|
||||||
|
@ -42,15 +38,6 @@ func (mosh *mockOriginStreamHandler) ServeStream(stream *MuxedStream) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCloudflaredProxyTunnelHostnameHeader(stream *MuxedStream) string {
|
|
||||||
for _, header := range stream.Headers {
|
|
||||||
if header.Name == CloudflaredProxyTunnelHostnameHeader {
|
|
||||||
return header.Value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) {
|
func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, stream.Headers, 1)
|
assert.Len(t, stream.Headers, 1)
|
||||||
|
@ -72,13 +59,11 @@ func TestMissingHeaders(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request doesn't contain CloudflaredProxyTunnelHostnameHeader
|
|
||||||
stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
|
stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
|
||||||
assertOpenStreamSucceed(t, stream, err)
|
assertOpenStreamSucceed(t, stream, err)
|
||||||
|
|
||||||
assert.Empty(t, originHandler.stream.method)
|
assert.Empty(t, originHandler.stream.method)
|
||||||
assert.Empty(t, originHandler.stream.path)
|
assert.Empty(t, originHandler.stream.path)
|
||||||
assert.False(t, originHandler.stream.TunnelHostname().IsSet())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReceiveHeaderData(t *testing.T) {
|
func TestReceiveHeaderData(t *testing.T) {
|
||||||
|
@ -90,18 +75,14 @@ func TestReceiveHeaderData(t *testing.T) {
|
||||||
methodHeader,
|
methodHeader,
|
||||||
schemeHeader,
|
schemeHeader,
|
||||||
pathHeader,
|
pathHeader,
|
||||||
tunnelHostnameHeader,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
reqHeaders = append(reqHeaders, tunnelHostnameHeader)
|
|
||||||
stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
|
stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
|
||||||
assertOpenStreamSucceed(t, stream, err)
|
assertOpenStreamSucceed(t, stream, err)
|
||||||
|
|
||||||
assert.Equal(t, methodHeader.Value, originHandler.stream.method)
|
assert.Equal(t, methodHeader.Value, originHandler.stream.method)
|
||||||
assert.Equal(t, pathHeader.Value, originHandler.stream.path)
|
assert.Equal(t, pathHeader.Value, originHandler.stream.path)
|
||||||
assert.True(t, originHandler.stream.TunnelHostname().IsSet())
|
|
||||||
assert.Equal(t, tunnelHostnameHeader.Value, originHandler.stream.TunnelHostname().String())
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,34 +25,10 @@ type OriginConnection interface {
|
||||||
|
|
||||||
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
|
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
|
||||||
|
|
||||||
// Stream copies copy data to & from provided io.ReadWriters.
|
|
||||||
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
|
||||||
proxyDone := make(chan struct{}, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(conn, backendConn)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Msgf("conn to backendConn copy: %v", err)
|
|
||||||
}
|
|
||||||
proxyDone <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(backendConn, conn)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Msgf("backendConn to conn copy: %v", err)
|
|
||||||
}
|
|
||||||
proxyDone <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// If one side is done, we are done.
|
|
||||||
<-proxyDone
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultStreamHandler is an implementation of streamHandlerFunc that
|
// DefaultStreamHandler is an implementation of streamHandlerFunc that
|
||||||
// performs a two way io.Copy between originConn and remoteConn.
|
// performs a two way io.Copy between originConn and remoteConn.
|
||||||
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) {
|
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) {
|
||||||
Stream(originConn, remoteConn, log)
|
websocket.Stream(originConn, remoteConn, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||||
|
@ -61,7 +37,7 @@ type tcpConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
Stream(tunnelConn, tc.conn, log)
|
websocket.Stream(tunnelConn, tc.conn, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *tcpConnection) Close() {
|
func (tc *tcpConnection) Close() {
|
||||||
|
@ -89,7 +65,7 @@ type wsConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *wsConnection) Close() {
|
func (wsc *wsConnection) Close() {
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -157,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer wsForwarderInConn.Close()
|
defer wsForwarderInConn.Close()
|
||||||
|
|
||||||
Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
websocket.Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,7 +104,7 @@ func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnectio
|
||||||
var err error
|
var err error
|
||||||
dest := o.dest
|
dest := o.dest
|
||||||
if o.isBastion {
|
if o.isBastion {
|
||||||
dest, err = o.bastionDest(r)
|
dest, err = carrier.ResolveBastionDest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -130,23 +128,6 @@ func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnectio
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
|
|
||||||
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
|
||||||
if jumpDestination == "" {
|
|
||||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
|
||||||
}
|
|
||||||
// Strip scheme and path set by client. Without a scheme
|
|
||||||
// Parsing a hostname and path without scheme might not return an error due to parsing ambiguities
|
|
||||||
if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" {
|
|
||||||
return removePath(jumpURL.Host), nil
|
|
||||||
}
|
|
||||||
return removePath(jumpDestination), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removePath(dest string) string {
|
|
||||||
return strings.SplitN(dest, "/", 2)[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
originConn := o.conn
|
originConn := o.conn
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
|
||||||
bastionReq := baseReq.Clone(context.Background())
|
bastionReq := baseReq.Clone(context.Background())
|
||||||
bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String())
|
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||||
|
|
||||||
expectHeader := http.Header{
|
expectHeader := http.Header{
|
||||||
"Connection": {"Upgrade"},
|
"Connection": {"Upgrade"},
|
||||||
|
@ -135,19 +135,23 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
testCase string
|
||||||
service *tcpOverWSService
|
service *tcpOverWSService
|
||||||
req *http.Request
|
req *http.Request
|
||||||
expectErr bool
|
expectErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
testCase: "specific TCP service",
|
||||||
service: newTCPOverWSService(originURL),
|
service: newTCPOverWSService(originURL),
|
||||||
req: baseReq,
|
req: baseReq,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
testCase: "bastion service",
|
||||||
service: newBastionService(),
|
service: newBastionService(),
|
||||||
req: bastionReq,
|
req: bastionReq,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
testCase: "invalid bastion request",
|
||||||
service: newBastionService(),
|
service: newBastionService(),
|
||||||
req: baseReq,
|
req: baseReq,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
|
@ -155,6 +159,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
t.Run(test.testCase, func(t *testing.T) {
|
||||||
if test.expectErr {
|
if test.expectErr {
|
||||||
_, resp, err := test.service.EstablishConnection(test.req)
|
_, resp, err := test.service.EstablishConnection(test.req)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
@ -162,6 +167,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
originListener.Close()
|
originListener.Close()
|
||||||
|
@ -175,104 +181,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBastionDestination(t *testing.T) {
|
|
||||||
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
header http.Header
|
|
||||||
expectedDest string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "hostname destination",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"localhost"},
|
|
||||||
},
|
|
||||||
expectedDest: "localhost",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hostname destination with port",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"localhost:9000"},
|
|
||||||
},
|
|
||||||
expectedDest: "localhost:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hostname destination with scheme and port",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"ssh://localhost:9000"},
|
|
||||||
},
|
|
||||||
expectedDest: "localhost:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "full hostname url",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"},
|
|
||||||
},
|
|
||||||
expectedDest: "localhost:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hostname destination with port and path",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"localhost:9000/metrics"},
|
|
||||||
},
|
|
||||||
expectedDest: "localhost:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ip destination",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"127.0.0.1"},
|
|
||||||
},
|
|
||||||
expectedDest: "127.0.0.1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ip destination with port",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"127.0.0.1:9000"},
|
|
||||||
},
|
|
||||||
expectedDest: "127.0.0.1:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ip destination with port and path",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"},
|
|
||||||
},
|
|
||||||
expectedDest: "127.0.0.1:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ip destination with schem and port",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"},
|
|
||||||
},
|
|
||||||
expectedDest: "127.0.0.1:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "full ip url",
|
|
||||||
header: http.Header{
|
|
||||||
canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
|
||||||
},
|
|
||||||
expectedDest: "127.0.0.1:9000",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no destination",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
s := newBastionService()
|
|
||||||
for _, test := range tests {
|
|
||||||
r := &http.Request{
|
|
||||||
Header: test.header,
|
|
||||||
}
|
|
||||||
dest, err := s.bastionDest(r)
|
|
||||||
if test.wantErr {
|
|
||||||
assert.Error(t, err, "Test %s expects error", test.name)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
|
|
||||||
assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||||
cfg := OriginRequestConfig{
|
cfg := OriginRequestConfig{
|
||||||
HTTPHostHeader: t.Name(),
|
HTTPHostHeader: t.Name(),
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -103,7 +104,7 @@ func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte)
|
||||||
|
|
||||||
func (cm *reconnectCredentialManager) RefreshAuth(
|
func (cm *reconnectCredentialManager) RefreshAuth(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
backoff *BackoffHandler,
|
backoff *retry.BackoffHandler,
|
||||||
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||||
) (retryTimer <-chan time.Time, err error) {
|
) (retryTimer <-chan time.Time, err error) {
|
||||||
authOutcome, err := authenticate(ctx, backoff.Retries())
|
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||||
|
@ -121,11 +122,11 @@ func (cm *reconnectCredentialManager) RefreshAuth(
|
||||||
case tunnelpogs.AuthSuccess:
|
case tunnelpogs.AuthSuccess:
|
||||||
cm.SetReconnectToken(outcome.JWT())
|
cm.SetReconnectToken(outcome.JWT())
|
||||||
cm.authSuccess.Inc()
|
cm.authSuccess.Inc()
|
||||||
return timeAfter(outcome.RefreshAfter()), nil
|
return retry.Clock.After(outcome.RefreshAfter()), nil
|
||||||
case tunnelpogs.AuthUnknown:
|
case tunnelpogs.AuthUnknown:
|
||||||
duration := outcome.RefreshAfter()
|
duration := outcome.RefreshAfter()
|
||||||
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||||
return timeAfter(duration), nil
|
return retry.Clock.After(duration), nil
|
||||||
case tunnelpogs.AuthFail:
|
case tunnelpogs.AuthFail:
|
||||||
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||||
return nil, outcome
|
return nil, outcome
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,11 +18,11 @@ func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
var wait time.Duration
|
var wait time.Duration
|
||||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
retry.Clock.After = func(d time.Duration) <-chan time.Time {
|
||||||
wait = d
|
wait = d
|
||||||
return time.After(d)
|
return time.After(d)
|
||||||
}
|
}
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return nil, fmt.Errorf("authentication failure")
|
return nil, fmt.Errorf("authentication failure")
|
||||||
}
|
}
|
||||||
|
@ -45,7 +46,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
|
|
||||||
// The backoff timer should have been reset. To confirm this, make timeNow
|
// The backoff timer should have been reset. To confirm this, make timeNow
|
||||||
// return a value after the backoff timer's grace period
|
// return a value after the backoff timer's grace period
|
||||||
timeNow = func() time.Time {
|
retry.Clock.Now = func() time.Time {
|
||||||
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
|
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
|
||||||
return time.Now().Add(expectedGracePeriod * 2)
|
return time.Now().Add(expectedGracePeriod * 2)
|
||||||
}
|
}
|
||||||
|
@ -57,12 +58,12 @@ func TestRefreshAuthSuccess(t *testing.T) {
|
||||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
var wait time.Duration
|
var wait time.Duration
|
||||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
retry.Clock.After = func(d time.Duration) <-chan time.Time {
|
||||||
wait = d
|
wait = d
|
||||||
return time.After(d)
|
return time.After(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||||
}
|
}
|
||||||
|
@ -81,12 +82,12 @@ func TestRefreshAuthUnknown(t *testing.T) {
|
||||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
var wait time.Duration
|
var wait time.Duration
|
||||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
retry.Clock.After = func(d time.Duration) <-chan time.Time {
|
||||||
wait = d
|
wait = d
|
||||||
return time.After(d)
|
return time.After(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
}
|
}
|
||||||
|
@ -104,7 +105,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
|
||||||
func TestRefreshAuthFail(t *testing.T) {
|
func TestRefreshAuthFail(t *testing.T) {
|
||||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
@ -112,10 +113,10 @@ func (s *Supervisor) Run(
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
tunnelsActive := s.config.HAConnections
|
tunnelsActive := s.config.HAConnections
|
||||||
|
|
||||||
backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
|
backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
|
||||||
var backoffTimer <-chan time.Time
|
var backoffTimer <-chan time.Time
|
||||||
|
|
||||||
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
refreshAuthBackoff := &retry.BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
||||||
var refreshAuthBackoffTimer <-chan time.Time
|
var refreshAuthBackoffTimer <-chan time.Time
|
||||||
|
|
||||||
if s.useReconnectToken {
|
if s.useReconnectToken {
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -138,7 +139,7 @@ func ServeTunnelLoop(
|
||||||
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
|
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
|
||||||
|
|
||||||
protocolFallback := &protocolFallback{
|
protocolFallback := &protocolFallback{
|
||||||
BackoffHandler{MaxRetries: config.Retries},
|
retry.BackoffHandler{MaxRetries: config.Retries},
|
||||||
config.ProtocolSelector.Current(),
|
config.ProtocolSelector.Current(),
|
||||||
false,
|
false,
|
||||||
}
|
}
|
||||||
|
@ -195,18 +196,18 @@ func ServeTunnelLoop(
|
||||||
// protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
|
// protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
|
||||||
// max retries
|
// max retries
|
||||||
type protocolFallback struct {
|
type protocolFallback struct {
|
||||||
BackoffHandler
|
retry.BackoffHandler
|
||||||
protocol connection.Protocol
|
protocol connection.Protocol
|
||||||
inFallback bool
|
inFallback bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pf *protocolFallback) reset() {
|
func (pf *protocolFallback) reset() {
|
||||||
pf.resetNow()
|
pf.ResetNow()
|
||||||
pf.inFallback = false
|
pf.inFallback = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pf *protocolFallback) fallback(fallback connection.Protocol) {
|
func (pf *protocolFallback) fallback(fallback connection.Protocol) {
|
||||||
pf.resetNow()
|
pf.ResetNow()
|
||||||
pf.protocol = fallback
|
pf.protocol = fallback
|
||||||
pf.inFallback = true
|
pf.inFallback = true
|
||||||
}
|
}
|
||||||
|
@ -281,7 +282,7 @@ func ServeTunnel(
|
||||||
}
|
}
|
||||||
|
|
||||||
if protocol == connection.HTTP2 {
|
if protocol == connection.HTTP2 {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
||||||
err = ServeHTTP2(
|
err = ServeHTTP2(
|
||||||
ctx,
|
ctx,
|
||||||
connLog,
|
connLog,
|
||||||
|
@ -382,7 +383,7 @@ func ServeH2mux(
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
if config.NamedTunnel != nil {
|
if config.NamedTunnel != nil {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
|
||||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
|
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
|
||||||
}
|
}
|
||||||
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dynamicMockFetcher struct {
|
type dynamicMockFetcher struct {
|
||||||
|
@ -26,7 +27,7 @@ func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
|
||||||
|
|
||||||
func TestWaitForBackoffFallback(t *testing.T) {
|
func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
maxRetries := uint(3)
|
maxRetries := uint(3)
|
||||||
backoff := BackoffHandler{
|
backoff := retry.BackoffHandler{
|
||||||
MaxRetries: maxRetries,
|
MaxRetries: maxRetries,
|
||||||
BaseTime: time.Millisecond * 10,
|
BaseTime: time.Millisecond * 10,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package origin
|
package retry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -7,10 +7,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeclare time functions so they can be overridden in tests.
|
// Redeclare time functions so they can be overridden in tests.
|
||||||
var (
|
type clock struct {
|
||||||
timeNow = time.Now
|
Now func() time.Time
|
||||||
timeAfter = time.After
|
After func(d time.Duration) <-chan time.Time
|
||||||
)
|
}
|
||||||
|
|
||||||
|
var Clock = clock{
|
||||||
|
Now: time.Now,
|
||||||
|
After: time.After,
|
||||||
|
}
|
||||||
|
|
||||||
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
|
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
|
||||||
// The base time period is 1 second, doubling with each retry.
|
// The base time period is 1 second, doubling with each retry.
|
||||||
|
@ -39,7 +44,7 @@ func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duratio
|
||||||
return time.Duration(0), false
|
return time.Duration(0), false
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
if !b.resetDeadline.IsZero() && Clock.Now().After(b.resetDeadline) {
|
||||||
// b.retries would be set to 0 at this point
|
// b.retries would be set to 0 at this point
|
||||||
return time.Second, true
|
return time.Second, true
|
||||||
}
|
}
|
||||||
|
@ -53,7 +58,7 @@ func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duratio
|
||||||
// BackoffTimer returns a channel that sends the current time when the exponential backoff timeout expires.
|
// BackoffTimer returns a channel that sends the current time when the exponential backoff timeout expires.
|
||||||
// Returns nil if the maximum number of retries have been used.
|
// Returns nil if the maximum number of retries have been used.
|
||||||
func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
|
func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
|
||||||
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
if !b.resetDeadline.IsZero() && Clock.Now().After(b.resetDeadline) {
|
||||||
b.retries = 0
|
b.retries = 0
|
||||||
b.resetDeadline = time.Time{}
|
b.resetDeadline = time.Time{}
|
||||||
}
|
}
|
||||||
|
@ -66,7 +71,7 @@ func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
|
||||||
}
|
}
|
||||||
maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries))
|
maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries))
|
||||||
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
|
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
|
||||||
return timeAfter(timeToWait)
|
return Clock.After(timeToWait)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backoff is used to wait according to exponential backoff. Returns false if the
|
// Backoff is used to wait according to exponential backoff. Returns false if the
|
||||||
|
@ -89,7 +94,7 @@ func (b *BackoffHandler) Backoff(ctx context.Context) bool {
|
||||||
func (b *BackoffHandler) SetGracePeriod() {
|
func (b *BackoffHandler) SetGracePeriod() {
|
||||||
maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1)
|
maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1)
|
||||||
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
|
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
|
||||||
b.resetDeadline = timeNow().Add(timeToWait)
|
b.resetDeadline = Clock.Now().Add(timeToWait)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b BackoffHandler) GetBaseTime() time.Duration {
|
func (b BackoffHandler) GetBaseTime() time.Duration {
|
||||||
|
@ -108,6 +113,6 @@ func (b *BackoffHandler) ReachedMaxRetries() bool {
|
||||||
return b.retries == b.MaxRetries
|
return b.retries == b.MaxRetries
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BackoffHandler) resetNow() {
|
func (b *BackoffHandler) ResetNow() {
|
||||||
b.resetDeadline = time.Now()
|
b.resetDeadline = time.Now()
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package origin
|
package retry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -14,7 +14,7 @@ func immediateTimeAfter(time.Duration) <-chan time.Time {
|
||||||
|
|
||||||
func TestBackoffRetries(t *testing.T) {
|
func TestBackoffRetries(t *testing.T) {
|
||||||
// make backoff return immediately
|
// make backoff return immediately
|
||||||
timeAfter = immediateTimeAfter
|
Clock.After = immediateTimeAfter
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
backoff := BackoffHandler{MaxRetries: 3}
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
if !backoff.Backoff(ctx) {
|
if !backoff.Backoff(ctx) {
|
||||||
|
@ -33,7 +33,7 @@ func TestBackoffRetries(t *testing.T) {
|
||||||
|
|
||||||
func TestBackoffCancel(t *testing.T) {
|
func TestBackoffCancel(t *testing.T) {
|
||||||
// prevent backoff from returning normally
|
// prevent backoff from returning normally
|
||||||
timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
|
Clock.After = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
backoff := BackoffHandler{MaxRetries: 3}
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
|
@ -47,10 +47,10 @@ func TestBackoffCancel(t *testing.T) {
|
||||||
|
|
||||||
func TestBackoffGracePeriod(t *testing.T) {
|
func TestBackoffGracePeriod(t *testing.T) {
|
||||||
currentTime := time.Now()
|
currentTime := time.Now()
|
||||||
// make timeNow return whatever we like
|
// make Clock.Now return whatever we like
|
||||||
timeNow = func() time.Time { return currentTime }
|
Clock.Now = func() time.Time { return currentTime }
|
||||||
// make backoff return immediately
|
// make backoff return immediately
|
||||||
timeAfter = immediateTimeAfter
|
Clock.After = immediateTimeAfter
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
backoff := BackoffHandler{MaxRetries: 1}
|
backoff := BackoffHandler{MaxRetries: 1}
|
||||||
if !backoff.Backoff(ctx) {
|
if !backoff.Backoff(ctx) {
|
||||||
|
@ -71,7 +71,7 @@ func TestBackoffGracePeriod(t *testing.T) {
|
||||||
|
|
||||||
func TestGetMaxBackoffDurationRetries(t *testing.T) {
|
func TestGetMaxBackoffDurationRetries(t *testing.T) {
|
||||||
// make backoff return immediately
|
// make backoff return immediately
|
||||||
timeAfter = immediateTimeAfter
|
Clock.After = immediateTimeAfter
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
backoff := BackoffHandler{MaxRetries: 3}
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok {
|
if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok {
|
||||||
|
@ -96,7 +96,7 @@ func TestGetMaxBackoffDurationRetries(t *testing.T) {
|
||||||
|
|
||||||
func TestGetMaxBackoffDuration(t *testing.T) {
|
func TestGetMaxBackoffDuration(t *testing.T) {
|
||||||
// make backoff return immediately
|
// make backoff return immediately
|
||||||
timeAfter = immediateTimeAfter
|
Clock.After = immediateTimeAfter
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
backoff := BackoffHandler{MaxRetries: 3}
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
|
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
|
||||||
|
@ -118,7 +118,7 @@ func TestGetMaxBackoffDuration(t *testing.T) {
|
||||||
|
|
||||||
func TestBackoffRetryForever(t *testing.T) {
|
func TestBackoffRetryForever(t *testing.T) {
|
||||||
// make backoff return immediately
|
// make backoff return immediately
|
||||||
timeAfter = immediateTimeAfter
|
Clock.After = immediateTimeAfter
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
backoff := BackoffHandler{MaxRetries: 3, RetryForever: true}
|
backoff := BackoffHandler{MaxRetries: 3, RetryForever: true}
|
||||||
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
|
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
|
|
@ -18,7 +18,7 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/origin"
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -36,7 +36,7 @@ type AppInfo struct {
|
||||||
|
|
||||||
type lock struct {
|
type lock struct {
|
||||||
lockFilePath string
|
lockFilePath string
|
||||||
backoff *origin.BackoffHandler
|
backoff *retry.BackoffHandler
|
||||||
sigHandler *signalHandler
|
sigHandler *signalHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ func newLock(path string) *lock {
|
||||||
lockPath := path + ".lock"
|
lockPath := path + ".lock"
|
||||||
return &lock{
|
return &lock{
|
||||||
lockFilePath: lockPath,
|
lockFilePath: lockPath,
|
||||||
backoff: &origin.BackoffHandler{MaxRetries: 7},
|
backoff: &retry.BackoffHandler{MaxRetries: 7},
|
||||||
sigHandler: &signalHandler{
|
sigHandler: &signalHandler{
|
||||||
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
|
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
|
||||||
},
|
},
|
||||||
|
|
|
@ -4,15 +4,11 @@ import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var stripWebsocketHeaders = []string{
|
var stripWebsocketHeaders = []string{
|
||||||
|
@ -47,80 +43,6 @@ func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Con
|
||||||
return conn, response, nil
|
return conn, response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartProxyServer will start a websocket server that will decode
|
|
||||||
// the websocket data and write the resulting data to the provided
|
|
||||||
func StartProxyServer(
|
|
||||||
log *zerolog.Logger,
|
|
||||||
listener net.Listener,
|
|
||||||
staticHost string,
|
|
||||||
shutdownC <-chan struct{},
|
|
||||||
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger),
|
|
||||||
) error {
|
|
||||||
upgrader := websocket.Upgrader{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
}
|
|
||||||
h := handler{
|
|
||||||
upgrader: upgrader,
|
|
||||||
log: log,
|
|
||||||
staticHost: staticHost,
|
|
||||||
streamHandler: streamHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: &h}
|
|
||||||
go func() {
|
|
||||||
<-shutdownC
|
|
||||||
_ = httpServer.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return httpServer.Serve(listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTP handler for the websocket proxy.
|
|
||||||
type handler struct {
|
|
||||||
log *zerolog.Logger
|
|
||||||
staticHost string
|
|
||||||
upgrader websocket.Upgrader
|
|
||||||
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// If remote is an empty string, get the destination from the client.
|
|
||||||
finalDestination := h.staticHost
|
|
||||||
if finalDestination == "" {
|
|
||||||
if jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader); jumpDestination == "" {
|
|
||||||
h.log.Error().Msg("Did not receive final destination from client. The --destination flag is likely not set")
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
finalDestination = jumpDestination
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := net.Dial("tcp", finalDestination)
|
|
||||||
if err != nil {
|
|
||||||
h.log.Err(err).Msg("Cannot connect to remote")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
if !websocket.IsWebSocketUpgrade(r) {
|
|
||||||
_, _ = w.Write(nonWebSocketRequestPage())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conn, err := h.upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
h.log.Err(err).Msg("failed to upgrade")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
||||||
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
|
|
||||||
gorillaConn := &GorillaConn{Conn: conn, log: h.log}
|
|
||||||
go gorillaConn.pinger(r.Context())
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
h.streamHandler(gorillaConn, stream, h.log)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
||||||
func NewResponseHeader(req *http.Request) http.Header {
|
func NewResponseHeader(req *http.Request) http.Header {
|
||||||
header := http.Header{}
|
header := http.Header{}
|
||||||
|
@ -174,3 +96,27 @@ func ChangeRequestScheme(reqURL *url.URL) string {
|
||||||
return reqURL.Scheme
|
return reqURL.Scheme
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stream copies copy data to & from provided io.ReadWriters.
|
||||||
|
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
|
proxyDone := make(chan struct{}, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(conn, backendConn)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("conn to backendConn copy: %v", err)
|
||||||
|
}
|
||||||
|
proxyDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(backendConn, conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msgf("backendConn to conn copy: %v", err)
|
||||||
|
}
|
||||||
|
proxyDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// If one side is done, we are done.
|
||||||
|
<-proxyDone
|
||||||
|
}
|
||||||
|
|
|
@ -151,41 +151,3 @@ func TestWebsocketWrapper(t *testing.T) {
|
||||||
require.Equal(t, n, 2)
|
require.Equal(t, n, 2)
|
||||||
require.Equal(t, "bc", string(buf[:n]))
|
require.Equal(t, "bc", string(buf[:n]))
|
||||||
}
|
}
|
||||||
|
|
||||||
// func TestStartProxyServer(t *testing.T) {
|
|
||||||
// var wg sync.WaitGroup
|
|
||||||
// remoteAddress := "localhost:1113"
|
|
||||||
// listenerAddress := "localhost:1112"
|
|
||||||
// message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
|
||||||
// logger := zerolog.Nop()
|
|
||||||
// shutdownC := make(chan struct{})
|
|
||||||
|
|
||||||
// listener, err := net.Listen("tcp", listenerAddress)
|
|
||||||
// assert.NoError(t, err)
|
|
||||||
// defer listener.Close()
|
|
||||||
|
|
||||||
// remoteListener, err := net.Listen("tcp", remoteAddress)
|
|
||||||
// assert.NoError(t, err)
|
|
||||||
// defer remoteListener.Close()
|
|
||||||
|
|
||||||
// wg.Add(1)
|
|
||||||
// go func() {
|
|
||||||
// defer wg.Done()
|
|
||||||
// conn, err := remoteListener.Accept()
|
|
||||||
// assert.NoError(t, err)
|
|
||||||
// buf := make([]byte, len(message))
|
|
||||||
// conn.Read(buf)
|
|
||||||
// assert.Equal(t, string(buf), message)
|
|
||||||
// }()
|
|
||||||
|
|
||||||
// go func() {
|
|
||||||
// StartProxyServer(logger, listener, remoteAddress, shutdownC)
|
|
||||||
// }()
|
|
||||||
|
|
||||||
// req := testRequest(t, fmt.Sprintf("http://%s/", listenerAddress), nil)
|
|
||||||
// conn, _, err := ClientConnect(req, nil)
|
|
||||||
// assert.NoError(t, err)
|
|
||||||
// err = conn.WriteMessage(1, []byte(message))
|
|
||||||
// assert.NoError(t, err)
|
|
||||||
// wg.Wait()
|
|
||||||
// }
|
|
||||||
|
|
Loading…
Reference in New Issue