Browse Source

TUN-5299: Send/receive QUIC datagram from edge and proxy to origin as UDP

pull/561/head
cthuang 6 months ago committed by Arég Harutyunyan
parent
commit
dd32dc1364
  1. 99
      connection/quic.go
  2. 3
      connection/quic_test.go
  3. 59
      connection/udp_session.go
  4. 38
      quic/datagram.go
  5. 41
      quic/datagram_test.go

99
connection/quic.go

@ -14,6 +14,7 @@ import (
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -25,17 +26,16 @@ const (
// HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
HTTPMethodKey = "HttpMethod"
// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
HTTPHostKey = "HttpHost"
HTTPHostKey = "HttpHost"
MaxDatagramFrameSize = 1220
)
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct {
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
gracefulShutdownC <-chan struct{}
stoppedGracefully bool
udpSessions *udpSessions
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
udpSessions *udpSessions
}
// NewQUICConnection returns a new instance of QUICConnection.
@ -49,6 +49,12 @@ func NewQUICConnection(
controlStreamHandler ControlStreamHandler,
observer *Observer,
) (*QUICConnection, error) {
localIP, err := GetLocalIP()
if err != nil {
return nil, err
}
observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP)
udpSessions := newUDPSessions(localIP)
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial to edge: %w", err)
@ -69,15 +75,24 @@ func NewQUICConnection(
session: session,
httpProxy: httpProxy,
logger: observer.log,
udpSessions: newUDPSessions(),
udpSessions: udpSessions,
}, nil
}
// Serve starts a QUIC session that begins accepting streams.
func (q *QUICConnection) Serve(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return q.listenEdgeDatagram()
})
errGroup.Go(func() error {
return q.acceptStream(ctx)
})
return errGroup.Wait()
}
func (q *QUICConnection) acceptStream(ctx context.Context) error {
for {
stream, err := q.session.AcceptStream(ctx)
if err != nil {
@ -96,6 +111,26 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
}
}
// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload
func (q *QUICConnection) listenEdgeDatagram() error {
for {
msg, err := q.session.ReceiveMessage()
if err != nil {
return err
}
go func(msg []byte) {
sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg)
if err != nil {
q.logger.Err(err).Msg("Failed to parse session ID from datagram")
return
}
if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil {
q.logger.Err(err).Msg("Failed to send UDP to origin")
}
}(msg)
}
}
// Close closes the session with no errors specified.
func (q *QUICConnection) Close() {
q.session.CloseWithError(0, "")
@ -120,7 +155,7 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
}
return q.handleRPCStream(rpcStream)
default:
return fmt.Errorf("Unknown protocol %v", signature)
return fmt.Errorf("unknown protocol %v", signature)
}
}
@ -151,7 +186,45 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
}
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
return q.udpSessions.register(sessionID, dstIP, dstPort)
udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort)
if err != nil {
return err
}
q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort)
go q.listenOriginUDP(sessionID, udpConn)
return nil
}
// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin
func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) {
defer func() {
q.udpSessions.unregister(sessionID)
conn.Close()
}()
readBuffer := make([]byte, MaxDatagramFrameSize)
for {
n, err := conn.Read(readBuffer)
if n > 0 {
if n > MaxDatagramFrameSize-sessionIDLen {
// TODO: TUN-5302 return ICMP packet too big message
q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen)
continue
}
msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n])
if err != nil {
q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped")
continue
}
if err := q.session.SendMessage(msgWithID); err != nil {
q.logger.Err(err).Msg("Failed to send datagram back to edge")
return
}
}
if err != nil {
q.logger.Err(err).Msg("Failed to read UDP from origin")
return
}
}
}
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
@ -208,7 +281,7 @@ func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadClose
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
httpHeaderKey := strings.Split(metadata.Key, ":")
if len(httpHeaderKey) != 2 {
return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key)
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
}
req.Header.Add(httpHeaderKey[1], metadata.Val)
}

3
connection/quic_test.go

