TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input
This change extracts the need for EstablishConnection to know about a request's entire context. It also removes the concern of populating the http.Response from EstablishConnection's responsibilities.
This commit is contained in:
parent
f1b57526b3
commit
d678584d89
|
@ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService {
|
|||
}
|
||||
|
||||
// Get a single origin service from the CLI/config.
|
||||
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
|
||||
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
|
||||
if c.IsSet("hello-world") {
|
||||
return new(helloWorld), nil
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
|||
rules := make([]Rule, len(ingress))
|
||||
for i, r := range ingress {
|
||||
cfg := setConfig(defaults, r.OriginRequest)
|
||||
var service originService
|
||||
var service OriginService
|
||||
|
||||
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
||||
// No validation necessary for unix socket filepath services
|
||||
|
|
|
@ -6,9 +6,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
|
|||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
||||
EstablishConnection(dest string) (OriginConnection, error)
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
|||
return o.resp, nil
|
||||
}
|
||||
|
||||
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := getRequestHost(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
originConn := &tcpConnection{
|
||||
conn: conn,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
return originConn, nil
|
||||
}
|
||||
|
||||
// getRequestHost returns the host of the http.Request.
|
||||
func getRequestHost(r *http.Request) (string, error) {
|
||||
if r.Host != "" {
|
||||
return r.Host, nil
|
||||
}
|
||||
if r.URL != nil {
|
||||
return r.URL.Host, nil
|
||||
}
|
||||
return "", errors.New("host not found")
|
||||
}
|
||||
|
||||
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||
var err error
|
||||
dest := o.dest
|
||||
if o.isBastion {
|
||||
dest, err = carrier.ResolveBastionDest(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !o.isBastion {
|
||||
dest = o.dest
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
originConn := &tcpOverWSConnection{
|
||||
conn: conn,
|
||||
streamHandler: o.streamHandler,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: websocket.NewResponseHeader(r),
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
return originConn, nil
|
||||
|
||||
}
|
||||
|
||||
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
originConn := o.conn
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: websocket.NewResponseHeader(r),
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
|
||||
return o.conn, nil
|
||||
}
|
||||
|
|
|
@ -17,20 +17,6 @@ import (
|
|||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
|
||||
// the expected response
|
||||
func assertEstablishConnectionResponse(t *testing.T,
|
||||
originProxy StreamBasedOriginProxy,
|
||||
req *http.Request,
|
||||
expectHeader http.Header,
|
||||
) {
|
||||
_, resp, err := originProxy.EstablishConnection(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, switchingProtocolText, resp.Status)
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
assert.Equal(t, expectHeader, resp.Header)
|
||||
}
|
||||
|
||||
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
@ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
|||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
|
||||
|
||||
originListener.Close()
|
||||
<-listenerClosed
|
||||
|
||||
|
@ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, resp, err := rawTCPService.EstablishConnection(req)
|
||||
_, err = rawTCPService.EstablishConnection(req.URL.String())
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
}
|
||||
|
||||
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
|
@ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|||
bastionReq := baseReq.Clone(context.Background())
|
||||
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testCase string
|
||||
service *tcpOverWSService
|
||||
|
@ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(test.testCase, func(t *testing.T) {
|
||||
if test.expectErr {
|
||||
_, resp, err := test.service.EstablishConnection(test.req)
|
||||
bastionHost, _ := carrier.ResolveBastionDest(test.req)
|
||||
_, err := test.service.EstablishConnection(bastionHost)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|||
|
||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, resp, err := service.EstablishConnection(bastionReq)
|
||||
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
|
||||
_, err := service.EstablishConnection(bastionHost)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ import (
|
|||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
// originService is something a tunnel can proxy traffic to.
|
||||
type originService interface {
|
||||
// OriginService is something a tunnel can proxy traffic to.
|
||||
type OriginService interface {
|
||||
String() string
|
||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||
|
@ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||
|
|
|
@ -17,7 +17,7 @@ type Rule struct {
|
|||
// A (probably local) address. Requests for a hostname which matches this
|
||||
// rule's hostname pattern will be proxied to the service running on this
|
||||
// address.
|
||||
Service originService
|
||||
Service OriginService
|
||||
|
||||
// Configure the request cloudflared sends to this specific origin.
|
||||
Config OriginRequestConfig
|
||||
|
|
|
@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
|
|||
type fields struct {
|
||||
Hostname string
|
||||
Path *regexp.Regexp
|
||||
Service originService
|
||||
Service OriginService
|
||||
}
|
||||
type args struct {
|
||||
requestURL *url.URL
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -33,6 +34,8 @@ type proxy struct {
|
|||
bufferPool *bufferPool
|
||||
}
|
||||
|
||||
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
|
||||
func NewOriginProxy(
|
||||
ingressRules ingress.Ingress,
|
||||
warpRouting *ingress.WarpRoutingService,
|
||||
|
@ -71,7 +74,13 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
lbProbe: lbProbe,
|
||||
rule: ingress.ServiceWarpRouting,
|
||||
}
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
|
||||
|
||||
host, err := getRequestHost(req)
|
||||
if err != nil {
|
||||
err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err)
|
||||
return err
|
||||
}
|
||||
if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil {
|
||||
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
|
||||
return err
|
||||
}
|
||||
|
@ -97,7 +106,11 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
return nil
|
||||
|
||||
case ingress.StreamBasedOriginProxy:
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
|
||||
dest, err := getDestFromRule(rule, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil {
|
||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||
p.logRequestError(err, cfRay, rule, srv)
|
||||
return err
|
||||
|
@ -105,10 +118,29 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
||||
switch rule.Service.String() {
|
||||
case ingress.ServiceBastion:
|
||||
return carrier.ResolveBastionDest(req)
|
||||
default:
|
||||
return rule.Service.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestHost returns the host of the http.Request.
|
||||
func getRequestHost(r *http.Request) (string, error) {
|
||||
if r.Host != "" {
|
||||
return r.Host, nil
|
||||
}
|
||||
if r.URL != nil {
|
||||
return r.URL.Host, nil
|
||||
}
|
||||
return "", errors.New("host not set in incoming request")
|
||||
}
|
||||
|
||||
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
||||
srv = ing.Rules[ruleNum].Service.String()
|
||||
if ing.IsSingleRule() {
|
||||
|
@ -191,16 +223,24 @@ func (p *proxy) proxyHTTPRequest(
|
|||
func (p *proxy) proxyStreamRequest(
|
||||
serveCtx context.Context,
|
||||
w connection.ResponseWriter,
|
||||
dest string,
|
||||
req *http.Request,
|
||||
connectionProxy ingress.StreamBasedOriginProxy,
|
||||
fields logFields,
|
||||
) error {
|
||||
originConn, resp, err := connectionProxy.EstablishConnection(req)
|
||||
originConn, err := connectionProxy.EstablishConnection(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
ContentLength: -1,
|
||||
}
|
||||
|
||||
if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
|
||||
resp.Header = websocket.NewResponseHeader(req)
|
||||
}
|
||||
|
||||
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
||||
|
|
Loading…
Reference in New Issue