diff --git a/connection/quic.go b/connection/quic.go index 87c3f00a..d6c738aa 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -29,8 +29,7 @@ const ( // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. HTTPMethodKey = "HttpMethod" // HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPHostKey = "HttpHost" - MaxDatagramFrameSize = 1220 + HTTPHostKey = "HttpHost" ) // QUICConnection represents the type that facilitates Proxying via QUIC streams. diff --git a/datagramsession/session.go b/datagramsession/session.go index 97b44111..59f6c3e2 100644 --- a/datagramsession/session.go +++ b/datagramsession/session.go @@ -54,7 +54,7 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos go func() { // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 // This makes it safe to share readBuffer between iterations - readBuffer := make([]byte, s.transport.MTU()) + readBuffer := make([]byte, s.transport.ReceiveMTU()) for { if err := s.dstToTransport(readBuffer); err != nil { s.closeChan <- err diff --git a/datagramsession/transport.go b/datagramsession/transport.go index e2b73bf0..aad1b475 100644 --- a/datagramsession/transport.go +++ b/datagramsession/transport.go @@ -8,6 +8,6 @@ type transport interface { SendTo(sessionID uuid.UUID, payload []byte) error // ReceiveFrom reads the next datagram from the transport ReceiveFrom() (uuid.UUID, []byte, error) - // Max transmission unit of the transport - MTU() uint + // Max transmission unit to receive from the transport + ReceiveMTU() uint } diff --git a/datagramsession/transport_test.go b/datagramsession/transport_test.go index f8c67895..2b18197d 100644 --- a/datagramsession/transport_test.go +++ b/datagramsession/transport_test.go @@ -22,8 +22,8 @@ func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) { return mt.reqChan.Receive(context.Background()) } -func (mt *mockQUICTransport) MTU() uint { - return 1220 +func (mt *mockQUICTransport) ReceiveMTU() uint { + return 1217 } func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error { diff --git a/quic/datagram.go b/quic/datagram.go index be3d52e3..ad273913 100644 --- a/quic/datagram.go +++ b/quic/datagram.go @@ -9,7 +9,9 @@ import ( ) const ( - MaxDatagramFrameSize = 1220 + // Max datagram frame size is limited to 1220 https://github.com/lucas-clemente/quic-go/blob/v0.24.0/internal/protocol/params.go#L138 + // However, 3 more bytes are reserved https://github.com/lucas-clemente/quic-go/blob/v0.24.0/internal/wire/datagram_frame.go#L61 + MaxDatagramFrameSize = 1217 sessionIDLen = len(uuid.UUID{}) ) @@ -34,7 +36,7 @@ func NewDatagramMuxer(quicSession quic.Session) (*DatagramMuxer, error) { func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error { if len(payload) > MaxDatagramFrameSize-sessionIDLen { // TODO: TUN-5302 return ICMP packet too big message - return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), MaxDatagramFrameSize-sessionIDLen) + return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.SendMTU()) } msgWithID, err := SuffixSessionID(sessionID, payload) if err != nil { @@ -57,7 +59,13 @@ func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) { return ExtractSessionID(msg) } -func (dm *DatagramMuxer) MTU() uint { +// Maximum application payload to send through QUIC datagram frame +func (dm *DatagramMuxer) SendMTU() uint { + return uint(MaxDatagramFrameSize - sessionIDLen) +} + +// Maximum expected bytes to read from QUIC datagram frame +func (dm *DatagramMuxer) ReceiveMTU() uint { return MaxDatagramFrameSize } diff --git a/quic/datagram_test.go b/quic/datagram_test.go index 8c36b777..b761f1bd 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -1,10 +1,20 @@ package quic import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" "testing" "github.com/google/uuid" + "github.com/lucas-clemente/quic-go" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) var ( @@ -39,3 +49,91 @@ func TestSuffixSessionIDError(t *testing.T) { _, err = SuffixSessionID(testSessionID, msg) require.Error(t, err) } + +func TestMaxDatagramPayload(t *testing.T) { + payload := make([]byte, MaxDatagramFrameSize-sessionIDLen) + + quicConfig := &quic.Config{ + KeepAlive: true, + EnableDatagrams: true, + } + quicListener := newQUICListener(t, quicConfig) + defer quicListener.Close() + + errGroup, ctx := errgroup.WithContext(context.Background()) + // Run edge side of datagram muxer + errGroup.Go(func() error { + // Accept quic connection + quicSession, err := quicListener.Accept(ctx) + require.NoError(t, err) + + muxer, err := NewDatagramMuxer(quicSession) + require.NoError(t, err) + + sessionID, receivedPayload, err := muxer.ReceiveFrom() + require.NoError(t, err) + require.Equal(t, testSessionID, sessionID) + require.True(t, bytes.Equal(payload, receivedPayload)) + + return nil + }) + + // Run cloudflared side of datagram muxer + errGroup.Go(func() error { + tlsClientConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"argotunnel"}, + } + // Establish quic connection + quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig) + require.NoError(t, err) + + muxer, err := NewDatagramMuxer(quicSession) + require.NoError(t, err) + + err = muxer.SendTo(testSessionID, payload) + require.NoError(t, err) + + // Payload larger than transport MTU, should return an error + largePayload := append(payload, byte(1)) + err = muxer.SendTo(testSessionID, largePayload) + require.Error(t, err) + + return nil + }) + + require.NoError(t, errGroup.Wait()) +} + +func newQUICListener(t *testing.T, config *quic.Config) quic.Listener { + // Create a simple tls config. + tlsConfig := generateTLSConfig() + + listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, config) + require.NoError(t, err) + + return listener +} + +func generateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + NextProtos: []string{"argotunnel"}, + } +}