@ -33,7 +33,8 @@ import (
// It also serves as a demonstration for communication with the QUIC connection started by a cloudflared.
func TestQUICServer(t *testing.T) {
quicConfig := &quic.Config{
KeepAlive: true,
KeepAlive: true,
EnableDatagrams: true,
}
// Setup test.

59
connection/udp_session.go

@ -1,6 +1,7 @@
package connection
import (
"fmt"
"net"
"sync"
@ -8,18 +9,24 @@ import (
)
// TODO: TUN-5422 Unregister session
const (
sessionIDLen = len(uuid.UUID{})
)
type udpSessions struct {
lock sync.Mutex
lock sync.RWMutex
sessions map[uuid.UUID]*net.UDPConn
localIP net.IP
}
func newUDPSessions() *udpSessions {
func newUDPSessions(localIP net.IP) *udpSessions {
return &udpSessions{
sessions: make(map[uuid.UUID]*net.UDPConn),
localIP: localIP,
}
}
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) error {
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
us.lock.Lock()
defer us.lock.Unlock()
dstAddr := &net.UDPAddr{
@ -28,16 +35,56 @@ func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) erro
}
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr)
if err != nil {
return err
return nil, err
}
us.sessions[id] = conn
return nil
return conn, nil
}
func (us *udpSessions) unregister(id uuid.UUID) {
us.lock.Lock()
defer us.lock.Unlock()
delete(us.sessions, id)
}
func (us *udpSessions) send(id uuid.UUID, payload []byte) error {
us.lock.RLock()
defer us.lock.RUnlock()
conn, ok := us.sessions[id]
if !ok {
return fmt.Errorf("session %s not found", id)
}
_, err := conn.Write(payload)
return err
}
func (ud *udpSessions) localAddr() *net.UDPAddr {
// TODO: Determine the IP to bind to
return &net.UDPAddr{
IP: net.IPv4zero,
IP: ud.localIP,
Port: 0,
}
}
// TODO: TUN-5421 allow user to specify which IP to bind to
func GetLocalIP() (net.IP, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
// Find the IP that is not loop back
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if !ip.IsLoopback() {
return ip, nil
}
}
return nil, fmt.Errorf("cannot determine IP to bind to")
}

38
quic/datagram.go

@ -0,0 +1,38 @@
package quic
import (
"fmt"
"github.com/google/uuid"
)
const (
sessionIDLen = len(uuid.UUID{})
MaxDatagramFrameSize = 1220
)
// Each QUIC datagram should be suffixed with session ID.
// ExtractSessionID extracts the session ID and a slice with only the payload
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {
msgLen := len(b)
if msgLen < sessionIDLen {
return uuid.Nil, nil, fmt.Errorf("session ID has %d bytes, but data only has %d", sessionIDLen, len(b))
}
// Parse last 16 bytess as UUID and remove it from slice
sessionID, err := uuid.FromBytes(b[len(b)-sessionIDLen:])
if err != nil {
return uuid.Nil, nil, err
}
b = b[:len(b)-sessionIDLen]
return sessionID, b, nil
}
// SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because
// the payload slice might already have enough capacity to append the session ID at the end
func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) {
if len(b)+len(sessionID) > MaxDatagramFrameSize {
return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize)
}
b = append(b, sessionID[:]...)
return b, nil
}

41
quic/datagram_test.go

@ -0,0 +1,41 @@
package quic
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
var (
testSessionID = uuid.New()
)
func TestSuffixThenRemoveSessionID(t *testing.T) {
msg := []byte(t.Name())
msgWithID, err := SuffixSessionID(testSessionID, msg)
require.NoError(t, err)
require.Len(t, msgWithID, len(msg)+sessionIDLen)
sessionID, msgWithoutID, err := ExtractSessionID(msgWithID)
require.NoError(t, err)
require.Equal(t, msg, msgWithoutID)
require.Equal(t, testSessionID, sessionID)
}
func TestRemoveSessionIDError(t *testing.T) {
// message is too short to contain session ID
msg := []byte("test")
_, _, err := ExtractSessionID(msg)
require.Error(t, err)
}
func TestSuffixSessionIDError(t *testing.T) {
msg := make([]byte, MaxDatagramFrameSize-sessionIDLen)
_, err := SuffixSessionID(testSessionID, msg)
require.NoError(t, err)
msg = make([]byte, MaxDatagramFrameSize-sessionIDLen+1)
_, err = SuffixSessionID(testSessionID, msg)
require.Error(t, err)
}
Loading…
Cancel
Save