TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input

This change extracts the need for EstablishConnection to know about a
request's entire context. It also removes the concern of populating the
http.Response from EstablishConnection's responsibilities.
This commit is contained in:
Sudarsan Reddy 2021-07-01 19:30:26 +01:00
parent f1b57526b3
commit d678584d89
7 changed files with 69 additions and 94 deletions

View File

@ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService {
}
// Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
if c.IsSet("hello-world") {
return new(helloWorld), nil
}
@ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
rules := make([]Rule, len(ingress))
for i, r := range ingress {
cfg := setConfig(defaults, r.OriginRequest)
var service originService
var service OriginService
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
// No validation necessary for unix socket filepath services

View File

@ -6,9 +6,6 @@ import (
"net/http"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/websocket"
)
var (
@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
EstablishConnection(dest string) (OriginConnection, error)
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := getRequestHost(r)
if err != nil {
return nil, nil, err
}
func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
return nil, err
}
originConn := &tcpConnection{
conn: conn,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
return originConn, nil
}
// getRequestHost returns the host of the http.Request.
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
}
if r.URL != nil {
return r.URL.Host, nil
}
return "", errors.New("host not found")
}
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
var err error
dest := o.dest
if o.isBastion {
dest, err = carrier.ResolveBastionDest(r)
if err != nil {
return nil, nil, err
}
if !o.isBastion {
dest = o.dest
}
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
return nil, err
}
originConn := &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
return originConn, nil
}
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
originConn := o.conn
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
return o.conn, nil
}

View File

@ -17,20 +17,6 @@ import (
"github.com/cloudflare/cloudflared/websocket"
)
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
// the expected response
func assertEstablishConnectionResponse(t *testing.T,
originProxy StreamBasedOriginProxy,
req *http.Request,
expectHeader http.Header,
) {
_, resp, err := originProxy.EstablishConnection(req)
assert.NoError(t, err)
assert.Equal(t, switchingProtocolText, resp.Status)
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
assert.Equal(t, expectHeader, resp.Header)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
originListener.Close()
<-listenerClosed
@ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err)
// Origin not listening for new connection, should return an error
_, resp, err := rawTCPService.EstablishConnection(req)
_, err = rawTCPService.EstablishConnection(req.URL.String())
require.Error(t, err)
require.Nil(t, resp)
}
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
@ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
tests := []struct {
testCase string
service *tcpOverWSService
@ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
_, resp, err := test.service.EstablishConnection(test.req)
bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(bastionHost)
assert.Error(t, err)
assert.Nil(t, resp)
} else {
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
}
})
}
@ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
_, resp, err := service.EstablishConnection(bastionReq)
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(bastionHost)
assert.Error(t, err)
assert.Nil(t, resp)
}
}

View File

@ -20,8 +20,8 @@ import (
"github.com/cloudflare/cloudflared/tlsconfig"
)
// originService is something a tunnel can proxy traffic to.
type originService interface {
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
@ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error {
return nil
}
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool")

View File

@ -17,7 +17,7 @@ type Rule struct {
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service originService
Service OriginService
// Configure the request cloudflared sends to this specific origin.
Config OriginRequestConfig

View File

@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service originService
Service OriginService
}
type args struct {
requestURL *url.URL

View File

@ -12,6 +12,7 @@ import (
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -33,6 +34,8 @@ type proxy struct {
bufferPool *bufferPool
}
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
func NewOriginProxy(
ingressRules ingress.Ingress,
warpRouting *ingress.WarpRoutingService,
@ -71,7 +74,13 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting,
}
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
host, err := getRequestHost(req)
if err != nil {
err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err)
return err
}
if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
return err
}
@ -97,7 +106,11 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return nil
case ingress.StreamBasedOriginProxy:
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
dest, err := getDestFromRule(rule, req)
if err != nil {
return err
}
if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
@ -105,10 +118,29 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return nil
default:
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
}
}
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() {
case ingress.ServiceBastion:
return carrier.ResolveBastionDest(req)
default:
return rule.Service.String(), nil
}
}
// getRequestHost returns the host of the http.Request.
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
}
if r.URL != nil {
return r.URL.Host, nil
}
return "", errors.New("host not set in incoming request")
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() {
@ -191,16 +223,24 @@ func (p *proxy) proxyHTTPRequest(
func (p *proxy) proxyStreamRequest(
serveCtx context.Context,
w connection.ResponseWriter,
dest string,
req *http.Request,
connectionProxy ingress.StreamBasedOriginProxy,
fields logFields,
) error {
originConn, resp, err := connectionProxy.EstablishConnection(req)
originConn, err := connectionProxy.EstablishConnection(dest)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
resp.Header = websocket.NewResponseHeader(req)
}
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {