TUN-3617: Separate service from client, and implement different client for http vs. tcp origins
- extracted ResponseWriter from proxyConnection - added bastion tests over websocket - removed HTTPResp() - added some docstrings - Renamed some ingress clients as proxies - renamed instances of client to proxy in connection and origin - Stream no longer takes a context and logger.Service
This commit is contained in:
parent
5e2b43adb5
commit
e2262085e5
|
@ -23,7 +23,7 @@ type Websocket struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type wsdialer struct {
|
type wsdialer struct {
|
||||||
conn *cfwebsocket.Conn
|
conn *cfwebsocket.GorillaConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, error) {
|
func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, error) {
|
||||||
|
@ -75,7 +75,7 @@ func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC
|
||||||
// createWebsocketStream will create a WebSocket connection to stream data over
|
// createWebsocketStream will create a WebSocket connection to stream data over
|
||||||
// It also handles redirects from Access and will present that flow if
|
// It also handles redirects from Access and will present that flow if
|
||||||
// the token is not present on the request
|
// the token is not present on the request
|
||||||
func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.Conn, error) {
|
func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.GorillaConn, error) {
|
||||||
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
|
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -97,7 +97,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cfwebsocket.Conn{Conn: wsConn}, nil
|
return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createAccessAuthenticatedStream will try load a token from storage and make
|
// createAccessAuthenticatedStream will try load a token from storage and make
|
||||||
|
|
|
@ -245,9 +245,9 @@ func prepareTunnelConfig(
|
||||||
edgeTLSConfigs[p] = edgeTLSConfig
|
edgeTLSConfigs[p] = edgeTLSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
originClient := origin.NewClient(ingressRules, tags, log)
|
originProxy := origin.NewOriginProxy(ingressRules, tags, log)
|
||||||
connectionConfig := &connection.Config{
|
connectionConfig := &connection.Config{
|
||||||
OriginClient: originClient,
|
OriginProxy: originProxy,
|
||||||
GracePeriod: c.Duration("grace-period"),
|
GracePeriod: c.Duration("grace-period"),
|
||||||
ReplaceExisting: c.Bool("force"),
|
ReplaceExisting: c.Bool("force"),
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
const LogFieldConnIndex = "connIndex"
|
const LogFieldConnIndex = "connIndex"
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
OriginClient OriginClient
|
OriginProxy OriginProxy
|
||||||
GracePeriod time.Duration
|
GracePeriod time.Duration
|
||||||
ReplaceExisting bool
|
ReplaceExisting bool
|
||||||
}
|
}
|
||||||
|
@ -50,12 +50,12 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
|
||||||
return c.Hostname == ""
|
return c.Hostname == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type OriginClient interface {
|
type OriginProxy interface {
|
||||||
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseWriter interface {
|
type ResponseWriter interface {
|
||||||
WriteRespHeaders(*http.Response) error
|
WriteRespHeaders(status int, header http.Header) error
|
||||||
WriteErrorResponse()
|
WriteErrorResponse()
|
||||||
io.ReadWriter
|
io.ReadWriter
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,8 +19,8 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testConfig = &Config{
|
testConfig = &Config{
|
||||||
OriginClient: &mockOriginClient{},
|
OriginProxy: &mockOriginProxy{},
|
||||||
GracePeriod: time.Millisecond * 100,
|
GracePeriod: time.Millisecond * 100,
|
||||||
}
|
}
|
||||||
log = zerolog.Nop()
|
log = zerolog.Nop()
|
||||||
testOriginURL = &url.URL{
|
testOriginURL = &url.URL{
|
||||||
|
@ -38,10 +38,10 @@ type testRequest struct {
|
||||||
isProxyError bool
|
isProxyError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockOriginClient struct {
|
type mockOriginProxy struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
|
func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
|
||||||
if isWebsocket {
|
if isWebsocket {
|
||||||
return wsEndpoint(w, r)
|
return wsEndpoint(w, r)
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ func wsEndpoint(w ResponseWriter, r *http.Request) error {
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
}
|
}
|
||||||
_ = w.WriteRespHeaders(resp)
|
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||||
clientReader := nowriter{r.Body}
|
clientReader := nowriter{r.Body}
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
@ -95,7 +95,7 @@ func originRespEndpoint(w ResponseWriter, status int, data []byte) {
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: status,
|
StatusCode: status,
|
||||||
}
|
}
|
||||||
_ = w.WriteRespHeaders(resp)
|
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||||
_, _ = w.Write(data)
|
_, _ = w.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -216,7 +216,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
return reqErr
|
return reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
err := h.config.OriginProxy.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respWriter.WriteErrorResponse()
|
respWriter.WriteErrorResponse()
|
||||||
return err
|
return err
|
||||||
|
@ -240,8 +240,8 @@ type h2muxRespWriter struct {
|
||||||
*h2mux.MuxedStream
|
*h2mux.MuxedStream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
|
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
|
headers := h2mux.H1ResponseToH2ResponseHeaders(status, header)
|
||||||
headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
|
headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
|
||||||
return rp.WriteHeaders(headers)
|
return rp.WriteHeaders(headers)
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,9 +115,9 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
} else if isWebsocketUpgrade(r) {
|
} else if isWebsocketUpgrade(r) {
|
||||||
respWriter.shouldFlush = true
|
respWriter.shouldFlush = true
|
||||||
stripWebsocketUpgradeHeader(r)
|
stripWebsocketUpgradeHeader(r)
|
||||||
err = c.config.OriginClient.Proxy(respWriter, r, true)
|
err = c.config.OriginProxy.Proxy(respWriter, r, true)
|
||||||
} else {
|
} else {
|
||||||
err = c.config.OriginClient.Proxy(respWriter, r, false)
|
err = c.config.OriginProxy.Proxy(respWriter, r, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -161,10 +161,10 @@ type http2RespWriter struct {
|
||||||
shouldFlush bool
|
shouldFlush bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
dest := rp.w.Header()
|
dest := rp.w.Header()
|
||||||
userHeaders := make(http.Header, len(resp.Header))
|
userHeaders := make(http.Header, len(header))
|
||||||
for header, values := range resp.Header {
|
for header, values := range header {
|
||||||
// Since these are http2 headers, they're required to be lowercase
|
// Since these are http2 headers, they're required to be lowercase
|
||||||
h2name := strings.ToLower(header)
|
h2name := strings.ToLower(header)
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
|
@ -184,13 +184,12 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
// Perform user header serialization and set them in the single header
|
// Perform user header serialization and set them in the single header
|
||||||
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
||||||
status := resp.StatusCode
|
|
||||||
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
||||||
if status == http.StatusSwitchingProtocols {
|
if status == http.StatusSwitchingProtocols {
|
||||||
status = http.StatusOK
|
status = http.StatusOK
|
||||||
}
|
}
|
||||||
rp.w.WriteHeader(status)
|
rp.w.WriteHeader(status)
|
||||||
if IsServerSentEvent(resp.Header) {
|
if IsServerSentEvent(header) {
|
||||||
rp.shouldFlush = true
|
rp.shouldFlush = true
|
||||||
}
|
}
|
||||||
if rp.shouldFlush {
|
if rp.shouldFlush {
|
||||||
|
|
|
@ -125,12 +125,12 @@ func IsWebsocketClientHeader(headerName string) bool {
|
||||||
headerName == "upgrade"
|
headerName == "upgrade"
|
||||||
}
|
}
|
||||||
|
|
||||||
func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) {
|
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []Header) {
|
||||||
h2 = []Header{
|
h2 = []Header{
|
||||||
{Name: ":status", Value: strconv.Itoa(h1.StatusCode)},
|
{Name: ":status", Value: strconv.Itoa(status)},
|
||||||
}
|
}
|
||||||
userHeaders := make(http.Header, len(h1.Header))
|
userHeaders := make(http.Header, len(h1))
|
||||||
for header, values := range h1.Header {
|
for header, values := range h1 {
|
||||||
h2name := strings.ToLower(header)
|
h2name := strings.ToLower(header)
|
||||||
if h2name == "content-length" {
|
if h2name == "content-length" {
|
||||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||||
|
|
|
@ -579,7 +579,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
|
||||||
Header: mockHeaders,
|
Header: mockHeaders,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := H1ResponseToH2ResponseHeaders(&mockResponse)
|
headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
|
||||||
|
|
||||||
serializedHeadersIndex := -1
|
serializedHeadersIndex := -1
|
||||||
for i, header := range headers {
|
for i, header := range headers {
|
||||||
|
@ -622,7 +622,7 @@ func TestHeaderSize(t *testing.T) {
|
||||||
Header: largeHeaders,
|
Header: largeHeaders,
|
||||||
}
|
}
|
||||||
|
|
||||||
serializedHeaders := H1ResponseToH2ResponseHeaders(&mockResponse)
|
serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
|
||||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
|
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
for _, header := range serializedHeaders {
|
for _, header := range serializedHeaders {
|
||||||
|
@ -669,6 +669,6 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = H1ResponseToH2ResponseHeaders(h1resp)
|
_ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,16 +85,24 @@ func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a single origin service from the CLI/config.
|
// Get a single origin service from the CLI/config.
|
||||||
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
|
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
|
||||||
if c.IsSet("hello-world") {
|
if c.IsSet("hello-world") {
|
||||||
return new(helloWorld), nil
|
return new(helloWorld), nil
|
||||||
}
|
}
|
||||||
if c.IsSet("url") || c.IsSet(config.BastionFlag) {
|
if c.IsSet(config.BastionFlag) {
|
||||||
|
return newBridgeService(), nil
|
||||||
|
}
|
||||||
|
if c.IsSet("url") {
|
||||||
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
|
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error validating origin URL")
|
return nil, errors.Wrap(err, "Error validating origin URL")
|
||||||
}
|
}
|
||||||
return &localService{URL: originURL, RootURL: originURL}, nil
|
if isHTTPService(originURL) {
|
||||||
|
return &httpService{
|
||||||
|
url: originURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return newSingleTCPService(originURL), nil
|
||||||
}
|
}
|
||||||
if c.IsSet("unix-socket") {
|
if c.IsSet("unix-socket") {
|
||||||
path, err := config.ValidateUnixSocket(c)
|
path, err := config.ValidateUnixSocket(c)
|
||||||
|
@ -104,7 +112,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginServ
|
||||||
return &unixSocketPath{path: path}, nil
|
return &unixSocketPath{path: path}, nil
|
||||||
}
|
}
|
||||||
u, err := url.Parse("http://localhost:8080")
|
u, err := url.Parse("http://localhost:8080")
|
||||||
return &localService{URL: u, RootURL: u}, err
|
return &httpService{url: u}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEmpty checks if there are any ingress rules.
|
// IsEmpty checks if there are any ingress rules.
|
||||||
|
@ -136,7 +144,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
rules := make([]Rule, len(ingress))
|
rules := make([]Rule, len(ingress))
|
||||||
for i, r := range ingress {
|
for i, r := range ingress {
|
||||||
cfg := setConfig(defaults, r.OriginRequest)
|
cfg := setConfig(defaults, r.OriginRequest)
|
||||||
var service OriginService
|
var service originService
|
||||||
|
|
||||||
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
||||||
// No validation necessary for unix socket filepath services
|
// No validation necessary for unix socket filepath services
|
||||||
|
@ -156,7 +164,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
// overwrite the localService.URL field when `start` is called. So,
|
// overwrite the localService.URL field when `start` is called. So,
|
||||||
// leave the URL field empty for now.
|
// leave the URL field empty for now.
|
||||||
cfg.BastionMode = true
|
cfg.BastionMode = true
|
||||||
service = new(localService)
|
service = newBridgeService()
|
||||||
} else {
|
} else {
|
||||||
// Validate URL services
|
// Validate URL services
|
||||||
u, err := url.Parse(r.Service)
|
u, err := url.Parse(r.Service)
|
||||||
|
@ -171,8 +179,11 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
if u.Path != "" {
|
if u.Path != "" {
|
||||||
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
|
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
|
||||||
}
|
}
|
||||||
serviceURL := localService{URL: u}
|
if isHTTPService(u) {
|
||||||
service = &serviceURL
|
service = &httpService{url: u}
|
||||||
|
} else {
|
||||||
|
service = newSingleTCPService(u)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateHostname(r, i, len(ingress)); err != nil {
|
if err := validateHostname(r, i, len(ingress)); err != nil {
|
||||||
|
@ -241,3 +252,7 @@ func ParseIngress(conf *config.Configuration) (Ingress, error) {
|
||||||
}
|
}
|
||||||
return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest))
|
return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isHTTPService(url *url.URL) bool {
|
||||||
|
return url.Scheme == "http" || url.Scheme == "https" || url.Scheme == "ws" || url.Scheme == "wss"
|
||||||
|
}
|
||||||
|
|
|
@ -61,12 +61,12 @@ ingress:
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "tunnel1.example.com",
|
Hostname: "tunnel1.example.com",
|
||||||
Service: &localService{URL: localhost8000},
|
Service: &httpService{url: localhost8000},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
Service: &localService{URL: localhost8001},
|
Service: &httpService{url: localhost8001},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -82,7 +82,22 @@ extraKey: extraValue
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
Service: &localService{URL: localhost8000},
|
Service: &httpService{url: localhost8000},
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ws service",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- hostname: "*"
|
||||||
|
service: wss://localhost:8000
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Hostname: "*",
|
||||||
|
Service: &httpService{url: MustParseURL(t, "wss://localhost:8000")},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -95,7 +110,7 @@ ingress:
|
||||||
`},
|
`},
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Service: &localService{URL: localhost8000},
|
Service: &httpService{url: localhost8000},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -209,6 +224,85 @@ ingress:
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "TCP services",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- hostname: tcp.foo.com
|
||||||
|
service: tcp://127.0.0.1
|
||||||
|
- hostname: tcp2.foo.com
|
||||||
|
service: tcp://localhost:8000
|
||||||
|
- service: http_status:404
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Hostname: "tcp.foo.com",
|
||||||
|
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Hostname: "tcp2.foo.com",
|
||||||
|
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Service: &fourOhFour,
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SSH services",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- service: ssh://127.0.0.1
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RDP services",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- service: rdp://127.0.0.1
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SMB services",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- service: smb://127.0.0.1
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Other TCP services",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- service: ftp://127.0.0.1
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "URL isn't necessary if using bastion",
|
name: "URL isn't necessary if using bastion",
|
||||||
args: args{rawYAML: `
|
args: args{rawYAML: `
|
||||||
|
@ -221,7 +315,7 @@ ingress:
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "bastion.foo.com",
|
Hostname: "bastion.foo.com",
|
||||||
Service: &localService{},
|
Service: newBridgeService(),
|
||||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -241,7 +335,7 @@ ingress:
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "bastion.foo.com",
|
Hostname: "bastion.foo.com",
|
||||||
Service: &localService{},
|
Service: newBridgeService(),
|
||||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -409,6 +503,37 @@ func TestFindMatchingRule(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsHTTPService(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
url *url.URL
|
||||||
|
isHTTP bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
url: MustParseURL(t, "http://localhost"),
|
||||||
|
isHTTP: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: MustParseURL(t, "https://127.0.0.1:8000"),
|
||||||
|
isHTTP: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: MustParseURL(t, "ws://localhost"),
|
||||||
|
isHTTP: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: MustParseURL(t, "wss://localhost:8000"),
|
||||||
|
isHTTP: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: MustParseURL(t, "tcp://localhost:9000"),
|
||||||
|
isHTTP: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
assert.Equal(t, test.isHTTP, isHTTPService(test.url))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func mustParsePath(t *testing.T, path string) *regexp.Regexp {
|
func mustParsePath(t *testing.T, path string) *regexp.Regexp {
|
||||||
regexp, err := regexp.Compile(path)
|
regexp, err := regexp.Compile(path)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
gws "github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OriginConnection is a way to stream to a service running on the user's origin.
|
||||||
|
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
|
||||||
|
type OriginConnection interface {
|
||||||
|
// Stream should generally be implemented as a bidirectional io.Copy.
|
||||||
|
Stream(tunnelConn io.ReadWriter)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||||
|
type tcpConnection struct {
|
||||||
|
conn net.Conn
|
||||||
|
streamHandler func(tunnelConn io.ReadWriter, originConn net.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) {
|
||||||
|
tc.streamHandler(tunnelConn, tc.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *tcpConnection) Close() {
|
||||||
|
tc.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
|
||||||
|
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
|
||||||
|
type wsConnection struct {
|
||||||
|
wsConn *gws.Conn
|
||||||
|
resp *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) {
|
||||||
|
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsc *wsConnection) Close() {
|
||||||
|
wsc.resp.Body.Close()
|
||||||
|
wsc.wsConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
|
||||||
|
d := &gws.Dialer{
|
||||||
|
TLSClientConfig: transport.TLSClientConfig,
|
||||||
|
}
|
||||||
|
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &wsConnection{
|
||||||
|
wsConn,
|
||||||
|
resp,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||||
|
type HTTPOriginProxy interface {
|
||||||
|
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||||
|
http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
||||||
|
type StreamBasedOriginProxy interface {
|
||||||
|
EstablishConnection(r *http.Request) (OriginConnection, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return o.transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: TUN-3636: establish connection to origins over UDS
|
||||||
|
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||||
|
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Rewrite the request URL so that it goes to the origin service.
|
||||||
|
req.URL.Host = o.url.Host
|
||||||
|
req.URL.Scheme = o.url.Scheme
|
||||||
|
return o.transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Rewrite the request URL so that it goes to the Hello World server.
|
||||||
|
req.URL.Host = o.server.Addr().String()
|
||||||
|
req.URL.Scheme = "https"
|
||||||
|
return o.transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
||||||
|
req.URL.Host = o.server.Addr().String()
|
||||||
|
req.URL.Scheme = "wss"
|
||||||
|
return newWSConnection(o.transport, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
|
return o.resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||||
|
dest, err := o.destination(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return o.client.connect(r, dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||||
|
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
||||||
|
if jumpDestination == "" {
|
||||||
|
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||||
|
}
|
||||||
|
// Strip scheme and path set by client. Without a scheme
|
||||||
|
// Parsing a hostname and path without scheme might not return an error due to parsing ambiguities
|
||||||
|
if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" {
|
||||||
|
return removePath(jumpURL.Host), nil
|
||||||
|
}
|
||||||
|
return removePath(jumpDestination), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removePath(dest string) string {
|
||||||
|
return strings.SplitN(dest, "/", 2)[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||||
|
return o.client.connect(r, o.dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpClient struct {
|
||||||
|
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
|
||||||
|
conn, err := net.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tcpConnection{
|
||||||
|
conn: conn,
|
||||||
|
streamHandler: c.streamHandler,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -0,0 +1,107 @@
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBridgeServiceDestination(t *testing.T) {
|
||||||
|
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header http.Header
|
||||||
|
expectedDest string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hostname destination",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"localhost"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname destination with port",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"localhost:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname destination with scheme and port",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"ssh://localhost:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full hostname url",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname destination with port and path",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"localhost:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "localhost:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"127.0.0.1"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with port",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"127.0.0.1:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with port and path",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with schem and port",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full ip url",
|
||||||
|
header: http.Header{
|
||||||
|
canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no destination",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := newBridgeService()
|
||||||
|
for _, test := range tests {
|
||||||
|
r := &http.Request{
|
||||||
|
Header: test.header,
|
||||||
|
}
|
||||||
|
dest, err := s.destination(r)
|
||||||
|
if test.wantErr {
|
||||||
|
assert.Error(t, err, "Test %s expects error", test.name)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
|
||||||
|
assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -21,10 +20,8 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OriginService is something a tunnel can proxy traffic to.
|
// originService is something a tunnel can proxy traffic to.
|
||||||
type OriginService interface {
|
type originService interface {
|
||||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
|
||||||
http.RoundTripper
|
|
||||||
String() string
|
String() string
|
||||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||||
|
@ -51,10 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
return o.transport.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||||
d := &gws.Dialer{
|
d := &gws.Dialer{
|
||||||
NetDial: o.transport.Dial,
|
NetDial: o.transport.Dial,
|
||||||
|
@ -65,130 +58,87 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
|
||||||
return d.Dial(reqURL.String(), headers)
|
return d.Dial(reqURL.String(), headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
|
type httpService struct {
|
||||||
type localService struct {
|
url *url.URL
|
||||||
// The URL for the user's origin service
|
|
||||||
RootURL *url.URL
|
|
||||||
// The URL that cloudflared should send requests to.
|
|
||||||
// If this origin requires starting a proxy, this is the proxy's address,
|
|
||||||
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
|
|
||||||
URL *url.URL
|
|
||||||
transport *http.Transport
|
transport *http.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
|
||||||
// Rewrite the request URL so that it goes to the origin service.
|
|
||||||
reqURL.Host = o.URL.Host
|
|
||||||
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
|
|
||||||
return d.Dial(reqURL.String(), headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *localService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
|
||||||
transport, err := newHTTPTransport(o, cfg, log)
|
transport, err := newHTTPTransport(o, cfg, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
o.transport = transport
|
o.transport = transport
|
||||||
|
|
||||||
// Start a proxy if one is needed
|
|
||||||
if staticHost := o.staticHost(); originRequiresProxy(staticHost, cfg) {
|
|
||||||
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
func (o *httpService) String() string {
|
||||||
|
return o.url.String()
|
||||||
|
}
|
||||||
|
|
||||||
// Start a listener for the proxy
|
// bridgeService is like a jump host, the destination is specified by the client
|
||||||
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
|
type bridgeService struct {
|
||||||
listener, err := net.Listen("tcp", proxyAddress)
|
client *tcpClient
|
||||||
if err != nil {
|
}
|
||||||
log.Error().Msgf("Cannot start Websocket Proxy Server: %s", err)
|
|
||||||
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
|
func newBridgeService() *bridgeService {
|
||||||
|
return &bridgeService{
|
||||||
|
client: &tcpClient{},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start the proxy itself
|
func (o *bridgeService) String() string {
|
||||||
wg.Add(1)
|
return "bridge service"
|
||||||
go func() {
|
}
|
||||||
defer wg.Done()
|
|
||||||
streamHandler := websocket.DefaultStreamHandler
|
|
||||||
// This origin's config specifies what type of proxy to start.
|
|
||||||
switch cfg.ProxyType {
|
|
||||||
case socksProxy:
|
|
||||||
log.Info().Msg("SOCKS5 server started")
|
|
||||||
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
|
|
||||||
dialer := socks.NewConnDialer(remoteConn)
|
|
||||||
requestHandler := socks.NewRequestHandler(dialer)
|
|
||||||
socksServer := socks.NewConnectionHandler(requestHandler)
|
|
||||||
|
|
||||||
_ = socksServer.Serve(wsConn)
|
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
}
|
if cfg.ProxyType == socksProxy {
|
||||||
case "":
|
o.client.streamHandler = socks.StreamHandler
|
||||||
log.Debug().Msg("Not starting any websocket proxy")
|
} else {
|
||||||
default:
|
o.client.streamHandler = websocket.DefaultStreamHandler
|
||||||
log.Error().Msgf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
|
|
||||||
}
|
|
||||||
|
|
||||||
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Modify this origin, so that it no longer points at the origin service directly.
|
|
||||||
// Instead, it points at the proxy to the origin service.
|
|
||||||
newURL, err := url.Parse("http://" + listener.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
o.URL = newURL
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *localService) String() string {
|
type singleTCPService struct {
|
||||||
if o.isBastion() {
|
dest string
|
||||||
return "Bastion"
|
client *tcpClient
|
||||||
}
|
|
||||||
return o.URL.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *localService) isBastion() bool {
|
func newSingleTCPService(url *url.URL) *singleTCPService {
|
||||||
return o.URL == nil
|
switch url.Scheme {
|
||||||
}
|
|
||||||
|
|
||||||
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
// Rewrite the request URL so that it goes to the origin service.
|
|
||||||
req.URL.Host = o.URL.Host
|
|
||||||
req.URL.Scheme = o.URL.Scheme
|
|
||||||
return o.transport.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *localService) staticHost() string {
|
|
||||||
|
|
||||||
if o.URL == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
addPortIfMissing := func(uri *url.URL, port int) string {
|
|
||||||
if uri.Port() != "" {
|
|
||||||
return uri.Host
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch o.URL.Scheme {
|
|
||||||
case "ssh":
|
case "ssh":
|
||||||
return addPortIfMissing(o.URL, 22)
|
addPortIfMissing(url, 22)
|
||||||
case "rdp":
|
case "rdp":
|
||||||
return addPortIfMissing(o.URL, 3389)
|
addPortIfMissing(url, 3389)
|
||||||
case "smb":
|
case "smb":
|
||||||
return addPortIfMissing(o.URL, 445)
|
addPortIfMissing(url, 445)
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
|
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
|
||||||
}
|
}
|
||||||
return ""
|
return &singleTCPService{
|
||||||
|
dest: url.Host,
|
||||||
|
client: &tcpClient{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addPortIfMissing(uri *url.URL, port int) {
|
||||||
|
if uri.Port() == "" {
|
||||||
|
uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *singleTCPService) String() string {
|
||||||
|
return o.dest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
|
if cfg.ProxyType == socksProxy {
|
||||||
|
o.client.streamHandler = socks.StreamHandler
|
||||||
|
} else {
|
||||||
|
o.client.streamHandler = websocket.DefaultStreamHandler
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HelloWorld is an OriginService for the built-in Hello World server.
|
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||||
|
@ -228,26 +178,6 @@ func (o *helloWorld) start(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
// Rewrite the request URL so that it goes to the Hello World server.
|
|
||||||
req.URL.Host = o.server.Addr().String()
|
|
||||||
req.URL.Scheme = "https"
|
|
||||||
return o.transport.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
|
||||||
d := &gws.Dialer{
|
|
||||||
TLSClientConfig: o.transport.TLSClientConfig,
|
|
||||||
}
|
|
||||||
reqURL.Host = o.server.Addr().String()
|
|
||||||
reqURL.Scheme = "wss"
|
|
||||||
return d.Dial(reqURL.String(), headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
|
|
||||||
return staticHost != "" || cfg.BastionMode
|
|
||||||
}
|
|
||||||
|
|
||||||
// statusCode is an OriginService that just responds with a given HTTP status.
|
// statusCode is an OriginService that just responds with a given HTTP status.
|
||||||
// Typical use-case is "user wants the catch-all rule to just respond 404".
|
// Typical use-case is "user wants the catch-all rule to just respond 404".
|
||||||
type statusCode struct {
|
type statusCode struct {
|
||||||
|
@ -277,10 +207,6 @@ func (o *statusCode) start(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
|
||||||
return o.resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type NopReadCloser struct{}
|
type NopReadCloser struct{}
|
||||||
|
|
||||||
// Read always returns EOF to signal end of input
|
// Read always returns EOF to signal end of input
|
||||||
|
@ -292,7 +218,7 @@ func (nrc *NopReadCloser) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||||
|
@ -337,19 +263,19 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol
|
||||||
return &httpTransport, nil
|
return &httpTransport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
||||||
type MockOriginService struct {
|
type MockOriginHTTPService struct {
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return mos.Transport.RoundTrip(req)
|
return mos.Transport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mos MockOriginService) String() string {
|
func (mos MockOriginHTTPService) String() string {
|
||||||
return "MockOriginService"
|
return "MockOriginService"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mos MockOriginService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ type Rule struct {
|
||||||
// A (probably local) address. Requests for a hostname which matches this
|
// A (probably local) address. Requests for a hostname which matches this
|
||||||
// rule's hostname pattern will be proxied to the service running on this
|
// rule's hostname pattern will be proxied to the service running on this
|
||||||
// address.
|
// address.
|
||||||
Service OriginService
|
Service originService
|
||||||
|
|
||||||
// Configure the request cloudflared sends to this specific origin.
|
// Configure the request cloudflared sends to this specific origin.
|
||||||
Config OriginRequestConfig
|
Config OriginRequestConfig
|
||||||
|
|
|
@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
Path *regexp.Regexp
|
Path *regexp.Regexp
|
||||||
Service OriginService
|
Service originService
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
requestURL *url.URL
|
requestURL *url.URL
|
||||||
|
|
139
origin/proxy.go
139
origin/proxy.go
|
@ -5,7 +5,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -15,7 +14,6 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
@ -24,15 +22,15 @@ const (
|
||||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type proxy struct {
|
||||||
ingressRules ingress.Ingress
|
ingressRules ingress.Ingress
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
bufferPool *buffer.Pool
|
bufferPool *buffer.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginClient {
|
func NewOriginProxy(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginProxy {
|
||||||
return &client{
|
return &proxy{
|
||||||
ingressRules: ingressRules,
|
ingressRules: ingressRules,
|
||||||
tags: tags,
|
tags: tags,
|
||||||
log: log,
|
log: log,
|
||||||
|
@ -40,36 +38,55 @@ func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
|
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
|
||||||
incrementRequests()
|
incrementRequests()
|
||||||
defer decrementConcurrentRequests()
|
defer decrementConcurrentRequests()
|
||||||
|
|
||||||
cfRay := findCfRayHeader(req)
|
cfRay := findCfRayHeader(req)
|
||||||
lbProbe := isLBProbeRequest(req)
|
lbProbe := isLBProbeRequest(req)
|
||||||
|
|
||||||
c.appendTagHeaders(req)
|
p.appendTagHeaders(req)
|
||||||
rule, ruleNum := c.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||||
c.logRequest(req, cfRay, lbProbe, ruleNum)
|
p.logRequest(req, cfRay, lbProbe, ruleNum)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
resp *http.Response
|
resp *http.Response
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
if isWebsocket {
|
if isWebsocket {
|
||||||
resp, err = c.proxyWebsocket(w, req, rule)
|
go websocket.NewConn(w, p.log).Pinger(req.Context())
|
||||||
|
|
||||||
|
connClosedChan := make(chan struct{})
|
||||||
|
err = p.proxyConnection(connClosedChan, w, req, rule)
|
||||||
|
if err == nil {
|
||||||
|
respHeader := websocket.NewResponseHeader(req)
|
||||||
|
status := http.StatusSwitchingProtocols
|
||||||
|
resp = &http.Response{
|
||||||
|
Status: http.StatusText(status),
|
||||||
|
StatusCode: status,
|
||||||
|
Header: respHeader,
|
||||||
|
ContentLength: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader)
|
||||||
|
<-connClosedChan
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
resp, err = c.proxyHTTP(w, req, rule)
|
resp, err = p.proxyHTTP(w, req, rule)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logRequestError(err, cfRay, ruleNum)
|
p.logRequestError(err, cfRay, ruleNum)
|
||||||
w.WriteErrorResponse()
|
w.WriteErrorResponse()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
|
||||||
|
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||||
if rule.Config.DisableChunkedEncoding {
|
if rule.Config.DisableChunkedEncoding {
|
||||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||||
|
@ -87,73 +104,69 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule
|
||||||
req.Host = hostHeader
|
req.Host = hostHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := rule.Service.RoundTrip(req)
|
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
|
||||||
|
if !ok {
|
||||||
|
p.log.Error().Msgf("%s is not a http service", rule.Service)
|
||||||
|
return nil, fmt.Errorf("Not a http service")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpService.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error proxying request to origin")
|
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
err = w.WriteRespHeaders(resp)
|
err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error writing response header")
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
}
|
}
|
||||||
if connection.IsServerSentEvent(resp.Header) {
|
if connection.IsServerSentEvent(resp.Header) {
|
||||||
c.log.Debug().Msg("Detected Server-Side Events from Origin")
|
p.log.Debug().Msg("Detected Server-Side Events from Origin")
|
||||||
c.writeEventStream(w, resp.Body)
|
p.writeEventStream(w, resp.Body)
|
||||||
} else {
|
} else {
|
||||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||||
// compression generates dictionary on first write
|
// compression generates dictionary on first write
|
||||||
buf := c.bufferPool.Get()
|
buf := p.bufferPool.Get()
|
||||||
defer c.bufferPool.Put(buf)
|
defer p.bufferPool.Put(buf)
|
||||||
_, _ = io.CopyBuffer(w, resp.Body, buf)
|
_, _ = io.CopyBuffer(w, resp.Body, buf)
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
func (p *proxy) proxyConnection(connClosedChan chan struct{},
|
||||||
|
conn io.ReadWriter, req *http.Request, rule *ingress.Rule) error {
|
||||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||||
req.Header.Set("Host", hostHeader)
|
req.Header.Set("Host", hostHeader)
|
||||||
req.Host = hostHeader
|
req.Host = hostHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
dialler, ok := rule.Service.(websocket.Dialler)
|
connectionService, ok := rule.Service.(ingress.StreamBasedOriginProxy)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
|
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
||||||
|
return fmt.Errorf("Not a connection-oriented service")
|
||||||
}
|
}
|
||||||
conn, resp, err := websocket.ClientConnect(req, dialler)
|
originConn, err := connectionService.EstablishConnection(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
serveCtx, cancel := context.WithCancel(req.Context())
|
serveCtx, cancel := context.WithCancel(req.Context())
|
||||||
connClosedChan := make(chan struct{})
|
|
||||||
go func() {
|
go func() {
|
||||||
// serveCtx is done if req is cancelled, or streamWebsocket returns
|
// serveCtx is done if req is cancelled, or streamWebsocket returns
|
||||||
<-serveCtx.Done()
|
<-serveCtx.Done()
|
||||||
_ = conn.Close()
|
originConn.Close()
|
||||||
close(connClosedChan)
|
close(connClosedChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Copy to/from stream to the undelying connection. Use the underlying
|
go func() {
|
||||||
// connection because cloudflared doesn't operate on the message themselves
|
originConn.Stream(conn)
|
||||||
err = c.streamWebsocket(w, conn.UnderlyingConn(), resp)
|
cancel()
|
||||||
cancel()
|
}()
|
||||||
|
|
||||||
// We need to make sure conn is closed before returning, otherwise we might write to conn after Proxy returns
|
|
||||||
<-connClosedChan
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) streamWebsocket(w connection.ResponseWriter, conn net.Conn, resp *http.Response) error {
|
|
||||||
err := w.WriteRespHeaders(resp)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Error writing websocket response header")
|
|
||||||
}
|
|
||||||
websocket.Stream(conn, w)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||||
reader := bufio.NewReader(respBody)
|
reader := bufio.NewReader(respBody)
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
|
@ -164,54 +177,54 @@ func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadC
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) appendTagHeaders(r *http.Request) {
|
func (p *proxy) appendTagHeaders(r *http.Request) {
|
||||||
for _, tag := range c.tags {
|
for _, tag := range p.tags {
|
||||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) {
|
func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) {
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||||
} else if lbProbe {
|
} else if lbProbe {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||||
} else {
|
} else {
|
||||||
c.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
|
p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
|
||||||
}
|
}
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
|
p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum)
|
p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum)
|
||||||
|
|
||||||
if contentLen := r.ContentLength; contentLen == -1 {
|
if contentLen := r.ContentLength; contentLen == -1 {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay)
|
p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay)
|
||||||
} else {
|
} else {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen)
|
p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) {
|
func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) {
|
||||||
responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc()
|
responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc()
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum)
|
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum)
|
||||||
} else if lbProbe {
|
} else if lbProbe {
|
||||||
c.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status)
|
p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status)
|
||||||
} else {
|
} else {
|
||||||
c.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum)
|
p.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum)
|
||||||
}
|
}
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
|
p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
|
||||||
|
|
||||||
if contentLen := r.ContentLength; contentLen == -1 {
|
if contentLen := r.ContentLength; contentLen == -1 {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay)
|
p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay)
|
||||||
} else {
|
} else {
|
||||||
c.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen)
|
p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) logRequestError(err error, cfRay string, ruleNum int) {
|
func (p *proxy) logRequestError(err error, cfRay string, ruleNum int) {
|
||||||
requestErrors.Inc()
|
requestErrors.Inc()
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
c.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err)
|
p.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err)
|
||||||
} else {
|
} else {
|
||||||
c.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err)
|
p.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -14,9 +16,11 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
@ -39,9 +43,9 @@ func newMockHTTPRespWriter() *mockHTTPRespWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
|
func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
w.WriteHeader(resp.StatusCode)
|
w.WriteHeader(status)
|
||||||
for header, val := range resp.Header {
|
for header, val := range header {
|
||||||
w.Header()[header] = val
|
w.Header()[header] = val
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -125,28 +129,28 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||||
errC := make(chan error)
|
errC := make(chan error)
|
||||||
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
|
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
|
||||||
|
|
||||||
client := NewClient(ingressRule, testTags, &log)
|
proxy := NewOriginProxy(ingressRule, testTags, &log)
|
||||||
t.Run("testProxyHTTP", testProxyHTTP(t, client))
|
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
|
||||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, client))
|
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
|
||||||
t.Run("testProxySSE", testProxySSE(t, client))
|
t.Run("testProxySSE", testProxySSE(t, proxy))
|
||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProxyHTTP(t *testing.T, client connection.OriginClient) func(t *testing.T) {
|
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
respWriter := newMockHTTPRespWriter()
|
respWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = client.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, respWriter.Code)
|
assert.Equal(t, http.StatusOK, respWriter.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *testing.T) {
|
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
// WSRoute is a websocket echo handler
|
// WSRoute is a websocket echo handler
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -159,7 +163,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err = client.Proxy(respWriter, req, true)
|
err = proxy.Proxy(respWriter, req, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
||||||
|
@ -169,7 +173,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te
|
||||||
err = wsutil.WriteClientText(writePipe, msg)
|
err = wsutil.WriteClientText(writePipe, msg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ReadServerText reads next data message from rw, considering that caller represents client side.
|
// ReadServerText reads next data message from rw, considering that caller represents proxy side.
|
||||||
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
|
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, msg, returnedMsg)
|
require.Equal(t, msg, returnedMsg)
|
||||||
|
@ -186,7 +190,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.T) {
|
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
pushCount = 50
|
pushCount = 50
|
||||||
|
@ -201,7 +205,7 @@ func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err = client.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, respWriter.Code)
|
require.Equal(t, http.StatusOK, respWriter.Code)
|
||||||
|
@ -258,7 +262,7 @@ func TestProxyMultipleOrigins(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC))
|
require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC))
|
||||||
|
|
||||||
client := NewClient(ingress, testTags, &log)
|
proxy := NewOriginProxy(ingress, testTags, &log)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
url string
|
url string
|
||||||
|
@ -294,7 +298,7 @@ func TestProxyMultipleOrigins(t *testing.T) {
|
||||||
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = client.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, test.expectedStatus, respWriter.Code)
|
assert.Equal(t, test.expectedStatus, respWriter.Code)
|
||||||
|
@ -327,7 +331,7 @@ func TestProxyError(t *testing.T) {
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
Path: nil,
|
Path: nil,
|
||||||
Service: ingress.MockOriginService{
|
Service: ingress.MockOriginHTTPService{
|
||||||
Transport: errorOriginTransport{},
|
Transport: errorOriginTransport{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -336,14 +340,85 @@ func TestProxyError(t *testing.T) {
|
||||||
|
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
|
||||||
client := NewClient(ingress, testTags, &log)
|
proxy := NewOriginProxy(ingress, testTags, &log)
|
||||||
|
|
||||||
respWriter := newMockHTTPRespWriter()
|
respWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = client.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, false)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
|
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
|
||||||
assert.Equal(t, "http response error", respWriter.Body.String())
|
assert.Equal(t, "http response error", respWriter.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyBastionMode(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
|
||||||
|
flagSet.Bool("bastion", true, "")
|
||||||
|
|
||||||
|
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
|
||||||
|
err := cliCtx.Set(config.BastionFlag, "true")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
allowURLFromArgs := false
|
||||||
|
ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errC := make(chan error)
|
||||||
|
|
||||||
|
log := logger.Create(nil)
|
||||||
|
|
||||||
|
ingressRule.StartOrigins(&wg, log, ctx.Done(), errC)
|
||||||
|
|
||||||
|
proxy := NewOriginProxy(ingressRule, testTags, log)
|
||||||
|
|
||||||
|
t.Run("testBastionWebsocket", testBastionWebsocket(proxy))
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
// WSRoute is a websocket echo handler
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
readPipe, _ := io.Pipe()
|
||||||
|
respWriter := newMockWSRespWriter(readPipe)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
msgFromConn := []byte("data from websocket proxy")
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer ln.Close()
|
||||||
|
server, err := ln.Accept()
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn := websocket.NewConn(server, nil)
|
||||||
|
conn.Write(msgFromConn)
|
||||||
|
}()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
|
||||||
|
req.Header.Set(h2mux.CFJumpDestinationHeader, ln.Addr().String())
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err = proxy.Proxy(respWriter, req, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// ReadServerText reads next data message from rw, considering that caller represents proxy side.
|
||||||
|
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
|
||||||
|
if err != io.EOF {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msgFromConn, returnedMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package socks
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -104,3 +105,11 @@ func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Reques
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn) {
|
||||||
|
dialer := NewConnDialer(originConn)
|
||||||
|
requestHandler := NewRequestHandler(dialer)
|
||||||
|
socksServer := NewConnectionHandler(requestHandler)
|
||||||
|
|
||||||
|
socksServer.Serve(tunnelConn)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gobwas "github.com/gobwas/ws"
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Time allowed to write a message to the peer.
|
||||||
|
writeWait = 10 * time.Second
|
||||||
|
|
||||||
|
// Time allowed to read the next pong message from the peer.
|
||||||
|
pongWait = 60 * time.Second
|
||||||
|
|
||||||
|
// Send pings to peer with this period. Must be less than pongWait.
|
||||||
|
pingPeriod = (pongWait * 9) / 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
|
||||||
|
// This is still used by access carrier
|
||||||
|
type GorillaConn struct {
|
||||||
|
*websocket.Conn
|
||||||
|
log *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read will read messages from the websocket connection
|
||||||
|
func (c *GorillaConn) Read(p []byte) (int, error) {
|
||||||
|
_, message, err := c.Conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return copy(p, message), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write will write messages to the websocket connection
|
||||||
|
func (c *GorillaConn) Write(p []byte) (int, error) {
|
||||||
|
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pinger simulates the websocket connection to keep it alive
|
||||||
|
func (c *GorillaConn) pinger(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(pingPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
||||||
|
c.log.Debug().Msgf("failed to send ping message: %s", err)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
rw io.ReadWriter
|
||||||
|
log *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
rw: rw,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read will read messages from the websocket connection
|
||||||
|
func (c *Conn) Read(reader []byte) (int, error) {
|
||||||
|
data, err := wsutil.ReadClientBinary(c.rw)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return copy(reader, data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write will write messages to the websocket connection
|
||||||
|
func (c *Conn) Write(p []byte) (int, error) {
|
||||||
|
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Pinger(ctx context.Context) {
|
||||||
|
pongMessge := wsutil.Message{
|
||||||
|
OpCode: gobwas.OpPong,
|
||||||
|
Payload: []byte{},
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(pingPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
||||||
|
c.log.Err(err).Msgf("failed to write ping message")
|
||||||
|
}
|
||||||
|
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
||||||
|
c.log.Err(err).Msgf("failed to write pong message")
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,7 +2,6 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -16,17 +15,6 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// Time allowed to write a message to the peer.
|
|
||||||
writeWait = 10 * time.Second
|
|
||||||
|
|
||||||
// Time allowed to read the next pong message from the peer.
|
|
||||||
pongWait = 60 * time.Second
|
|
||||||
|
|
||||||
// Send pings to peer with this period. Must be less than pongWait.
|
|
||||||
pingPeriod = (pongWait * 9) / 10
|
|
||||||
)
|
|
||||||
|
|
||||||
var stripWebsocketHeaders = []string{
|
var stripWebsocketHeaders = []string{
|
||||||
"Upgrade",
|
"Upgrade",
|
||||||
"Connection",
|
"Connection",
|
||||||
|
@ -35,70 +23,28 @@ var stripWebsocketHeaders = []string{
|
||||||
"Sec-Websocket-Extensions",
|
"Sec-Websocket-Extensions",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn is a wrapper around the standard gorilla websocket
|
|
||||||
// but implements a ReadWriter
|
|
||||||
type Conn struct {
|
|
||||||
*websocket.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read will read messages from the websocket connection
|
|
||||||
func (c *Conn) Read(p []byte) (int, error) {
|
|
||||||
_, message, err := c.Conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(p, message), nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write will write messages to the websocket connection
|
|
||||||
func (c *Conn) Write(p []byte) (int, error) {
|
|
||||||
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||||
return websocket.IsWebSocketUpgrade(req)
|
return websocket.IsWebSocketUpgrade(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialler is something that can proxy websocket requests.
|
|
||||||
type Dialler interface {
|
|
||||||
Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type defaultDialler struct {
|
|
||||||
tlsConfig *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) {
|
|
||||||
d := &websocket.Dialer{
|
|
||||||
TLSClientConfig: dd.tlsConfig,
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
}
|
|
||||||
return d.Dial(url.String(), header)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||||
// the connection. The response body may not contain the entire response and does
|
// the connection. The response body may not contain the entire response and does
|
||||||
// not need to be closed by the application.
|
// not need to be closed by the application.
|
||||||
func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) {
|
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
|
||||||
req.URL.Scheme = ChangeRequestScheme(req.URL)
|
req.URL.Scheme = ChangeRequestScheme(req.URL)
|
||||||
wsHeaders := websocketHeaders(req)
|
wsHeaders := websocketHeaders(req)
|
||||||
|
|
||||||
if dialler == nil {
|
if dialler == nil {
|
||||||
dialler = new(defaultDialler)
|
dialler = &websocket.Dialer{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
conn, response, err := dialler.Dial(req.URL, wsHeaders)
|
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, response, err
|
return nil, response, err
|
||||||
}
|
}
|
||||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||||
return conn, response, err
|
return conn, response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream copies copy data to & from provided io.ReadWriters.
|
// Stream copies copy data to & from provided io.ReadWriters.
|
||||||
|
@ -121,8 +67,8 @@ func Stream(conn, backendConn io.ReadWriter) {
|
||||||
|
|
||||||
// DefaultStreamHandler is provided to the the standard websocket to origin stream
|
// DefaultStreamHandler is provided to the the standard websocket to origin stream
|
||||||
// This exist to allow SOCKS to deframe data before it gets to the origin
|
// This exist to allow SOCKS to deframe data before it gets to the origin
|
||||||
func DefaultStreamHandler(wsConn *Conn, remoteConn net.Conn, _ http.Header) {
|
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) {
|
||||||
Stream(wsConn, remoteConn)
|
Stream(originConn, remoteConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartProxyServer will start a websocket server that will decode
|
// StartProxyServer will start a websocket server that will decode
|
||||||
|
@ -132,7 +78,7 @@ func StartProxyServer(
|
||||||
listener net.Listener,
|
listener net.Listener,
|
||||||
staticHost string,
|
staticHost string,
|
||||||
shutdownC <-chan struct{},
|
shutdownC <-chan struct{},
|
||||||
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header),
|
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn),
|
||||||
) error {
|
) error {
|
||||||
upgrader := websocket.Upgrader{
|
upgrader := websocket.Upgrader{
|
||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
|
@ -159,7 +105,7 @@ type handler struct {
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
staticHost string
|
staticHost string
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header)
|
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -192,14 +138,20 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||||
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
|
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
|
||||||
done := make(chan struct{})
|
gorillaConn := &GorillaConn{conn, h.log}
|
||||||
go pinger(h.log, conn, done)
|
go gorillaConn.pinger(r.Context())
|
||||||
defer func() {
|
defer conn.Close()
|
||||||
done <- struct{}{}
|
|
||||||
_ = conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
h.streamHandler(&Conn{conn}, stream, r.Header)
|
h.streamHandler(gorillaConn, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
||||||
|
func NewResponseHeader(req *http.Request) http.Header {
|
||||||
|
header := http.Header{}
|
||||||
|
header.Add("Connection", "Upgrade")
|
||||||
|
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
|
||||||
|
header.Add("Upgrade", "websocket")
|
||||||
|
return header
|
||||||
}
|
}
|
||||||
|
|
||||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||||
|
@ -246,19 +198,3 @@ func ChangeRequestScheme(reqURL *url.URL) string {
|
||||||
return reqURL.Scheme
|
return reqURL.Scheme
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pinger simulates the websocket connection to keep it alive
|
|
||||||
func pinger(logger *zerolog.Logger, ws *websocket.Conn, done chan struct{}) {
|
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
|
||||||
logger.Debug().Msgf("failed to send ping message: %s", err)
|
|
||||||
}
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
|
gws "github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
@ -78,7 +78,7 @@ func TestServe(t *testing.T) {
|
||||||
|
|
||||||
tlsConfig := websocketClientTLSConfig(t)
|
tlsConfig := websocketClientTLSConfig(t)
|
||||||
assert.NotNil(t, tlsConfig)
|
assert.NotNil(t, tlsConfig)
|
||||||
d := defaultDialler{tlsConfig: tlsConfig}
|
d := gws.Dialer{TLSClientConfig: tlsConfig}
|
||||||
conn, resp, err := ClientConnect(req, &d)
|
conn, resp, err := ClientConnect(req, &d)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||||
|
|
Loading…
Reference in New Issue