TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input
This change extracts the need for EstablishConnection to know about a request's entire context. It also removes the concern of populating the http.Response from EstablishConnection's responsibilities.
This commit is contained in:
parent
f1b57526b3
commit
d678584d89
|
@ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a single origin service from the CLI/config.
|
// Get a single origin service from the CLI/config.
|
||||||
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
|
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
|
||||||
if c.IsSet("hello-world") {
|
if c.IsSet("hello-world") {
|
||||||
return new(helloWorld), nil
|
return new(helloWorld), nil
|
||||||
}
|
}
|
||||||
|
@ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
rules := make([]Rule, len(ingress))
|
rules := make([]Rule, len(ingress))
|
||||||
for i, r := range ingress {
|
for i, r := range ingress {
|
||||||
cfg := setConfig(defaults, r.OriginRequest)
|
cfg := setConfig(defaults, r.OriginRequest)
|
||||||
var service originService
|
var service OriginService
|
||||||
|
|
||||||
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
||||||
// No validation necessary for unix socket filepath services
|
// No validation necessary for unix socket filepath services
|
||||||
|
|
|
@ -6,9 +6,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
|
||||||
|
|
||||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||||
type StreamBasedOriginProxy interface {
|
type StreamBasedOriginProxy interface {
|
||||||
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
EstablishConnection(dest string) (OriginConnection, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
return o.resp, nil
|
return o.resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||||
dest, err := getRequestHost(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
conn, err := net.Dial("tcp", dest)
|
conn, err := net.Dial("tcp", dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
originConn := &tcpConnection{
|
originConn := &tcpConnection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
}
|
}
|
||||||
resp := &http.Response{
|
return originConn, nil
|
||||||
Status: switchingProtocolText,
|
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
|
||||||
ContentLength: -1,
|
|
||||||
}
|
|
||||||
return originConn, resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRequestHost returns the host of the http.Request.
|
func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||||
func getRequestHost(r *http.Request) (string, error) {
|
|
||||||
if r.Host != "" {
|
|
||||||
return r.Host, nil
|
|
||||||
}
|
|
||||||
if r.URL != nil {
|
|
||||||
return r.URL.Host, nil
|
|
||||||
}
|
|
||||||
return "", errors.New("host not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
|
||||||
var err error
|
var err error
|
||||||
dest := o.dest
|
if !o.isBastion {
|
||||||
if o.isBastion {
|
dest = o.dest
|
||||||
dest, err = carrier.ResolveBastionDest(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", dest)
|
conn, err := net.Dial("tcp", dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
originConn := &tcpOverWSConnection{
|
originConn := &tcpOverWSConnection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
streamHandler: o.streamHandler,
|
streamHandler: o.streamHandler,
|
||||||
}
|
}
|
||||||
resp := &http.Response{
|
return originConn, nil
|
||||||
Status: switchingProtocolText,
|
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
|
||||||
Header: websocket.NewResponseHeader(r),
|
|
||||||
ContentLength: -1,
|
|
||||||
}
|
|
||||||
return originConn, resp, nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||||
originConn := o.conn
|
return o.conn, nil
|
||||||
resp := &http.Response{
|
|
||||||
Status: switchingProtocolText,
|
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
|
||||||
Header: websocket.NewResponseHeader(r),
|
|
||||||
ContentLength: -1,
|
|
||||||
}
|
|
||||||
return originConn, resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,20 +17,6 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
|
|
||||||
// the expected response
|
|
||||||
func assertEstablishConnectionResponse(t *testing.T,
|
|
||||||
originProxy StreamBasedOriginProxy,
|
|
||||||
req *http.Request,
|
|
||||||
expectHeader http.Header,
|
|
||||||
) {
|
|
||||||
_, resp, err := originProxy.EstablishConnection(req)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, switchingProtocolText, resp.Status)
|
|
||||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
||||||
assert.Equal(t, expectHeader, resp.Header)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
|
|
||||||
|
|
||||||
originListener.Close()
|
originListener.Close()
|
||||||
<-listenerClosed
|
<-listenerClosed
|
||||||
|
|
||||||
|
@ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Origin not listening for new connection, should return an error
|
// Origin not listening for new connection, should return an error
|
||||||
_, resp, err := rawTCPService.EstablishConnection(req)
|
_, err = rawTCPService.EstablishConnection(req.URL.String())
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
|
@ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
bastionReq := baseReq.Clone(context.Background())
|
bastionReq := baseReq.Clone(context.Background())
|
||||||
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||||
|
|
||||||
expectHeader := http.Header{
|
|
||||||
"Connection": {"Upgrade"},
|
|
||||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
|
||||||
"Upgrade": {"websocket"},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
testCase string
|
testCase string
|
||||||
service *tcpOverWSService
|
service *tcpOverWSService
|
||||||
|
@ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.testCase, func(t *testing.T) {
|
t.Run(test.testCase, func(t *testing.T) {
|
||||||
if test.expectErr {
|
if test.expectErr {
|
||||||
_, resp, err := test.service.EstablishConnection(test.req)
|
bastionHost, _ := carrier.ResolveBastionDest(test.req)
|
||||||
|
_, err := test.service.EstablishConnection(bastionHost)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, resp)
|
|
||||||
} else {
|
|
||||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
|
|
||||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||||
// Origin not listening for new connection, should return an error
|
// Origin not listening for new connection, should return an error
|
||||||
_, resp, err := service.EstablishConnection(bastionReq)
|
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
|
||||||
|
_, err := service.EstablishConnection(bastionHost)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, resp)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,8 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// originService is something a tunnel can proxy traffic to.
|
// OriginService is something a tunnel can proxy traffic to.
|
||||||
type originService interface {
|
type OriginService interface {
|
||||||
String() string
|
String() string
|
||||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||||
|
@ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||||
|
|
|
@ -17,7 +17,7 @@ type Rule struct {
|
||||||
// A (probably local) address. Requests for a hostname which matches this
|
// A (probably local) address. Requests for a hostname which matches this
|
||||||
// rule's hostname pattern will be proxied to the service running on this
|
// rule's hostname pattern will be proxied to the service running on this
|
||||||
// address.
|
// address.
|
||||||
Service originService
|
Service OriginService
|
||||||
|
|
||||||
// Configure the request cloudflared sends to this specific origin.
|
// Configure the request cloudflared sends to this specific origin.
|
||||||
Config OriginRequestConfig
|
Config OriginRequestConfig
|
||||||
|
|
|
@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
Path *regexp.Regexp
|
Path *regexp.Regexp
|
||||||
Service originService
|
Service OriginService
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
requestURL *url.URL
|
requestURL *url.URL
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -33,6 +34,8 @@ type proxy struct {
|
||||||
bufferPool *bufferPool
|
bufferPool *bufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||||
|
|
||||||
func NewOriginProxy(
|
func NewOriginProxy(
|
||||||
ingressRules ingress.Ingress,
|
ingressRules ingress.Ingress,
|
||||||
warpRouting *ingress.WarpRoutingService,
|
warpRouting *ingress.WarpRoutingService,
|
||||||
|
@ -71,7 +74,13 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
lbProbe: lbProbe,
|
lbProbe: lbProbe,
|
||||||
rule: ingress.ServiceWarpRouting,
|
rule: ingress.ServiceWarpRouting,
|
||||||
}
|
}
|
||||||
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
|
|
||||||
|
host, err := getRequestHost(req)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil {
|
||||||
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
|
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -97,7 +106,11 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case ingress.StreamBasedOriginProxy:
|
case ingress.StreamBasedOriginProxy:
|
||||||
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
|
dest, err := getDestFromRule(rule, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil {
|
||||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||||
p.logRequestError(err, cfRay, rule, srv)
|
p.logRequestError(err, cfRay, rule, srv)
|
||||||
return err
|
return err
|
||||||
|
@ -105,10 +118,29 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
|
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
||||||
|
switch rule.Service.String() {
|
||||||
|
case ingress.ServiceBastion:
|
||||||
|
return carrier.ResolveBastionDest(req)
|
||||||
|
default:
|
||||||
|
return rule.Service.String(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRequestHost returns the host of the http.Request.
|
||||||
|
func getRequestHost(r *http.Request) (string, error) {
|
||||||
|
if r.Host != "" {
|
||||||
|
return r.Host, nil
|
||||||
|
}
|
||||||
|
if r.URL != nil {
|
||||||
|
return r.URL.Host, nil
|
||||||
|
}
|
||||||
|
return "", errors.New("host not set in incoming request")
|
||||||
|
}
|
||||||
|
|
||||||
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
||||||
srv = ing.Rules[ruleNum].Service.String()
|
srv = ing.Rules[ruleNum].Service.String()
|
||||||
if ing.IsSingleRule() {
|
if ing.IsSingleRule() {
|
||||||
|
@ -191,16 +223,24 @@ func (p *proxy) proxyHTTPRequest(
|
||||||
func (p *proxy) proxyStreamRequest(
|
func (p *proxy) proxyStreamRequest(
|
||||||
serveCtx context.Context,
|
serveCtx context.Context,
|
||||||
w connection.ResponseWriter,
|
w connection.ResponseWriter,
|
||||||
|
dest string,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
connectionProxy ingress.StreamBasedOriginProxy,
|
connectionProxy ingress.StreamBasedOriginProxy,
|
||||||
fields logFields,
|
fields logFields,
|
||||||
) error {
|
) error {
|
||||||
originConn, resp, err := connectionProxy.EstablishConnection(req)
|
originConn, err := connectionProxy.EstablishConnection(dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp.Body != nil {
|
|
||||||
defer resp.Body.Close()
|
resp := &http.Response{
|
||||||
|
Status: switchingProtocolText,
|
||||||
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
ContentLength: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
|
||||||
|
resp.Header = websocket.NewResponseHeader(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
||||||
|
|
Loading…
Reference in New Issue