TUN-5593: Read full packet from UDP connection, even if it exceeds MTU of the transport. When packet length is greater than the MTU of the transport, we will silently drop packets (for now).

This commit is contained in:
Igor Postelnik 2021-12-22 17:18:22 -06:00
parent 7a55208c61
commit 8445b88d3c
6 changed files with 28 additions and 19 deletions

View File

@ -95,7 +95,7 @@ func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, orig
} }
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
session := newSession(registration.sessionID, m.transport, registration.originProxy) session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log)
m.sessions[registration.sessionID] = session m.sessions[registration.sessionID] = session
registration.resultChan <- session registration.resultChan <- session
} }

View File

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog"
) )
const ( const (
@ -17,7 +18,7 @@ func SessionIdleErr(timeout time.Duration) error {
return fmt.Errorf("session idle for %v", timeout) return fmt.Errorf("session idle for %v", timeout)
} }
// Each Session is a bidirectional pipe of datagrams between transport and dstConn // Session is a bidirectional pipe of datagrams between transport and dstConn
// Currently the only implementation of transport is quic DatagramMuxer // Currently the only implementation of transport is quic DatagramMuxer
// Destination can be a connection with origin or with eyeball // Destination can be a connection with origin or with eyeball
// When the destination is origin: // When the destination is origin:
@ -35,9 +36,10 @@ type Session struct {
// activeAtChan is used to communicate the last read/write time // activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time activeAtChan chan time.Time
closeChan chan error closeChan chan error
log *zerolog.Logger
} }
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session { func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session {
return &Session{ return &Session{
ID: id, ID: id,
transport: transport, transport: transport,
@ -47,6 +49,7 @@ func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *
activeAtChan: make(chan time.Time, 2), activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel // capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2), closeChan: make(chan error, 2),
log: log,
} }
} }
@ -54,7 +57,8 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos
go func() { go func() {
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations // This makes it safe to share readBuffer between iterations
readBuffer := make([]byte, s.transport.ReceiveMTU()) const maxPacketSize = 1500
readBuffer := make([]byte, maxPacketSize)
for { for {
if err := s.dstToTransport(readBuffer); err != nil { if err := s.dstToTransport(readBuffer); err != nil {
s.closeChan <- err s.closeChan <- err
@ -103,8 +107,15 @@ func (s *Session) dstToTransport(buffer []byte) error {
n, err := s.dstConn.Read(buffer) n, err := s.dstConn.Read(buffer)
s.markActive() s.markActive()
if n > 0 { if n > 0 {
if err := s.transport.SendTo(s.ID, buffer[:n]); err != nil { if n <= int(s.transport.MTU()) {
return err err = s.transport.SendTo(s.ID, buffer[:n])
} else {
// drop packet for now, eventually reply with ICMP for PMTUD
s.log.Debug().
Str("session", s.ID.String()).
Int("len", n).
Uint("mtu", s.transport.MTU()).
Msg("dropped packet exceeding MTU")
} }
} }
return err return err

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -44,7 +45,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
reqChan: newDatagramChannel(1), reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1), respChan: newDatagramChannel(1),
} }
session := newSession(sessionID, transport, cfdConn) log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sessionDone := make(chan struct{}) sessionDone := make(chan struct{})
@ -119,7 +121,8 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
reqChan: newDatagramChannel(100), reqChan: newDatagramChannel(100),
respChan: newDatagramChannel(100), respChan: newDatagramChannel(100),
} }
session := newSession(sessionID, transport, cfdConn) log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
startTime := time.Now() startTime := time.Now()
activeUntil := startTime.Add(activeTime) activeUntil := startTime.Add(activeTime)
@ -181,7 +184,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
func TestMarkActiveNotBlocking(t *testing.T) { func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50 const concurrentCalls = 50
session := newSession(uuid.New(), nil, nil) session := newSession(uuid.New(), nil, nil, nil)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(concurrentCalls) wg.Add(concurrentCalls)
for i := 0; i < concurrentCalls; i++ { for i := 0; i < concurrentCalls; i++ {

View File

@ -9,5 +9,5 @@ type transport interface {
// ReceiveFrom reads the next datagram from the transport // ReceiveFrom reads the next datagram from the transport
ReceiveFrom() (uuid.UUID, []byte, error) ReceiveFrom() (uuid.UUID, []byte, error)
// Max transmission unit to receive from the transport // Max transmission unit to receive from the transport
ReceiveMTU() uint MTU() uint
} }

View File

@ -22,7 +22,7 @@ func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
return mt.reqChan.Receive(context.Background()) return mt.reqChan.Receive(context.Background())
} }
func (mt *mockQUICTransport) ReceiveMTU() uint { func (mt *mockQUICTransport) MTU() uint {
return 1217 return 1217
} }

View File

@ -36,7 +36,7 @@ func NewDatagramMuxer(quicSession quic.Session) (*DatagramMuxer, error) {
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error { func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error {
if len(payload) > MaxDatagramFrameSize-sessionIDLen { if len(payload) > MaxDatagramFrameSize-sessionIDLen {
// TODO: TUN-5302 return ICMP packet too big message // TODO: TUN-5302 return ICMP packet too big message
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.SendMTU()) return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.MTU())
} }
msgWithID, err := SuffixSessionID(sessionID, payload) msgWithID, err := SuffixSessionID(sessionID, payload)
if err != nil { if err != nil {
@ -59,16 +59,11 @@ func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
return ExtractSessionID(msg) return ExtractSessionID(msg)
} }
// Maximum application payload to send through QUIC datagram frame // Maximum application payload to send to / receive from QUIC datagram frame
func (dm *DatagramMuxer) SendMTU() uint { func (dm *DatagramMuxer) MTU() uint {
return uint(MaxDatagramFrameSize - sessionIDLen) return uint(MaxDatagramFrameSize - sessionIDLen)
} }
// Maximum expected bytes to read from QUIC datagram frame
func (dm *DatagramMuxer) ReceiveMTU() uint {
return MaxDatagramFrameSize
}
// Each QUIC datagram should be suffixed with session ID. // Each QUIC datagram should be suffixed with session ID.
// ExtractSessionID extracts the session ID and a slice with only the payload // ExtractSessionID extracts the session ID and a slice with only the payload
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) { func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {