TUN-5621: Correctly manage QUIC stream closing
Until this PR, we were naively closing the quic.Stream whenever the callstack for handling the request (HTTP or TCP) finished. However, our proxy handler may still be reading or writing from the quic.Stream at that point, because we return the callstack if either side finishes, but not necessarily both. This is a problem for quic-go library because quic.Stream#Close cannot be called concurrently with quic.Stream#Write Furthermore, we also noticed that quic.Stream#Close does nothing to do receiving stream (since, underneath, quic.Stream has 2 streams, 1 for each direction), thus leaking memory, as explained in: https://github.com/lucas-clemente/quic-go/issues/3322 This PR addresses both problems by wrapping the quic.Stream that is passed down to the proxying logic and handle all these concerns.
This commit is contained in:
parent
e09dcf6d60
commit
ed2bac026d
|
@ -122,7 +122,7 @@ func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream q
|
||||||
func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||||
defer q.Close()
|
defer q.Close()
|
||||||
for {
|
for {
|
||||||
stream, err := q.session.AcceptStream(ctx)
|
quicStream, err := q.session.AcceptStream(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
|
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
|
||||||
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
|
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
|
||||||
|
@ -131,7 +131,9 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||||
return fmt.Errorf("failed to accept QUIC stream: %w", err)
|
return fmt.Errorf("failed to accept QUIC stream: %w", err)
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
|
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
if err = q.handleStream(stream); err != nil {
|
if err = q.handleStream(stream); err != nil {
|
||||||
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
||||||
}
|
}
|
||||||
|
@ -144,7 +146,7 @@ func (q *QUICConnection) Close() {
|
||||||
q.session.CloseWithError(0, "")
|
q.session.CloseWithError(0, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) handleStream(stream quic.Stream) error {
|
func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
|
||||||
signature, err := quicpogs.DetermineProtocol(stream)
|
signature, err := quicpogs.DetermineProtocol(stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -3,14 +3,9 @@ package connection
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -33,7 +28,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testTLSServerConfig = generateTLSConfig()
|
testTLSServerConfig = quicpogs.GenerateTLSConfig()
|
||||||
testQUICConfig = &quic.Config{
|
testQUICConfig = &quic.Config{
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
EnableDatagrams: true,
|
EnableDatagrams: true,
|
||||||
|
@ -84,7 +79,7 @@ func TestQUICServer(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "test http body request streaming",
|
desc: "test http body request streaming",
|
||||||
dest: "/echo_body",
|
dest: "/slow_echo_body",
|
||||||
connectionType: quicpogs.ConnectionTypeHTTP,
|
connectionType: quicpogs.ConnectionTypeHTTP,
|
||||||
metadata: []quicpogs.Metadata{
|
metadata: []quicpogs.Metadata{
|
||||||
{
|
{
|
||||||
|
@ -195,8 +190,9 @@ func quicServer(
|
||||||
session, err := earlyListener.Accept(ctx)
|
session, err := earlyListener.Accept(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
stream, err := session.OpenStreamSync(context.Background())
|
quicStream, err := session.OpenStreamSync(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
||||||
|
|
||||||
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
|
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
|
||||||
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
|
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
|
||||||
|
@ -207,42 +203,20 @@ func quicServer(
|
||||||
|
|
||||||
if message != nil {
|
if message != nil {
|
||||||
// ALPN successful. Write data.
|
// ALPN successful. Write data.
|
||||||
_, err := stream.Write([]byte(message))
|
_, err := stream.Write(message)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, len(expectedResponse))
|
response := make([]byte, len(expectedResponse))
|
||||||
stream.Read(response)
|
_, err = stream.Read(response)
|
||||||
require.NoError(t, err)
|
if err != io.EOF {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// For now it is an echo server. Verify if the same data is returned.
|
// For now it is an echo server. Verify if the same data is returned.
|
||||||
assert.Equal(t, expectedResponse, response)
|
assert.Equal(t, expectedResponse, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup a bare-bones TLS config for the server
|
|
||||||
func generateTLSConfig() *tls.Config {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
|
||||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
||||||
|
|
||||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{tlsCert},
|
|
||||||
NextProtos: []string{"argotunnel"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockOriginProxyWithRequest struct{}
|
type mockOriginProxyWithRequest struct{}
|
||||||
|
|
||||||
func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Request, isWebsocket bool) error {
|
func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Request, isWebsocket bool) error {
|
||||||
|
@ -264,6 +238,9 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/ok":
|
case "/ok":
|
||||||
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
||||||
|
case "/slow_echo_body":
|
||||||
|
time.Sleep(5)
|
||||||
|
fallthrough
|
||||||
case "/echo_body":
|
case "/echo_body":
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusOK,
|
StatusCode: http.StatusOK,
|
||||||
|
|
|
@ -31,8 +31,6 @@ const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
FeatureSerializedHeaders = "serialized_headers"
|
FeatureSerializedHeaders = "serialized_headers"
|
||||||
FeatureQuickReconnects = "quick_reconnects"
|
FeatureQuickReconnects = "quick_reconnects"
|
||||||
quicHandshakeIdleTimeout = 5 * time.Second
|
|
||||||
quicMaxIdleTimeout = 15 * time.Second
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelConfig struct {
|
type TunnelConfig struct {
|
||||||
|
@ -523,8 +521,8 @@ func ServeQUIC(
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
|
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
|
||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
|
HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout,
|
||||||
MaxIdleTimeout: quicMaxIdleTimeout,
|
MaxIdleTimeout: quicpogs.MaxIdleTimeout,
|
||||||
MaxIncomingStreams: connection.MaxConcurrentStreams,
|
MaxIncomingStreams: connection.MaxConcurrentStreams,
|
||||||
MaxIncomingUniStreams: connection.MaxConcurrentStreams,
|
MaxIncomingUniStreams: connection.MaxConcurrentStreams,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
|
|
|
@ -17,8 +17,8 @@ import (
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does
|
// ProtocolSignature defines the first 6 bytes of the stream, which is used to distinguish the type of stream. It
|
||||||
// not write data before writing the metadata.
|
// ensures whoever performs a handshake does not write data before writing the metadata.
|
||||||
type ProtocolSignature [6]byte
|
type ProtocolSignature [6]byte
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -29,12 +29,15 @@ var (
|
||||||
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
||||||
)
|
)
|
||||||
|
|
||||||
const protocolVersionLength = 2
|
|
||||||
|
|
||||||
type protocolVersion string
|
type protocolVersion string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
protocolV1 protocolVersion = "01"
|
protocolV1 protocolVersion = "01"
|
||||||
|
|
||||||
|
protocolVersionLength = 2
|
||||||
|
|
||||||
|
HandshakeIdleTimeout = 5 * time.Second
|
||||||
|
MaxIdleTimeout = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestServerStream is a stream to serve requests
|
// RequestServerStream is a stream to serve requests
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SafeStreamCloser struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
stream quic.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser {
|
||||||
|
return &SafeStreamCloser{
|
||||||
|
stream: stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
|
||||||
|
return s.stream.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
return s.stream.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SafeStreamCloser) Close() error {
|
||||||
|
// Make sure a possible writer does not block the lock forever. We need it, so we can close the writer
|
||||||
|
// side of the stream safely.
|
||||||
|
_ = s.stream.SetWriteDeadline(time.Now())
|
||||||
|
|
||||||
|
// This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes.
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
|
// We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that.
|
||||||
|
s.stream.CancelRead(0)
|
||||||
|
return s.stream.Close()
|
||||||
|
}
|
|
@ -0,0 +1,142 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testTLSServerConfig = GenerateTLSConfig()
|
||||||
|
testQUICConfig = &quic.Config{
|
||||||
|
KeepAlive: true,
|
||||||
|
EnableDatagrams: true,
|
||||||
|
}
|
||||||
|
exchanges = 1000
|
||||||
|
msgsPerExchange = 10
|
||||||
|
testMsg = "Ok message"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSafeStreamClose(t *testing.T) {
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer udpListener.Close()
|
||||||
|
|
||||||
|
var serverReady sync.WaitGroup
|
||||||
|
serverReady.Add(1)
|
||||||
|
|
||||||
|
var done sync.WaitGroup
|
||||||
|
done.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer done.Done()
|
||||||
|
quicServer(t, &serverReady, udpListener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
done.Add(1)
|
||||||
|
go func() {
|
||||||
|
serverReady.Wait()
|
||||||
|
defer done.Done()
|
||||||
|
quicClient(t, udpListener.LocalAddr())
|
||||||
|
}()
|
||||||
|
|
||||||
|
done.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func quicClient(t *testing.T, addr net.Addr) {
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"argotunnel"},
|
||||||
|
}
|
||||||
|
session, err := quic.DialAddr(addr.String(), tlsConf, testQUICConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for exchange := 0; exchange < exchanges; exchange++ {
|
||||||
|
quicStream, err := session.AcceptStream(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(iter int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
stream := NewSafeStreamCloser(quicStream)
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
// Do a bunch of round trips over this stream that should work.
|
||||||
|
for msg := 0; msg < msgsPerExchange; msg++ {
|
||||||
|
clientRoundTrip(t, stream, true)
|
||||||
|
}
|
||||||
|
// And one that won't work necessarily, but shouldn't break other streams in the session.
|
||||||
|
if iter%2 == 0 {
|
||||||
|
clientRoundTrip(t, stream, false)
|
||||||
|
}
|
||||||
|
}(exchange)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
earlyListener, err := quic.Listen(conn, testTLSServerConfig, testQUICConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverReady.Done()
|
||||||
|
session, err := earlyListener.Accept(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for exchange := 0; exchange < exchanges; exchange++ {
|
||||||
|
quicStream, err := session.OpenStreamSync(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(iter int) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
stream := NewSafeStreamCloser(quicStream)
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
// Do a bunch of round trips over this stream that should work.
|
||||||
|
for msg := 0; msg < msgsPerExchange; msg++ {
|
||||||
|
serverRoundTrip(t, stream, true)
|
||||||
|
}
|
||||||
|
// And one that won't work necessarily, but shouldn't break other streams in the session.
|
||||||
|
if iter%2 == 1 {
|
||||||
|
serverRoundTrip(t, stream, false)
|
||||||
|
}
|
||||||
|
}(exchange)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) {
|
||||||
|
response := make([]byte, len(testMsg))
|
||||||
|
_, err := stream.Read(response)
|
||||||
|
if !mustWork {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
require.Equal(t, testMsg, string(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) {
|
||||||
|
_, err := stream.Write([]byte(testMsg))
|
||||||
|
if !mustWork {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
|
||||||
|
func GenerateTLSConfig() *tls.Config {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
|
||||||
|
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{tlsCert},
|
||||||
|
NextProtos: []string{"argotunnel"},
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue