TUN-3617: Separate service from client, and implement different client for http vs. tcp origins

- extracted ResponseWriter from proxyConnection
 - added bastion tests over websocket
 - removed HTTPResp()
 - added some docstrings
 - Renamed some ingress clients as proxies
 - renamed instances of client to proxy in connection and origin
 - Stream no longer takes a context and logger.Service
This commit is contained in:
cthuang 2020-12-09 21:46:53 +00:00 committed by Nuno Diegues
parent 5e2b43adb5
commit e2262085e5
23 changed files with 839 additions and 354 deletions

View File

@ -23,7 +23,7 @@ type Websocket struct {
}
type wsdialer struct {
conn *cfwebsocket.Conn
conn *cfwebsocket.GorillaConn
}
func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, error) {
@ -75,7 +75,7 @@ func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC
// createWebsocketStream will create a WebSocket connection to stream data over
// It also handles redirects from Access and will present that flow if
// the token is not present on the request
func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.Conn, error) {
func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.GorillaConn, error) {
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
if err != nil {
return nil, err
@ -97,7 +97,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
return nil, err
}
return &cfwebsocket.Conn{Conn: wsConn}, nil
return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
}
// createAccessAuthenticatedStream will try load a token from storage and make

View File

@ -245,9 +245,9 @@ func prepareTunnelConfig(
edgeTLSConfigs[p] = edgeTLSConfig
}
originClient := origin.NewClient(ingressRules, tags, log)
originProxy := origin.NewOriginProxy(ingressRules, tags, log)
connectionConfig := &connection.Config{
OriginClient: originClient,
OriginProxy: originProxy,
GracePeriod: c.Duration("grace-period"),
ReplaceExisting: c.Bool("force"),
}

View File

