diff --git a/connection/quic.go b/connection/quic.go index a6d6f6f0..9905e90b 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -10,6 +10,7 @@ import ( "net/netip" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -42,6 +43,11 @@ const ( demuxChanCapacity = 16 ) +var ( + portForConnIndex = make(map[uint8]int, 0) + portMapMutex sync.Mutex +) + // QUICConnection represents the type that facilitates Proxying via QUIC streams. type QUICConnection struct { session quic.Connection @@ -60,6 +66,7 @@ type QUICConnection struct { func NewQUICConnection( quicConfig *quic.Config, edgeAddr net.Addr, + connIndex uint8, tlsConfig *tls.Config, orchestrator Orchestrator, connOptions *tunnelpogs.ConnectionOptions, @@ -67,11 +74,22 @@ func NewQUICConnection( logger *zerolog.Logger, packetRouterConfig *packet.GlobalRouterConfig, ) (*QUICConnection, error) { - session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) + udpConn, err := createUDPConnForConnIndex(connIndex, logger) + if err != nil { + return nil, err + } + + session, err := quic.Dial(udpConn, edgeAddr, edgeAddr.String(), tlsConfig, quicConfig) if err != nil { return nil, &EdgeQuicDialError{Cause: err} } + // wrap the session, so that the UDPConn is closed after session is closed. + session = &wrapCloseableConnQuicConnection{ + session, + udpConn, + } + sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) @@ -492,3 +510,45 @@ func (rp *returnPipe) SendPacket(dst netip.Addr, pk packet.RawPacket) error { func (rp *returnPipe) Close() error { return nil } + +func createUDPConnForConnIndex(connIndex uint8, logger *zerolog.Logger) (*net.UDPConn, error) { + portMapMutex.Lock() + defer portMapMutex.Unlock() + + // if port was not set yet, it will be zero, so bind will randomly allocate one. + if port, ok := portForConnIndex[connIndex]; ok { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: port}) + // if there wasn't an error, or if port was 0 (independently of error or not, just return) + if err == nil { + return udpConn, nil + } else { + logger.Debug().Err(err).Msgf("Unable to reuse port %d for connIndex %d. Falling back to random allocation.", port, connIndex) + } + } + + // if we reached here, then there was an error or port as not been allocated it. + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err == nil { + udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr) + if !ok { + return nil, fmt.Errorf("unable to cast to udpConn") + } + portForConnIndex[connIndex] = udpAddr.Port + } else { + delete(portForConnIndex, connIndex) + } + + return udpConn, err +} + +type wrapCloseableConnQuicConnection struct { + quic.Connection + udpConn *net.UDPConn +} + +func (w *wrapCloseableConnQuicConnection) CloseWithError(errorCode quic.ApplicationErrorCode, reason string) error { + err := w.Connection.CloseWithError(errorCode, reason) + w.udpConn.Close() + + return err +} diff --git a/connection/quic_test.go b/connection/quic_test.go index c2990878..ebdc07a2 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -567,6 +567,37 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) { require.Equal(t, err, io.EOF) } +func TestCreateUDPConnReuseSourcePort(t *testing.T) { + logger := zerolog.Nop() + conn, err := createUDPConnForConnIndex(0, &logger) + require.NoError(t, err) + + getPortFunc := func(conn *net.UDPConn) int { + addr := conn.LocalAddr().(*net.UDPAddr) + return addr.Port + } + + initialPort := getPortFunc(conn) + + // close conn + conn.Close() + + // should get the same port as before. + conn, err = createUDPConnForConnIndex(0, &logger) + require.NoError(t, err) + require.Equal(t, initialPort, getPortFunc(conn)) + + // new index, should get a different port + conn1, err := createUDPConnForConnIndex(1, &logger) + require.NoError(t, err) + require.NotEqual(t, initialPort, getPortFunc(conn1)) + + // not closing the conn and trying to obtain a new conn for same index should give a different random port + conn, err = createUDPConnForConnIndex(0, &logger) + require.NoError(t, err) + require.NotEqual(t, initialPort, getPortFunc(conn)) +} + func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { var ( payload = []byte(t.Name()) @@ -682,6 +713,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection qc, err := NewQUICConnection( testQUICConfig, udpListenerAddr, + 0, tlsClientConfig, &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, &tunnelpogs.ConnectionOptions{}, diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index d0a78b04..6c0c4cd7 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -657,6 +657,7 @@ func (e *EdgeTunnelServer) serveQUIC( quicConn, err := connection.NewQUICConnection( quicConfig, edgeAddr, + connIndex, tlsConfig, e.orchestrator, connOptions,