diff --git a/connection/quic.go b/connection/quic.go index 1a050843..b0590b4f 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -14,6 +14,7 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" quicpogs "github.com/cloudflare/cloudflared/quic" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -25,17 +26,16 @@ const ( // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. HTTPMethodKey = "HttpMethod" // HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPHostKey = "HttpHost" + HTTPHostKey = "HttpHost" + MaxDatagramFrameSize = 1220 ) // QUICConnection represents the type that facilitates Proxying via QUIC streams. type QUICConnection struct { - session quic.Session - logger *zerolog.Logger - httpProxy OriginProxy - gracefulShutdownC <-chan struct{} - stoppedGracefully bool - udpSessions *udpSessions + session quic.Session + logger *zerolog.Logger + httpProxy OriginProxy + udpSessions *udpSessions } // NewQUICConnection returns a new instance of QUICConnection. @@ -49,6 +49,12 @@ func NewQUICConnection( controlStreamHandler ControlStreamHandler, observer *Observer, ) (*QUICConnection, error) { + localIP, err := GetLocalIP() + if err != nil { + return nil, err + } + observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP) + udpSessions := newUDPSessions(localIP) session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { return nil, fmt.Errorf("failed to dial to edge: %w", err) @@ -69,15 +75,24 @@ func NewQUICConnection( session: session, httpProxy: httpProxy, logger: observer.log, - udpSessions: newUDPSessions(), + udpSessions: udpSessions, }, nil } // Serve starts a QUIC session that begins accepting streams. func (q *QUICConnection) Serve(ctx context.Context) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + return q.listenEdgeDatagram() + }) + errGroup.Go(func() error { + return q.acceptStream(ctx) + }) + return errGroup.Wait() +} + +func (q *QUICConnection) acceptStream(ctx context.Context) error { for { stream, err := q.session.AcceptStream(ctx) if err != nil { @@ -96,6 +111,26 @@ func (q *QUICConnection) Serve(ctx context.Context) error { } } +// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload +func (q *QUICConnection) listenEdgeDatagram() error { + for { + msg, err := q.session.ReceiveMessage() + if err != nil { + return err + } + go func(msg []byte) { + sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg) + if err != nil { + q.logger.Err(err).Msg("Failed to parse session ID from datagram") + return + } + if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil { + q.logger.Err(err).Msg("Failed to send UDP to origin") + } + }(msg) + } +} + // Close closes the session with no errors specified. func (q *QUICConnection) Close() { q.session.CloseWithError(0, "") @@ -120,7 +155,7 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error { } return q.handleRPCStream(rpcStream) default: - return fmt.Errorf("Unknown protocol %v", signature) + return fmt.Errorf("unknown protocol %v", signature) } } @@ -151,7 +186,45 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er } func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error { - return q.udpSessions.register(sessionID, dstIP, dstPort) + udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort) + if err != nil { + return err + } + q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort) + go q.listenOriginUDP(sessionID, udpConn) + return nil +} + +// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin +func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) { + defer func() { + q.udpSessions.unregister(sessionID) + conn.Close() + }() + readBuffer := make([]byte, MaxDatagramFrameSize) + for { + n, err := conn.Read(readBuffer) + if n > 0 { + if n > MaxDatagramFrameSize-sessionIDLen { + // TODO: TUN-5302 return ICMP packet too big message + q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen) + continue + } + msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n]) + if err != nil { + q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped") + continue + } + if err := q.session.SendMessage(msgWithID); err != nil { + q.logger.Err(err).Msg("Failed to send datagram back to edge") + return + } + } + if err != nil { + q.logger.Err(err).Msg("Failed to read UDP from origin") + return + } + } } // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to @@ -208,7 +281,7 @@ func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadClose // metadata.Key is off the format httpHeaderKey: httpHeaderKey := strings.Split(metadata.Key, ":") if len(httpHeaderKey) != 2 { - return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key) + return nil, fmt.Errorf("header Key: %s malformed", metadata.Key) } req.Header.Add(httpHeaderKey[1], metadata.Val) } diff --git a/connection/quic_test.go b/connection/quic_test.go index 1bc1a809..8e53fcb2 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -33,7 +33,8 @@ import ( // It also serves as a demonstration for communication with the QUIC connection started by a cloudflared. func TestQUICServer(t *testing.T) { quicConfig := &quic.Config{ - KeepAlive: true, + KeepAlive: true, + EnableDatagrams: true, } // Setup test. diff --git a/connection/udp_session.go b/connection/udp_session.go index a038ac96..1d0a6057 100644 --- a/connection/udp_session.go +++ b/connection/udp_session.go @@ -1,6 +1,7 @@ package connection import ( + "fmt" "net" "sync" @@ -8,18 +9,24 @@ import ( ) // TODO: TUN-5422 Unregister session +const ( + sessionIDLen = len(uuid.UUID{}) +) + type udpSessions struct { - lock sync.Mutex + lock sync.RWMutex sessions map[uuid.UUID]*net.UDPConn + localIP net.IP } -func newUDPSessions() *udpSessions { +func newUDPSessions(localIP net.IP) *udpSessions { return &udpSessions{ sessions: make(map[uuid.UUID]*net.UDPConn), + localIP: localIP, } } -func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) error { +func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) { us.lock.Lock() defer us.lock.Unlock() dstAddr := &net.UDPAddr{ @@ -28,16 +35,56 @@ func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) erro } conn, err := net.DialUDP("udp", us.localAddr(), dstAddr) if err != nil { - return err + return nil, err } us.sessions[id] = conn - return nil + return conn, nil +} + +func (us *udpSessions) unregister(id uuid.UUID) { + us.lock.Lock() + defer us.lock.Unlock() + delete(us.sessions, id) +} + +func (us *udpSessions) send(id uuid.UUID, payload []byte) error { + us.lock.RLock() + defer us.lock.RUnlock() + conn, ok := us.sessions[id] + if !ok { + return fmt.Errorf("session %s not found", id) + } + _, err := conn.Write(payload) + return err } func (ud *udpSessions) localAddr() *net.UDPAddr { // TODO: Determine the IP to bind to + return &net.UDPAddr{ - IP: net.IPv4zero, + IP: ud.localIP, Port: 0, } } + +// TODO: TUN-5421 allow user to specify which IP to bind to +func GetLocalIP() (net.IP, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + for _, addr := range addrs { + // Find the IP that is not loop back + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if !ip.IsLoopback() { + return ip, nil + } + } + return nil, fmt.Errorf("cannot determine IP to bind to") +} diff --git a/quic/datagram.go b/quic/datagram.go new file mode 100644 index 00000000..cd538d87 --- /dev/null +++ b/quic/datagram.go @@ -0,0 +1,38 @@ +package quic + +import ( + "fmt" + + "github.com/google/uuid" +) + +const ( + sessionIDLen = len(uuid.UUID{}) + MaxDatagramFrameSize = 1220 +) + +// Each QUIC datagram should be suffixed with session ID. +// ExtractSessionID extracts the session ID and a slice with only the payload +func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) { + msgLen := len(b) + if msgLen < sessionIDLen { + return uuid.Nil, nil, fmt.Errorf("session ID has %d bytes, but data only has %d", sessionIDLen, len(b)) + } + // Parse last 16 bytess as UUID and remove it from slice + sessionID, err := uuid.FromBytes(b[len(b)-sessionIDLen:]) + if err != nil { + return uuid.Nil, nil, err + } + b = b[:len(b)-sessionIDLen] + return sessionID, b, nil +} + +// SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because +// the payload slice might already have enough capacity to append the session ID at the end +func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) { + if len(b)+len(sessionID) > MaxDatagramFrameSize { + return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize) + } + b = append(b, sessionID[:]...) + return b, nil +} diff --git a/quic/datagram_test.go b/quic/datagram_test.go new file mode 100644 index 00000000..8c36b777 --- /dev/null +++ b/quic/datagram_test.go @@ -0,0 +1,41 @@ +package quic + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +var ( + testSessionID = uuid.New() +) + +func TestSuffixThenRemoveSessionID(t *testing.T) { + msg := []byte(t.Name()) + msgWithID, err := SuffixSessionID(testSessionID, msg) + require.NoError(t, err) + require.Len(t, msgWithID, len(msg)+sessionIDLen) + + sessionID, msgWithoutID, err := ExtractSessionID(msgWithID) + require.NoError(t, err) + require.Equal(t, msg, msgWithoutID) + require.Equal(t, testSessionID, sessionID) +} + +func TestRemoveSessionIDError(t *testing.T) { + // message is too short to contain session ID + msg := []byte("test") + _, _, err := ExtractSessionID(msg) + require.Error(t, err) +} + +func TestSuffixSessionIDError(t *testing.T) { + msg := make([]byte, MaxDatagramFrameSize-sessionIDLen) + _, err := SuffixSessionID(testSessionID, msg) + require.NoError(t, err) + + msg = make([]byte, MaxDatagramFrameSize-sessionIDLen+1) + _, err = SuffixSessionID(testSessionID, msg) + require.Error(t, err) +}