From c196679bc7f178e60aeaa0891dfd6688a2c0327f Mon Sep 17 00:00:00 2001 From: cthuang Date: Wed, 19 Jan 2022 18:24:16 +0000 Subject: [PATCH] TUN-5659: Proxy UDP with zero-byte payload --- datagramsession/session.go | 5 +++++ datagramsession/session_test.go | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/datagramsession/session.go b/datagramsession/session.go index fc46aa0f..0a1199e8 100644 --- a/datagramsession/session.go +++ b/datagramsession/session.go @@ -106,6 +106,7 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time func (s *Session) dstToTransport(buffer []byte) error { n, err := s.dstConn.Read(buffer) s.markActive() + // https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes if n > 0 { if n <= int(s.transport.MTU()) { err = s.transport.SendTo(s.ID, buffer[:n]) @@ -118,6 +119,10 @@ func (s *Session) dstToTransport(buffer []byte) error { Msg("dropped packet exceeding MTU") } } + // Some UDP application might send 0-size payload. + if err == nil && n == 0 { + err = s.transport.SendTo(s.ID, []byte{}) + } return err } diff --git a/datagramsession/session_test.go b/datagramsession/session_test.go index 591fa3c6..bf1b1e20 100644 --- a/datagramsession/session_test.go +++ b/datagramsession/session_test.go @@ -195,3 +195,40 @@ func TestMarkActiveNotBlocking(t *testing.T) { } wg.Wait() } + +func TestZeroBytePayload(t *testing.T) { + sessionID := uuid.New() + cfdConn, originConn := net.Pipe() + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(1), + respChan: newDatagramChannel(1), + } + log := zerolog.Nop() + session := newSession(sessionID, transport, cfdConn, &log) + + ctx, cancel := context.WithCancel(context.Background()) + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + // Read from underlying conn and send to transport + closedByRemote, err := session.Serve(ctx, time.Minute*2) + require.Equal(t, context.Canceled, err) + require.False(t, closedByRemote) + return nil + }) + + errGroup.Go(func() error { + // Write to underlying connection + n, err := originConn.Write([]byte{}) + require.NoError(t, err) + require.Equal(t, 0, n) + return nil + }) + + receivedSessionID, payload, err := transport.respChan.Receive(ctx) + require.NoError(t, err) + require.Len(t, payload, 0) + require.Equal(t, sessionID, receivedSessionID) + + cancel() + require.NoError(t, errGroup.Wait()) +}