diff --git a/connection/quic.go b/connection/quic.go index c1b4ff9d..48537710 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -122,7 +122,7 @@ func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream q func (q *QUICConnection) acceptStream(ctx context.Context) error { defer q.Close() for { - stream, err := q.session.AcceptStream(ctx) + quicStream, err := q.session.AcceptStream(ctx) if err != nil { // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { @@ -131,7 +131,9 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error { return fmt.Errorf("failed to accept QUIC stream: %w", err) } go func() { + stream := quicpogs.NewSafeStreamCloser(quicStream) defer stream.Close() + if err = q.handleStream(stream); err != nil { q.logger.Err(err).Msg("Failed to handle QUIC stream") } @@ -144,7 +146,7 @@ func (q *QUICConnection) Close() { q.session.CloseWithError(0, "") } -func (q *QUICConnection) handleStream(stream quic.Stream) error { +func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error { signature, err := quicpogs.DetermineProtocol(stream) if err != nil { return err diff --git a/connection/quic_test.go b/connection/quic_test.go index ac945400..a430466d 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -3,14 +3,9 @@ package connection import ( "bytes" "context" - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "encoding/pem" "fmt" "io" - "math/big" "net" "net/http" "net/url" @@ -33,7 +28,7 @@ import ( ) var ( - testTLSServerConfig = generateTLSConfig() + testTLSServerConfig = quicpogs.GenerateTLSConfig() testQUICConfig = &quic.Config{ KeepAlive: true, EnableDatagrams: true, @@ -84,7 +79,7 @@ func TestQUICServer(t *testing.T) { }, { desc: "test http body request streaming", - dest: "/echo_body", + dest: "/slow_echo_body", connectionType: quicpogs.ConnectionTypeHTTP, metadata: []quicpogs.Metadata{ { @@ -195,8 +190,9 @@ func quicServer( session, err := earlyListener.Accept(ctx) require.NoError(t, err) - stream, err := session.OpenStreamSync(context.Background()) + quicStream, err := session.OpenStreamSync(context.Background()) require.NoError(t, err) + stream := quicpogs.NewSafeStreamCloser(quicStream) reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream} err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...) @@ -207,42 +203,20 @@ func quicServer( if message != nil { // ALPN successful. Write data. - _, err := stream.Write([]byte(message)) + _, err := stream.Write(message) require.NoError(t, err) } response := make([]byte, len(expectedResponse)) - stream.Read(response) - require.NoError(t, err) + _, err = stream.Read(response) + if err != io.EOF { + require.NoError(t, err) + } // For now it is an echo server. Verify if the same data is returned. assert.Equal(t, expectedResponse, response) } -// Setup a bare-bones TLS config for the server -func generateTLSConfig() *tls.Config { - key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - panic(err) - } - template := x509.Certificate{SerialNumber: big.NewInt(1)} - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - panic(err) - } - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - panic(err) - } - return &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - NextProtos: []string{"argotunnel"}, - } -} - type mockOriginProxyWithRequest struct{} func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Request, isWebsocket bool) error { @@ -264,6 +238,9 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque switch r.URL.Path { case "/ok": originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK))) + case "/slow_echo_body": + time.Sleep(5) + fallthrough case "/echo_body": resp := &http.Response{ StatusCode: http.StatusOK, diff --git a/origin/tunnel.go b/origin/tunnel.go index bcfea720..99cb5987 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -31,8 +31,6 @@ const ( dialTimeout = 15 * time.Second FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" - quicHandshakeIdleTimeout = 5 * time.Second - quicMaxIdleTimeout = 15 * time.Second ) type TunnelConfig struct { @@ -523,8 +521,8 @@ func ServeQUIC( ) (err error, recoverable bool) { tlsConfig := config.EdgeTLSConfigs[connection.QUIC] quicConfig := &quic.Config{ - HandshakeIdleTimeout: quicHandshakeIdleTimeout, - MaxIdleTimeout: quicMaxIdleTimeout, + HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout, + MaxIdleTimeout: quicpogs.MaxIdleTimeout, MaxIncomingStreams: connection.MaxConcurrentStreams, MaxIncomingUniStreams: connection.MaxConcurrentStreams, KeepAlive: true, diff --git a/quic/quic_protocol.go b/quic/quic_protocol.go index bba808ee..f64689bc 100644 --- a/quic/quic_protocol.go +++ b/quic/quic_protocol.go @@ -17,8 +17,8 @@ import ( tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does -// not write data before writing the metadata. +// ProtocolSignature defines the first 6 bytes of the stream, which is used to distinguish the type of stream. It +// ensures whoever performs a handshake does not write data before writing the metadata. type ProtocolSignature [6]byte var ( @@ -29,12 +29,15 @@ var ( RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65} ) -const protocolVersionLength = 2 - type protocolVersion string const ( protocolV1 protocolVersion = "01" + + protocolVersionLength = 2 + + HandshakeIdleTimeout = 5 * time.Second + MaxIdleTimeout = 15 * time.Second ) // RequestServerStream is a stream to serve requests diff --git a/quic/safe_stream.go b/quic/safe_stream.go new file mode 100644 index 00000000..12ba76f4 --- /dev/null +++ b/quic/safe_stream.go @@ -0,0 +1,43 @@ +package quic + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go" +) + +type SafeStreamCloser struct { + lock sync.Mutex + stream quic.Stream +} + +func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser { + return &SafeStreamCloser{ + stream: stream, + } +} + +func (s *SafeStreamCloser) Read(p []byte) (n int, err error) { + return s.stream.Read(p) +} + +func (s *SafeStreamCloser) Write(p []byte) (n int, err error) { + s.lock.Lock() + defer s.lock.Unlock() + return s.stream.Write(p) +} + +func (s *SafeStreamCloser) Close() error { + // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer + // side of the stream safely. + _ = s.stream.SetWriteDeadline(time.Now()) + + // This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes. + s.lock.Lock() + defer s.lock.Unlock() + + // We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that. + s.stream.CancelRead(0) + return s.stream.Close() +} diff --git a/quic/safe_stream_test.go b/quic/safe_stream_test.go new file mode 100644 index 00000000..48ffb559 --- /dev/null +++ b/quic/safe_stream_test.go @@ -0,0 +1,142 @@ +package quic + +import ( + "context" + "crypto/tls" + "io" + "net" + "sync" + "testing" + + "github.com/lucas-clemente/quic-go" + "github.com/stretchr/testify/require" +) + +var ( + testTLSServerConfig = GenerateTLSConfig() + testQUICConfig = &quic.Config{ + KeepAlive: true, + EnableDatagrams: true, + } + exchanges = 1000 + msgsPerExchange = 10 + testMsg = "Ok message" +) + +func TestSafeStreamClose(t *testing.T) { + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) + require.NoError(t, err) + defer udpListener.Close() + + var serverReady sync.WaitGroup + serverReady.Add(1) + + var done sync.WaitGroup + done.Add(1) + go func() { + defer done.Done() + quicServer(t, &serverReady, udpListener) + }() + + done.Add(1) + go func() { + serverReady.Wait() + defer done.Done() + quicClient(t, udpListener.LocalAddr()) + }() + + done.Wait() +} + +func quicClient(t *testing.T, addr net.Addr) { + tlsConf := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"argotunnel"}, + } + session, err := quic.DialAddr(addr.String(), tlsConf, testQUICConfig) + require.NoError(t, err) + + var wg sync.WaitGroup + for exchange := 0; exchange < exchanges; exchange++ { + quicStream, err := session.AcceptStream(context.Background()) + require.NoError(t, err) + wg.Add(1) + + go func(iter int) { + defer wg.Done() + + stream := NewSafeStreamCloser(quicStream) + defer stream.Close() + + // Do a bunch of round trips over this stream that should work. + for msg := 0; msg < msgsPerExchange; msg++ { + clientRoundTrip(t, stream, true) + } + // And one that won't work necessarily, but shouldn't break other streams in the session. + if iter%2 == 0 { + clientRoundTrip(t, stream, false) + } + }(exchange) + } + + wg.Wait() +} + +func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + earlyListener, err := quic.Listen(conn, testTLSServerConfig, testQUICConfig) + require.NoError(t, err) + + serverReady.Done() + session, err := earlyListener.Accept(ctx) + require.NoError(t, err) + + var wg sync.WaitGroup + for exchange := 0; exchange < exchanges; exchange++ { + quicStream, err := session.OpenStreamSync(context.Background()) + require.NoError(t, err) + wg.Add(1) + + go func(iter int) { + defer wg.Done() + + stream := NewSafeStreamCloser(quicStream) + defer stream.Close() + + // Do a bunch of round trips over this stream that should work. + for msg := 0; msg < msgsPerExchange; msg++ { + serverRoundTrip(t, stream, true) + } + // And one that won't work necessarily, but shouldn't break other streams in the session. + if iter%2 == 1 { + serverRoundTrip(t, stream, false) + } + }(exchange) + } + + wg.Wait() +} + +func clientRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) { + response := make([]byte, len(testMsg)) + _, err := stream.Read(response) + if !mustWork { + return + } + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, testMsg, string(response)) +} + +func serverRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) { + _, err := stream.Write([]byte(testMsg)) + if !mustWork { + return + } + require.NoError(t, err) +} diff --git a/quic/test_utils.go b/quic/test_utils.go new file mode 100644 index 00000000..56c342f6 --- /dev/null +++ b/quic/test_utils.go @@ -0,0 +1,34 @@ +package quic + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" +) + +// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server +func GenerateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{"argotunnel"}, + } +}