@ -14,7 +14,7 @@ import (
const LogFieldConnIndex = "connIndex"
type Config struct {
OriginClient OriginClient
OriginProxy OriginProxy
GracePeriod time.Duration
ReplaceExisting bool
}
@ -50,12 +50,12 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
return c.Hostname == ""
}
type OriginClient interface {
type OriginProxy interface {
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
}
type ResponseWriter interface {
WriteRespHeaders(*http.Response) error
WriteRespHeaders(status int, header http.Header) error
WriteErrorResponse()
io.ReadWriter
}

View File

@ -19,8 +19,8 @@ const (
var (
testConfig = &Config{
OriginClient: &mockOriginClient{},
GracePeriod: time.Millisecond * 100,
OriginProxy: &mockOriginProxy{},
GracePeriod: time.Millisecond * 100,
}
log = zerolog.Nop()
testOriginURL = &url.URL{
@ -38,10 +38,10 @@ type testRequest struct {
isProxyError bool
}
type mockOriginClient struct {
type mockOriginProxy struct {
}
func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
if isWebsocket {
return wsEndpoint(w, r)
}
@ -74,7 +74,7 @@ func wsEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
_ = w.WriteRespHeaders(resp)
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
clientReader := nowriter{r.Body}
go func() {
for {
@ -95,7 +95,7 @@ func originRespEndpoint(w ResponseWriter, status int, data []byte) {
resp := &http.Response{
StatusCode: status,
}
_ = w.WriteRespHeaders(resp)
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
_, _ = w.Write(data)
}

View File

@ -216,7 +216,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
return reqErr
}
err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
err := h.config.OriginProxy.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
if err != nil {
respWriter.WriteErrorResponse()
return err
@ -240,8 +240,8 @@ type h2muxRespWriter struct {
*h2mux.MuxedStream
}
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
headers := h2mux.H1ResponseToH2ResponseHeaders(status, header)
headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
return rp.WriteHeaders(headers)
}

View File

@ -115,9 +115,9 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else if isWebsocketUpgrade(r) {
respWriter.shouldFlush = true
stripWebsocketUpgradeHeader(r)
err = c.config.OriginClient.Proxy(respWriter, r, true)
err = c.config.OriginProxy.Proxy(respWriter, r, true)
} else {
err = c.config.OriginClient.Proxy(respWriter, r, false)
err = c.config.OriginProxy.Proxy(respWriter, r, false)
}
if err != nil {
@ -161,10 +161,10 @@ type http2RespWriter struct {
shouldFlush bool
}
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
dest := rp.w.Header()
userHeaders := make(http.Header, len(resp.Header))
for header, values := range resp.Header {
userHeaders := make(http.Header, len(header))
for header, values := range header {
// Since these are http2 headers, they're required to be lowercase
h2name := strings.ToLower(header)
for _, v := range values {
@ -184,13 +184,12 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
// Perform user header serialization and set them in the single header
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
status := resp.StatusCode
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
if status == http.StatusSwitchingProtocols {
status = http.StatusOK
}
rp.w.WriteHeader(status)
if IsServerSentEvent(resp.Header) {
if IsServerSentEvent(header) {
rp.shouldFlush = true
}
if rp.shouldFlush {

View File

@ -125,12 +125,12 @@ func IsWebsocketClientHeader(headerName string) bool {
headerName == "upgrade"
}
func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) {
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []Header) {
h2 = []Header{
{Name: ":status", Value: strconv.Itoa(h1.StatusCode)},
{Name: ":status", Value: strconv.Itoa(status)},
}
userHeaders := make(http.Header, len(h1.Header))
for header, values := range h1.Header {
userHeaders := make(http.Header, len(h1))
for header, values := range h1 {
h2name := strings.ToLower(header)
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,

View File

@ -579,7 +579,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
Header: mockHeaders,
}
headers := H1ResponseToH2ResponseHeaders(&mockResponse)
headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
serializedHeadersIndex := -1
for i, header := range headers {
@ -622,7 +622,7 @@ func TestHeaderSize(t *testing.T) {
Header: largeHeaders,
}
serializedHeaders := H1ResponseToH2ResponseHeaders(&mockResponse)
serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
assert.NoError(t, err)
for _, header := range serializedHeaders {
@ -669,6 +669,6 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = H1ResponseToH2ResponseHeaders(h1resp)
_ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header)
}
}

View File

@ -85,16 +85,24 @@ func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) {
}
// Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
if c.IsSet("hello-world") {
return new(helloWorld), nil
}
if c.IsSet("url") || c.IsSet(config.BastionFlag) {
if c.IsSet(config.BastionFlag) {
return newBridgeService(), nil
}
if c.IsSet("url") {
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
if err != nil {
return nil, errors.Wrap(err, "Error validating origin URL")
}
return &localService{URL: originURL, RootURL: originURL}, nil
if isHTTPService(originURL) {
return &httpService{
url: originURL,
}, nil
}
return newSingleTCPService(originURL), nil
}
if c.IsSet("unix-socket") {
path, err := config.ValidateUnixSocket(c)
@ -104,7 +112,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginServ
return &unixSocketPath{path: path}, nil
}
u, err := url.Parse("http://localhost:8080")
return &localService{URL: u, RootURL: u}, err
return &httpService{url: u}, err
}
// IsEmpty checks if there are any ingress rules.
@ -136,7 +144,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
rules := make([]Rule, len(ingress))
for i, r := range ingress {
cfg := setConfig(defaults, r.OriginRequest)
var service OriginService
var service originService
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
// No validation necessary for unix socket filepath services
@ -156,7 +164,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
// overwrite the localService.URL field when `start` is called. So,
// leave the URL field empty for now.
cfg.BastionMode = true
service = new(localService)
service = newBridgeService()
} else {
// Validate URL services
u, err := url.Parse(r.Service)
@ -171,8 +179,11 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
if u.Path != "" {
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
}
serviceURL := localService{URL: u}
service = &serviceURL
if isHTTPService(u) {
service = &httpService{url: u}
} else {
service = newSingleTCPService(u)
}
}
if err := validateHostname(r, i, len(ingress)); err != nil {
@ -241,3 +252,7 @@ func ParseIngress(conf *config.Configuration) (Ingress, error) {
}
return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest))
}
func isHTTPService(url *url.URL) bool {
return url.Scheme == "http" || url.Scheme == "https" || url.Scheme == "ws" || url.Scheme == "wss"
}

View File

@ -61,12 +61,12 @@ ingress:
want: []Rule{
{
Hostname: "tunnel1.example.com",
Service: &localService{URL: localhost8000},
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
{
Hostname: "*",
Service: &localService{URL: localhost8001},
Service: &httpService{url: localhost8001},
Config: defaultConfig,
},
},
@ -82,7 +82,22 @@ extraKey: extraValue
want: []Rule{
{
Hostname: "*",
Service: &localService{URL: localhost8000},
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
},
},
{
name: "ws service",
args: args{rawYAML: `
ingress:
- hostname: "*"
service: wss://localhost:8000
`},
want: []Rule{
{
Hostname: "*",
Service: &httpService{url: MustParseURL(t, "wss://localhost:8000")},
Config: defaultConfig,
},
},
@ -95,7 +110,7 @@ ingress:
`},
want: []Rule{
{
Service: &localService{URL: localhost8000},
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
},
@ -209,6 +224,85 @@ ingress:
},
},
},
{
name: "TCP services",
args: args{rawYAML: `
ingress:
- hostname: tcp.foo.com
service: tcp://127.0.0.1
- hostname: tcp2.foo.com
service: tcp://localhost:8000
- service: http_status:404
`},
want: []Rule{
{
Hostname: "tcp.foo.com",
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")),
Config: defaultConfig,
},
{
Hostname: "tcp2.foo.com",
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")),
Config: defaultConfig,
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "SSH services",
args: args{rawYAML: `
ingress:
- service: ssh://127.0.0.1
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")),
Config: defaultConfig,
},
},
},
{
name: "RDP services",
args: args{rawYAML: `
ingress:
- service: rdp://127.0.0.1
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")),
Config: defaultConfig,
},
},
},
{
name: "SMB services",
args: args{rawYAML: `
ingress:
- service: smb://127.0.0.1
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")),
Config: defaultConfig,
},
},
},
{
name: "Other TCP services",
args: args{rawYAML: `
ingress:
- service: ftp://127.0.0.1
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")),
Config: defaultConfig,
},
},
},
{
name: "URL isn't necessary if using bastion",
args: args{rawYAML: `
@ -221,7 +315,7 @@ ingress:
want: []Rule{
{
Hostname: "bastion.foo.com",
Service: &localService{},
Service: newBridgeService(),
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{
@ -241,7 +335,7 @@ ingress:
want: []Rule{
{
Hostname: "bastion.foo.com",
Service: &localService{},
Service: newBridgeService(),
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{
@ -409,6 +503,37 @@ func TestFindMatchingRule(t *testing.T) {
}
}
func TestIsHTTPService(t *testing.T) {
tests := []struct {
url *url.URL
isHTTP bool
}{
{
url: MustParseURL(t, "http://localhost"),
isHTTP: true,
},
{
url: MustParseURL(t, "https://127.0.0.1:8000"),
isHTTP: true,
},
{
url: MustParseURL(t, "ws://localhost"),
isHTTP: true,
},
{
url: MustParseURL(t, "wss://localhost:8000"),
isHTTP: true,
},
{
url: MustParseURL(t, "tcp://localhost:9000"),
isHTTP: false,
},
}
for _, test := range tests {
assert.Equal(t, test.isHTTP, isHTTPService(test.url))
}
}
func mustParsePath(t *testing.T, path string) *regexp.Regexp {
regexp, err := regexp.Compile(path)
assert.NoError(t, err)

View File

@ -0,0 +1,62 @@
package ingress
import (
"io"
"net"
"net/http"
"github.com/cloudflare/cloudflared/websocket"
gws "github.com/gorilla/websocket"
)
// OriginConnection is a way to stream to a service running on the user's origin.
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
type OriginConnection interface {
// Stream should generally be implemented as a bidirectional io.Copy.
Stream(tunnelConn io.ReadWriter)
Close()
}
// tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct {
conn net.Conn
streamHandler func(tunnelConn io.ReadWriter, originConn net.Conn)
}
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) {
tc.streamHandler(tunnelConn, tc.conn)
}
func (tc *tcpConnection) Close() {
tc.conn.Close()
}
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
type wsConnection struct {
wsConn *gws.Conn
resp *http.Response
}
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) {
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn())
}
func (wsc *wsConnection) Close() {
wsc.resp.Body.Close()
wsc.wsConn.Close()
}
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
d := &gws.Dialer{
TLSClientConfig: transport.TLSClientConfig,
}
wsConn, resp, err := websocket.ClientConnect(r, d)
if err != nil {
return nil, err
}
return &wsConnection{
wsConn,
resp,
}, nil
}

100
ingress/origin_proxy.go Normal file
View File

@ -0,0 +1,100 @@
package ingress
import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
type HTTPOriginProxy interface {
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper
}
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, error)
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
// TODO: TUN-3636: establish connection to origins over UDS
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
}
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host
req.URL.Scheme = o.url.Scheme
return o.transport.RoundTrip(req)
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "https"
return o.transport.RoundTrip(req)
}
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss"
return newWSConnection(o.transport, req)
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
dest, err := o.destination(r)
if err != nil {
return nil, err
}
return o.client.connect(r, dest)
}
func (o *bridgeService) destination(r *http.Request) (string, error) {
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
if jumpDestination == "" {
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
}
// Strip scheme and path set by client. Without a scheme
// Parsing a hostname and path without scheme might not return an error due to parsing ambiguities
if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" {
return removePath(jumpURL.Host), nil
}
return removePath(jumpDestination), nil
}
func removePath(dest string) string {
return strings.SplitN(dest, "/", 2)[0]
}
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
return o.client.connect(r, o.dest)
}
type tcpClient struct {
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
}
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &tcpConnection{
conn: conn,
streamHandler: c.streamHandler,
}, nil
}

View File

@ -0,0 +1,107 @@
package ingress
import (
"net/http"
"testing"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/stretchr/testify/assert"
)
func TestBridgeServiceDestination(t *testing.T) {
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
tests := []struct {
name string
header http.Header
expectedDest string
wantErr bool
}{
{
name: "hostname destination",
header: http.Header{
canonicalJumpDestHeader: []string{"localhost"},
},
expectedDest: "localhost",
},
{
name: "hostname destination with port",
header: http.Header{
canonicalJumpDestHeader: []string{"localhost:9000"},
},
expectedDest: "localhost:9000",
},
{
name: "hostname destination with scheme and port",
header: http.Header{
canonicalJumpDestHeader: []string{"ssh://localhost:9000"},
},
expectedDest: "localhost:9000",
},
{
name: "full hostname url",
header: http.Header{
canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"},
},
expectedDest: "localhost:9000",
},
{
name: "hostname destination with port and path",
header: http.Header{
canonicalJumpDestHeader: []string{"localhost:9000/metrics"},
},
expectedDest: "localhost:9000",
},
{
name: "ip destination",
header: http.Header{
canonicalJumpDestHeader: []string{"127.0.0.1"},
},
expectedDest: "127.0.0.1",
},
{
name: "ip destination with port",
header: http.Header{
canonicalJumpDestHeader: []string{"127.0.0.1:9000"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "ip destination with port and path",
header: http.Header{
canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "ip destination with schem and port",
header: http.Header{
canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "full ip url",
header: http.Header{
canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "no destination",
wantErr: true,
},
}
s := newBridgeService()
for _, test := range tests {
r := &http.Request{
Header: test.header,
}
dest, err := s.destination(r)
if test.wantErr {
assert.Error(t, err, "Test %s expects error", test.name)
} else {
assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
}
}
}

View File

@ -8,7 +8,6 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
@ -21,10 +20,8 @@ import (
"github.com/rs/zerolog"
)
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper
// originService is something a tunnel can proxy traffic to.
type originService interface {
String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
@ -51,10 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
return nil
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
NetDial: o.transport.Dial,
@ -65,130 +58,87 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
return d.Dial(reqURL.String(), headers)
}
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
type localService struct {
// The URL for the user's origin service
RootURL *url.URL
// The URL that cloudflared should send requests to.
// If this origin requires starting a proxy, this is the proxy's address,
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
URL *url.URL
type httpService struct {
url *url.URL
transport *http.Transport
}
func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
// Rewrite the request URL so that it goes to the origin service.
reqURL.Host = o.URL.Host
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
return d.Dial(reqURL.String(), headers)
}
func (o *localService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg, log)
if err != nil {
return err
}
o.transport = transport
// Start a proxy if one is needed
if staticHost := o.staticHost(); originRequiresProxy(staticHost, cfg) {
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
return err
}
}
return nil
}
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *httpService) String() string {
return o.url.String()
}
// Start a listener for the proxy
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
listener, err := net.Listen("tcp", proxyAddress)
if err != nil {
log.Error().Msgf("Cannot start Websocket Proxy Server: %s", err)
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
// bridgeService is like a jump host, the destination is specified by the client
type bridgeService struct {
client *tcpClient
}
func newBridgeService() *bridgeService {
return &bridgeService{
client: &tcpClient{},
}
}
// Start the proxy itself
wg.Add(1)
go func() {
defer wg.Done()
streamHandler := websocket.DefaultStreamHandler
// This origin's config specifies what type of proxy to start.
switch cfg.ProxyType {
case socksProxy:
log.Info().Msg("SOCKS5 server started")
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
dialer := socks.NewConnDialer(remoteConn)
requestHandler := socks.NewRequestHandler(dialer)
socksServer := socks.NewConnectionHandler(requestHandler)
func (o *bridgeService) String() string {
return "bridge service"
}
_ = socksServer.Serve(wsConn)
}
case "":
log.Debug().Msg("Not starting any websocket proxy")
default:
log.Error().Msgf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
}
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
}()
// Modify this origin, so that it no longer points at the origin service directly.
// Instead, it points at the proxy to the origin service.
newURL, err := url.Parse("http://" + listener.Addr().String())
if err != nil {
return err
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = websocket.DefaultStreamHandler
}
o.URL = newURL
return nil
}
func (o *localService) String() string {
if o.isBastion() {
return "Bastion"
}
return o.URL.String()
type singleTCPService struct {
dest string
client *tcpClient
}
func (o *localService) isBastion() bool {
return o.URL == nil
}
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.URL.Host
req.URL.Scheme = o.URL.Scheme
return o.transport.RoundTrip(req)
}
func (o *localService) staticHost() string {
if o.URL == nil {
return ""
}
addPortIfMissing := func(uri *url.URL, port int) string {
if uri.Port() != "" {
return uri.Host
}
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
switch o.URL.Scheme {
func newSingleTCPService(url *url.URL) *singleTCPService {
switch url.Scheme {
case "ssh":
return addPortIfMissing(o.URL, 22)
addPortIfMissing(url, 22)
case "rdp":
return addPortIfMissing(o.URL, 3389)
addPortIfMissing(url, 3389)
case "smb":
return addPortIfMissing(o.URL, 445)
addPortIfMissing(url, 445)
case "tcp":
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
}
return ""
return &singleTCPService{
dest: url.Host,
client: &tcpClient{},
}
}
func addPortIfMissing(uri *url.URL, port int) {
if uri.Port() == "" {
uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
}
func (o *singleTCPService) String() string {
return o.dest
}
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = websocket.DefaultStreamHandler
}
return nil
}
// HelloWorld is an OriginService for the built-in Hello World server.
@ -228,26 +178,6 @@ func (o *helloWorld) start(
return nil
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "https"
return o.transport.RoundTrip(req)
}
func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Host = o.server.Addr().String()
reqURL.Scheme = "wss"
return d.Dial(reqURL.String(), headers)
}
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
return staticHost != "" || cfg.BastionMode
}
// statusCode is an OriginService that just responds with a given HTTP status.
// Typical use-case is "user wants the catch-all rule to just respond 404".
type statusCode struct {
@ -277,10 +207,6 @@ func (o *statusCode) start(
return nil
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
type NopReadCloser struct{}
// Read always returns EOF to signal end of input
@ -292,7 +218,7 @@ func (nrc *NopReadCloser) Close() error {
return nil
}
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool")
@ -337,19 +263,19 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol
return &httpTransport, nil
}
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
type MockOriginService struct {
// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
type MockOriginHTTPService struct {
Transport http.RoundTripper
}
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) {
return mos.Transport.RoundTrip(req)
}
func (mos MockOriginService) String() string {
func (mos MockOriginHTTPService) String() string {
return "MockOriginService"
}
func (mos MockOriginService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
return nil
}

View File

@ -17,7 +17,7 @@ type Rule struct {
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service OriginService
Service originService
// Configure the request cloudflared sends to this specific origin.
Config OriginRequestConfig

View File

@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service OriginService
Service originService
}
type args struct {
requestURL *url.URL

0
origin/cloudflared.log Normal file
View File

View File

@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
@ -15,7 +14,6 @@ import (
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
@ -24,15 +22,15 @@ const (
TagHeaderNamePrefix = "Cf-Warp-Tag-"
)
type client struct {
type proxy struct {
ingressRules ingress.Ingress
tags []tunnelpogs.Tag
log *zerolog.Logger
bufferPool *buffer.Pool
}
func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginClient {
return &client{
func NewOriginProxy(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginProxy {
return &proxy{
ingressRules: ingressRules,
tags: tags,
log: log,
@ -40,36 +38,55 @@ func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog
}
}
func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
incrementRequests()
defer decrementConcurrentRequests()
cfRay := findCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
c.appendTagHeaders(req)
rule, ruleNum := c.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
c.logRequest(req, cfRay, lbProbe, ruleNum)
p.appendTagHeaders(req)
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
p.logRequest(req, cfRay, lbProbe, ruleNum)
var (
resp *http.Response
err error
)
if isWebsocket {
resp, err = c.proxyWebsocket(w, req, rule)
go websocket.NewConn(w, p.log).Pinger(req.Context())
connClosedChan := make(chan struct{})
err = p.proxyConnection(connClosedChan, w, req, rule)
if err == nil {
respHeader := websocket.NewResponseHeader(req)
status := http.StatusSwitchingProtocols
resp = &http.Response{
Status: http.StatusText(status),
StatusCode: status,
Header: respHeader,
ContentLength: -1,
}
w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader)
<-connClosedChan
}
} else {
resp, err = c.proxyHTTP(w, req, rule)
resp, err = p.proxyHTTP(w, req, rule)
}
if err != nil {
c.logRequestError(err, cfRay, ruleNum)
p.logRequestError(err, cfRay, ruleNum)
w.WriteErrorResponse()
return err
}
c.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
return nil
}
func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
@ -87,73 +104,69 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule
req.Host = hostHeader
}
resp, err := rule.Service.RoundTrip(req)
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service)
return nil, fmt.Errorf("Not a http service")
}
resp, err := httpService.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
defer resp.Body.Close()
err = w.WriteRespHeaders(resp)
err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
if connection.IsServerSentEvent(resp.Header) {
c.log.Debug().Msg("Detected Server-Side Events from Origin")
c.writeEventStream(w, resp.Body)
p.log.Debug().Msg("Detected Server-Side Events from Origin")
p.writeEventStream(w, resp.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
buf := c.bufferPool.Get()
defer c.bufferPool.Put(buf)
buf := p.bufferPool.Get()
defer p.bufferPool.Put(buf)
_, _ = io.CopyBuffer(w, resp.Body, buf)
}
return resp, nil
}
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
func (p *proxy) proxyConnection(connClosedChan chan struct{},
conn io.ReadWriter, req *http.Request, rule *ingress.Rule) error {
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
dialler, ok := rule.Service.(websocket.Dialler)
connectionService, ok := rule.Service.(ingress.StreamBasedOriginProxy)
if !ok {
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
return fmt.Errorf("Not a connection-oriented service")
}
conn, resp, err := websocket.ClientConnect(req, dialler)
originConn, err := connectionService.EstablishConnection(req)
if err != nil {
return nil, err
return err
}
serveCtx, cancel := context.WithCancel(req.Context())
connClosedChan := make(chan struct{})
go func() {
// serveCtx is done if req is cancelled, or streamWebsocket returns
<-serveCtx.Done()
_ = conn.Close()
originConn.Close()
close(connClosedChan)
}()
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
err = c.streamWebsocket(w, conn.UnderlyingConn(), resp)
cancel()
go func() {
originConn.Stream(conn)
cancel()
}()
// We need to make sure conn is closed before returning, otherwise we might write to conn after Proxy returns
<-connClosedChan
return resp, err
}
func (c *client) streamWebsocket(w connection.ResponseWriter, conn net.Conn, resp *http.Response) error {
err := w.WriteRespHeaders(resp)
if err != nil {
return errors.Wrap(err, "Error writing websocket response header")
}
websocket.Stream(conn, w)
return nil
}
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {
line, err := reader.ReadBytes('\n')
@ -164,54 +177,54 @@ func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadC
}
}
func (c *client) appendTagHeaders(r *http.Request) {
for _, tag := range c.tags {
func (p *proxy) appendTagHeaders(r *http.Request) {
for _, tag := range p.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
}
}
func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) {
func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) {
if cfRay != "" {
c.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else if lbProbe {
c.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else {
c.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
}
c.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
c.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum)
p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum)
if contentLen := r.ContentLength; contentLen == -1 {
c.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay)
p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay)
} else {
c.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen)
p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen)
}
}
func (c *client) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) {
func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) {
responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc()
if cfRay != "" {
c.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum)
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum)
} else if lbProbe {
c.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status)
p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status)
} else {
c.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum)
p.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum)
}
c.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
c.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay)
p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay)
} else {
c.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen)
p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen)
}
}
func (c *client) logRequestError(err error, cfRay string, ruleNum int) {
func (p *proxy) logRequestError(err error, cfRay string, ruleNum int) {
requestErrors.Inc()
if cfRay != "" {
c.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err)
p.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err)
} else {
c.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err)
p.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err)
}
}

