diff --git a/quic/datagram_test.go b/quic/datagram_test.go index 69bb0b71..61801034 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -10,6 +10,7 @@ import ( "fmt" "math/big" "net/netip" + "sync" "testing" "time" @@ -121,9 +122,15 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi logger := zerolog.Nop() - errGroup, ctx := errgroup.WithContext(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + errGroup, _ := errgroup.WithContext(ctx) + var receivedMessages sync.WaitGroup + receivedMessages.Add(1) // Run edge side of datagram muxer errGroup.Go(func() error { + defer receivedMessages.Done() + defer cancel() + // Accept quic connection quicSession, err := quicListener.Accept(ctx) if err != nil { @@ -135,10 +142,10 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi switch version { case 1: muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan) - muxer.ServeReceive(ctx) + go muxer.ServeReceive(ctx) case 2: muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan) - muxer.ServeReceive(ctx) + go muxer.ServeReceive(ctx) icmpDecoder := packet.NewICMPDecoder() for _, pk := range packets { @@ -157,8 +164,12 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi } for _, expectedPayload := range sessionToPayloads { - actualPayload := <-sessionDemuxChan - require.Equal(t, expectedPayload, actualPayload) + select { + case actualPayload := <-sessionDemuxChan: + require.Equal(t, expectedPayload, actualPayload) + case <-ctx.Done(): + t.Fatal("edge side got context cancelled before receiving all expected payloads") + } } return nil }) @@ -166,6 +177,8 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi largePayload := make([]byte, MaxDatagramFrameSize) // Run cloudflared side of datagram muxer errGroup.Go(func() error { + defer cancel() + tlsClientConfig := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"argotunnel"}, @@ -209,7 +222,7 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi })) // Wait for edge to finish receiving the messages - time.Sleep(time.Millisecond * 100) + receivedMessages.Wait() return nil })