diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index f3638ea7..5afcb8d9 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -249,7 +249,10 @@ func TestConcurrentUpdateAndRead(t *testing.T) { } ) - orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + orchestrator, err := NewOrchestrator(ctx, initConfig, testTags, &testLogger) require.NoError(t, err) updateWithValidation(t, orchestrator, 1, configJSONV1) @@ -265,8 +268,9 @@ func TestConcurrentUpdateAndRead(t *testing.T) { wg.Add(1) go func(i int, originProxy connection.OriginProxy) { defer wg.Done() - resp, err := proxyHTTP(t, originProxy, hostname) - require.NoError(t, err) + resp, err := proxyHTTP(originProxy, hostname) + require.NoError(t, err, "proxyHTTP %d failed %v", i, err) + defer resp.Body.Close() var warpRoutingDisabled bool // The response can be from initOrigin, http_status:204 or http_status:418 @@ -290,7 +294,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { // Once we have originProxy, it won't be changed by configuration updates. // We can infer the version by the ProxyHTTP response code pr, pw := io.Pipe() - // concurrentRespWriter makes sure ResponseRecorder is not read/write concurrently, and read waits for the first write + w := newRespReadWriteFlusher() // Write TCP message and make sure it's echo back. This has to be done in a go routune since ProxyTCP doesn't @@ -303,7 +307,14 @@ func TestConcurrentUpdateAndRead(t *testing.T) { tcpEyeball(t, pw, tcpBody, w) }() } - proxyTCP(t, originProxy, tcpOrigin.Addr().String(), w, pr, warpRoutingDisabled) + + err = proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), w, pr) + if warpRoutingDisabled { + require.Error(t, err, "expect proxyTCP %d to return error", i) + } else { + require.NoError(t, err, "proxyTCP %d failed %v", i, err) + } + }(i, originProxy) if i == concurrentRequests/4 { @@ -319,6 +330,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + // Makes sure v2 is applied before v3 <-appliedV2 updateWithValidation(t, orchestrator, 3, configJSONV3) }() @@ -328,14 +340,18 @@ func TestConcurrentUpdateAndRead(t *testing.T) { wg.Wait() } -func proxyHTTP(t *testing.T, originProxy connection.OriginProxy, hostname string) (*http.Response, error) { +func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil) - require.NoError(t, err) + if err != nil { + return nil, err + } w := httptest.NewRecorder() log := zerolog.Nop() respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeHTTP, &log) - require.NoError(t, err) + if err != nil { + return nil, err + } err = originProxy.ProxyHTTP(respWriter, req, false) if err != nil { @@ -356,13 +372,17 @@ func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWri require.Equal(t, writeN, n) } -func proxyTCP(t *testing.T, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser, expectErr bool) { +func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser) error { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originAddr), reqBody) - require.NoError(t, err) + if err != nil { + return err + } log := zerolog.Nop() respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeTCP, &log) - require.NoError(t, err) + if err != nil { + return err + } tcpReq := &connection.TCPRequest{ Dest: originAddr, @@ -370,12 +390,8 @@ func proxyTCP(t *testing.T, originProxy connection.OriginProxy, originAddr strin LBProbe: false, } rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req) - if expectErr { - require.Error(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq)) - return - } - require.NoError(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq)) + return originProxy.ProxyTCP(ctx, rws, tcpReq) } func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) { @@ -471,7 +487,7 @@ func TestClosePreviousProxies(t *testing.T) { originProxyV1, err := orchestrator.GetOriginProxy() require.NoError(t, err) - resp, err := proxyHTTP(t, originProxyV1, hostname) + resp, err := proxyHTTP(originProxyV1, hostname) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -479,12 +495,12 @@ func TestClosePreviousProxies(t *testing.T) { originProxyV2, err := orchestrator.GetOriginProxy() require.NoError(t, err) - resp, err = proxyHTTP(t, originProxyV2, hostname) + resp, err = proxyHTTP(originProxyV2, hostname) require.NoError(t, err) require.Equal(t, http.StatusTeapot, resp.StatusCode) // The hello-world server in config v1 should have been stopped - resp, err = proxyHTTP(t, originProxyV1, hostname) + resp, err = proxyHTTP(originProxyV1, hostname) require.Error(t, err) require.Nil(t, resp) @@ -495,7 +511,7 @@ func TestClosePreviousProxies(t *testing.T) { require.NoError(t, err) require.NotEqual(t, originProxyV1, originProxyV3) - resp, err = proxyHTTP(t, originProxyV3, hostname) + resp, err = proxyHTTP(originProxyV3, hostname) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -504,7 +520,7 @@ func TestClosePreviousProxies(t *testing.T) { // Wait for proxies to shutdown time.Sleep(time.Millisecond * 10) - resp, err = proxyHTTP(t, originProxyV3, hostname) + resp, err = proxyHTTP(originProxyV3, hostname) require.Error(t, err) require.Nil(t, resp) } @@ -553,6 +569,9 @@ func TestPersistentConnection(t *testing.T) { tcpReqReader, tcpReqWriter := io.Pipe() tcpRespReadWriter := newRespReadWriteFlusher() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var wg sync.WaitGroup wg.Add(3) // Start TCP origin @@ -570,7 +589,7 @@ func TestPersistentConnection(t *testing.T) { // Simulate cloudflared recieving a TCP connection go func() { defer wg.Done() - proxyTCP(t, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader, false) + require.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader)) }() // Simulate cloudflared recieving a WS connection go func() {