View File

@ -5,7 +5,9 @@ import (
"context"
"flag"
"fmt"
"github.com/cloudflare/cloudflared/logger"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
@ -14,9 +16,11 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/urfave/cli/v2"
"github.com/gobwas/ws/wsutil"
@ -39,9 +43,9 @@ func newMockHTTPRespWriter() *mockHTTPRespWriter {
}
}
func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
w.WriteHeader(resp.StatusCode)
for header, val := range resp.Header {
func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error {
w.WriteHeader(status)
for header, val := range header {
w.Header()[header] = val
}
return nil
@ -125,28 +129,28 @@ func TestProxySingleOrigin(t *testing.T) {
errC := make(chan error)
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
client := NewClient(ingressRule, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(t, client))
t.Run("testProxyWebsocket", testProxyWebsocket(t, client))
t.Run("testProxySSE", testProxySSE(t, client))
proxy := NewOriginProxy(ingressRule, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
t.Run("testProxySSE", testProxySSE(t, proxy))
cancel()
wg.Wait()
}
func testProxyHTTP(t *testing.T, client connection.OriginClient) func(t *testing.T) {
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
respWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
err = client.Proxy(respWriter, req, false)
err = proxy.Proxy(respWriter, req, false)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, respWriter.Code)
}
}
func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *testing.T) {
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
// WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background())
@ -159,7 +163,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te
wg.Add(1)
go func() {
defer wg.Done()
err = client.Proxy(respWriter, req, true)
err = proxy.Proxy(respWriter, req, true)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
@ -169,7 +173,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te
err = wsutil.WriteClientText(writePipe, msg)
require.NoError(t, err)
// ReadServerText reads next data message from rw, considering that caller represents client side.
// ReadServerText reads next data message from rw, considering that caller represents proxy side.
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
@ -186,7 +190,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te
}
}
func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.T) {
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
var (
pushCount = 50
@ -201,7 +205,7 @@ func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.
wg.Add(1)
go func() {
defer wg.Done()
err = client.Proxy(respWriter, req, false)
err = proxy.Proxy(respWriter, req, false)
require.NoError(t, err)
require.Equal(t, http.StatusOK, respWriter.Code)
@ -258,7 +262,7 @@ func TestProxyMultipleOrigins(t *testing.T) {
var wg sync.WaitGroup
require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC))
client := NewClient(ingress, testTags, &log)
proxy := NewOriginProxy(ingress, testTags, &log)
tests := []struct {
url string
@ -294,7 +298,7 @@ func TestProxyMultipleOrigins(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, test.url, nil)
require.NoError(t, err)
err = client.Proxy(respWriter, req, false)
err = proxy.Proxy(respWriter, req, false)
require.NoError(t, err)
assert.Equal(t, test.expectedStatus, respWriter.Code)
@ -327,7 +331,7 @@ func TestProxyError(t *testing.T) {
{
Hostname: "*",
Path: nil,
Service: ingress.MockOriginService{
Service: ingress.MockOriginHTTPService{
Transport: errorOriginTransport{},
},
},
@ -336,14 +340,85 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
client := NewClient(ingress, testTags, &log)
proxy := NewOriginProxy(ingress, testTags, &log)
respWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
err = client.Proxy(respWriter, req, false)
err = proxy.Proxy(respWriter, req, false)
assert.Error(t, err)
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
assert.Equal(t, "http response error", respWriter.Body.String())
}
func TestProxyBastionMode(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
flagSet.Bool("bastion", true, "")
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
err := cliCtx.Set(config.BastionFlag, "true")
require.NoError(t, err)
allowURLFromArgs := false
ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs)
require.NoError(t, err)
var wg sync.WaitGroup
errC := make(chan error)
log := logger.Create(nil)
ingressRule.StartOrigins(&wg, log, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, testTags, log)
t.Run("testBastionWebsocket", testBastionWebsocket(proxy))
cancel()
}
func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
// WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background())
readPipe, _ := io.Pipe()
respWriter := newMockWSRespWriter(readPipe)
var wg sync.WaitGroup
msgFromConn := []byte("data from websocket proxy")
ln, err := net.Listen("tcp", "127.0.0.1:0")
wg.Add(1)
go func() {
defer wg.Done()
defer ln.Close()
server, err := ln.Accept()
require.NoError(t, err)
conn := websocket.NewConn(server, nil)
conn.Write(msgFromConn)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
req.Header.Set(h2mux.CFJumpDestinationHeader, ln.Addr().String())
wg.Add(1)
go func() {
defer wg.Done()
err = proxy.Proxy(respWriter, req, true)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
}()
// ReadServerText reads next data message from rw, considering that caller represents proxy side.
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
if err != io.EOF {
require.NoError(t, err)
require.Equal(t, msgFromConn, returnedMsg)
}
cancel()
wg.Wait()
}
}

View File

@ -3,6 +3,7 @@ package socks
import (
"fmt"
"io"
"net"
"strings"
)
@ -104,3 +105,11 @@ func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Reques
}
return nil
}
func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn) {
dialer := NewConnDialer(originConn)
requestHandler := NewRequestHandler(dialer)
socksServer := NewConnectionHandler(requestHandler)
socksServer.Serve(tunnelConn)
}

