From b06fe0fc5f34718b0cd62a2ef45675fddf1a8782 Mon Sep 17 00:00:00 2001 From: Nuno Diegues Date: Fri, 18 Jun 2021 12:21:11 +0100 Subject: [PATCH] TUN-4571: Fix proxying to unix sockets when using HTTP2 transport to Cloudflare Edge --- component-tests/test_reconnect.py | 3 ++ connection/http2.go | 16 ++++++ origin/proxy_test.go | 87 ++++++++++++++++++++++++------- 3 files changed, 86 insertions(+), 20 deletions(-) diff --git a/component-tests/test_reconnect.py b/component-tests/test_reconnect.py index 5a0020f1..15af88ca 100644 --- a/component-tests/test_reconnect.py +++ b/component-tests/test_reconnect.py @@ -1,7 +1,9 @@ #!/usr/bin/env python import copy +import platform from time import sleep +import pytest from flaky import flaky from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected @@ -15,6 +17,7 @@ class TestReconnect: "stdin-control": True, } + @pytest.mark.skipif(platform.system() == "Windows", reason=f"Currently buggy on Windows TUN-4584") def test_named_reconnect(self, tmp_path, component_tests_config): config = component_tests_config(self.extra_config) with start_cloudflared(tmp_path, config, new_process=True, allow_input=True, capture_output=False) as cloudflared: diff --git a/connection/http2.go b/connection/http2.go index 54f3c667..ea81fdbf 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -98,6 +98,8 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer c.activeRequestsWG.Done() connType := determineHTTP2Type(r) + handleMissingRequestParts(connType, r) + respWriter, err := newHTTP2RespWriter(r, w, connType) if err != nil { c.observer.log.Error().Msg(err.Error()) @@ -255,6 +257,20 @@ func determineHTTP2Type(r *http.Request) Type { } } +func handleMissingRequestParts(connType Type, r *http.Request) { + if connType == TypeHTTP { + // http library has no guarantees that we receive a filled URL. If not, then we fill it, as we reuse the request + // for proxying. We use the same values as we used to in h2mux. For proxying they should not matter since we + // control the dialer on every egress proxied. + if len(r.URL.Scheme) == 0 { + r.URL.Scheme = "http" + } + if len(r.URL.Host) == 0 { + r.URL.Host = "localhost:8080" + } + } +} + func isControlStreamUpgrade(r *http.Request) bool { return r.Header.Get(InternalUpgradeHeader) == ControlStreamUpgrade } diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 1af9add4..b7fecbed 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -6,9 +6,11 @@ import ( "flag" "fmt" "io" + "io/ioutil" "net" "net/http" "net/http/httptest" + "os" "sync" "testing" "time" @@ -273,26 +275,7 @@ func TestProxyMultipleOrigins(t *testing.T) { }, } - ingress, err := ingress.ParseIngress(&config.Configuration{ - TunnelID: t.Name(), - Ingress: unvalidatedIngress, - }) - require.NoError(t, err) - - log := zerolog.Nop() - - ctx, cancel := context.WithCancel(context.Background()) - errC := make(chan error) - var wg sync.WaitGroup - require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) - - proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) - - tests := []struct { - url string - expectedStatus int - expectedBody []byte - }{ + tests := []MultipleIngressTest{ { url: "http://api.example.com", expectedStatus: http.StatusCreated, @@ -317,6 +300,31 @@ func TestProxyMultipleOrigins(t *testing.T) { }, } + runIngressTestScenarios(t, unvalidatedIngress, tests) +} + +type MultipleIngressTest struct { + url string + expectedStatus int + expectedBody []byte +} + +func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.UnvalidatedIngressRule, tests []MultipleIngressTest) { + ingress, err := ingress.ParseIngress(&config.Configuration{ + TunnelID: t.Name(), + Ingress: unvalidatedIngress, + }) + require.NoError(t, err) + + log := zerolog.Nop() + + ctx, cancel := context.WithCancel(context.Background()) + errC := make(chan error) + var wg sync.WaitGroup + require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) + + proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) + for _, test := range tests { responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, test.url, nil) @@ -633,6 +641,45 @@ func TestConnections(t *testing.T) { } } +func TestUnixSocketOrigin(t *testing.T) { + file, err := ioutil.TempFile("", "unix.sock") + require.NoError(t, err) + os.Remove(file.Name()) // remove the file since binding the socket expects to create it + + l, err := net.Listen("unix", file.Name()) + require.NoError(t, err) + defer l.Close() + defer os.Remove(file.Name()) + + api := &httptest.Server{ + Listener: l, + Config: &http.Server{Handler: mockAPI{}}, + } + api.Start() + defer api.Close() + + unvalidatedIngress := []config.UnvalidatedIngressRule{ + { + Hostname: "unix.example.com", + Service: "unix:" + file.Name(), + }, + { + Hostname: "*", + Service: "http_status:404", + }, + } + + tests := []MultipleIngressTest{ + { + url: "http://unix.example.com", + expectedStatus: http.StatusCreated, + expectedBody: []byte("Created"), + }, + } + + runIngressTestScenarios(t, unvalidatedIngress, tests) +} + type requestBody struct { pw *io.PipeWriter pr *io.PipeReader