118
websocket/connection.go Normal file
View File

@ -0,0 +1,118 @@
package websocket
import (
"context"
"github.com/rs/zerolog"
"io"
"time"
gobwas "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
// This is still used by access carrier
type GorillaConn struct {
*websocket.Conn
log *zerolog.Logger
}
// Read will read messages from the websocket connection
func (c *GorillaConn) Read(p []byte) (int, error) {
_, message, err := c.Conn.ReadMessage()
if err != nil {
return 0, err
}
return copy(p, message), nil
}
// Write will write messages to the websocket connection
func (c *GorillaConn) Write(p []byte) (int, error) {
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
return 0, err
}
return len(p), nil
}
// pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
c.log.Debug().Msgf("failed to send ping message: %s", err)
}
case <-ctx.Done():
return
}
}
}
type Conn struct {
rw io.ReadWriter
log *zerolog.Logger
}
func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn {
return &Conn{
rw: rw,
log: log,
}
}
// Read will read messages from the websocket connection
func (c *Conn) Read(reader []byte) (int, error) {
data, err := wsutil.ReadClientBinary(c.rw)
if err != nil {
return 0, err
}
return copy(reader, data), nil
}
// Write will write messages to the websocket connection
func (c *Conn) Write(p []byte) (int, error) {
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
return 0, err
}
return len(p), nil
}
func (c *Conn) Pinger(ctx context.Context) {
pongMessge := wsutil.Message{
OpCode: gobwas.OpPong,
Payload: []byte{},
}
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
c.log.Err(err).Msgf("failed to write ping message")
}
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
c.log.Err(err).Msgf("failed to write pong message")
}
case <-ctx.Done():
return
}
}
}

View File

@ -2,7 +2,6 @@ package websocket
import (
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"io"
"net"
@ -16,17 +15,6 @@ import (
"github.com/rs/zerolog"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
var stripWebsocketHeaders = []string{
"Upgrade",
"Connection",
@ -35,70 +23,28 @@ var stripWebsocketHeaders = []string{
"Sec-Websocket-Extensions",
}
// Conn is a wrapper around the standard gorilla websocket
// but implements a ReadWriter
type Conn struct {
*websocket.Conn
}
// Read will read messages from the websocket connection
func (c *Conn) Read(p []byte) (int, error) {
_, message, err := c.Conn.ReadMessage()
if err != nil {
return 0, err
}
return copy(p, message), nil
}
// Write will write messages to the websocket connection
func (c *Conn) Write(p []byte) (int, error) {
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
return 0, err
}
return len(p), nil
}
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// Dialler is something that can proxy websocket requests.
type Dialler interface {
Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error)
}
type defaultDialler struct {
tlsConfig *tls.Config
}
func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) {
d := &websocket.Dialer{
TLSClientConfig: dd.tlsConfig,
Proxy: http.ProxyFromEnvironment,
}
return d.Dial(url.String(), header)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does
// not need to be closed by the application.
func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) {
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = ChangeRequestScheme(req.URL)
wsHeaders := websocketHeaders(req)
if dialler == nil {
dialler = new(defaultDialler)
dialler = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
conn, response, err := dialler.Dial(req.URL, wsHeaders)
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
if err != nil {
return nil, response, err
}
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
return conn, response, err
return conn, response, nil
}
// Stream copies copy data to & from provided io.ReadWriters.
@ -121,8 +67,8 @@ func Stream(conn, backendConn io.ReadWriter) {
// DefaultStreamHandler is provided to the the standard websocket to origin stream
// This exist to allow SOCKS to deframe data before it gets to the origin
func DefaultStreamHandler(wsConn *Conn, remoteConn net.Conn, _ http.Header) {
Stream(wsConn, remoteConn)
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) {
Stream(originConn, remoteConn)
}
// StartProxyServer will start a websocket server that will decode
@ -132,7 +78,7 @@ func StartProxyServer(
listener net.Listener,
staticHost string,
shutdownC <-chan struct{},
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header),
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn),
) error {
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
@ -159,7 +105,7 @@ type handler struct {
log *zerolog.Logger
staticHost string
upgrader websocket.Upgrader
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header)
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -192,14 +138,20 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
done := make(chan struct{})
go pinger(h.log, conn, done)
defer func() {
done <- struct{}{}
_ = conn.Close()
}()
gorillaConn := &GorillaConn{conn, h.log}
go gorillaConn.pinger(r.Context())
defer conn.Close()
h.streamHandler(&Conn{conn}, stream, r.Header)
h.streamHandler(gorillaConn, stream)
}
// NewResponseHeader returns headers needed to return to origin for completing handshake
func NewResponseHeader(req *http.Request) http.Header {
header := http.Header{}
header.Add("Connection", "Upgrade")
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
header.Add("Upgrade", "websocket")
return header
}
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
@ -246,19 +198,3 @@ func ChangeRequestScheme(reqURL *url.URL) string {
return reqURL.Scheme
}
}
// pinger simulates the websocket connection to keep it alive
func pinger(logger *zerolog.Logger, ws *websocket.Conn, done chan struct{}) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
logger.Debug().Msgf("failed to send ping message: %s", err)
}
case <-done:
return
}
}
}

View File

@ -11,7 +11,7 @@ import (
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig"
gws "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)
@ -78,7 +78,7 @@ func TestServe(t *testing.T) {
tlsConfig := websocketClientTLSConfig(t)
assert.NotNil(t, tlsConfig)
d := defaultDialler{tlsConfig: tlsConfig}
d := gws.Dialer{TLSClientConfig: tlsConfig}
conn, resp, err := ClientConnect(req, &d)
assert.NoError(t, err)
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))