TUN-5299: Send/receive QUIC datagram from edge and proxy to origin as UDP
This commit is contained in:
parent
fc2333c934
commit
dd32dc1364
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -26,6 +27,7 @@ const (
|
||||||
HTTPMethodKey = "HttpMethod"
|
HTTPMethodKey = "HttpMethod"
|
||||||
// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
|
// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||||
HTTPHostKey = "HttpHost"
|
HTTPHostKey = "HttpHost"
|
||||||
|
MaxDatagramFrameSize = 1220
|
||||||
)
|
)
|
||||||
|
|
||||||
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
||||||
|
@ -33,8 +35,6 @@ type QUICConnection struct {
|
||||||
session quic.Session
|
session quic.Session
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
httpProxy OriginProxy
|
httpProxy OriginProxy
|
||||||
gracefulShutdownC <-chan struct{}
|
|
||||||
stoppedGracefully bool
|
|
||||||
udpSessions *udpSessions
|
udpSessions *udpSessions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,6 +49,12 @@ func NewQUICConnection(
|
||||||
controlStreamHandler ControlStreamHandler,
|
controlStreamHandler ControlStreamHandler,
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
) (*QUICConnection, error) {
|
) (*QUICConnection, error) {
|
||||||
|
localIP, err := GetLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP)
|
||||||
|
udpSessions := newUDPSessions(localIP)
|
||||||
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial to edge: %w", err)
|
return nil, fmt.Errorf("failed to dial to edge: %w", err)
|
||||||
|
@ -69,15 +75,24 @@ func NewQUICConnection(
|
||||||
session: session,
|
session: session,
|
||||||
httpProxy: httpProxy,
|
httpProxy: httpProxy,
|
||||||
logger: observer.log,
|
logger: observer.log,
|
||||||
udpSessions: newUDPSessions(),
|
udpSessions: udpSessions,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve starts a QUIC session that begins accepting streams.
|
// Serve starts a QUIC session that begins accepting streams.
|
||||||
func (q *QUICConnection) Serve(ctx context.Context) error {
|
func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
defer cancel()
|
errGroup.Go(func() error {
|
||||||
|
return q.listenEdgeDatagram()
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
return q.acceptStream(ctx)
|
||||||
|
})
|
||||||
|
return errGroup.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||||
for {
|
for {
|
||||||
stream, err := q.session.AcceptStream(ctx)
|
stream, err := q.session.AcceptStream(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -96,6 +111,26 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload
|
||||||
|
func (q *QUICConnection) listenEdgeDatagram() error {
|
||||||
|
for {
|
||||||
|
msg, err := q.session.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go func(msg []byte) {
|
||||||
|
sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg)
|
||||||
|
if err != nil {
|
||||||
|
q.logger.Err(err).Msg("Failed to parse session ID from datagram")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil {
|
||||||
|
q.logger.Err(err).Msg("Failed to send UDP to origin")
|
||||||
|
}
|
||||||
|
}(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Close closes the session with no errors specified.
|
// Close closes the session with no errors specified.
|
||||||
func (q *QUICConnection) Close() {
|
func (q *QUICConnection) Close() {
|
||||||
q.session.CloseWithError(0, "")
|
q.session.CloseWithError(0, "")
|
||||||
|
@ -120,7 +155,7 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
|
||||||
}
|
}
|
||||||
return q.handleRPCStream(rpcStream)
|
return q.handleRPCStream(rpcStream)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Unknown protocol %v", signature)
|
return fmt.Errorf("unknown protocol %v", signature)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,7 +186,45 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||||
return q.udpSessions.register(sessionID, dstIP, dstPort)
|
udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort)
|
||||||
|
go q.listenOriginUDP(sessionID, udpConn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin
|
||||||
|
func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) {
|
||||||
|
defer func() {
|
||||||
|
q.udpSessions.unregister(sessionID)
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
|
readBuffer := make([]byte, MaxDatagramFrameSize)
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(readBuffer)
|
||||||
|
if n > 0 {
|
||||||
|
if n > MaxDatagramFrameSize-sessionIDLen {
|
||||||
|
// TODO: TUN-5302 return ICMP packet too big message
|
||||||
|
q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := q.session.SendMessage(msgWithID); err != nil {
|
||||||
|
q.logger.Err(err).Msg("Failed to send datagram back to edge")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
q.logger.Err(err).Msg("Failed to read UDP from origin")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||||
|
@ -208,7 +281,7 @@ func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadClose
|
||||||
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
|
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
|
||||||
httpHeaderKey := strings.Split(metadata.Key, ":")
|
httpHeaderKey := strings.Split(metadata.Key, ":")
|
||||||
if len(httpHeaderKey) != 2 {
|
if len(httpHeaderKey) != 2 {
|
||||||
return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key)
|
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
|
||||||
}
|
}
|
||||||
req.Header.Add(httpHeaderKey[1], metadata.Val)
|
req.Header.Add(httpHeaderKey[1], metadata.Val)
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,7 @@ import (
|
||||||
func TestQUICServer(t *testing.T) {
|
func TestQUICServer(t *testing.T) {
|
||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
|
EnableDatagrams: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup test.
|
// Setup test.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -8,18 +9,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: TUN-5422 Unregister session
|
// TODO: TUN-5422 Unregister session
|
||||||
|
const (
|
||||||
|
sessionIDLen = len(uuid.UUID{})
|
||||||
|
)
|
||||||
|
|
||||||
type udpSessions struct {
|
type udpSessions struct {
|
||||||
lock sync.Mutex
|
lock sync.RWMutex
|
||||||
sessions map[uuid.UUID]*net.UDPConn
|
sessions map[uuid.UUID]*net.UDPConn
|
||||||
|
localIP net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPSessions() *udpSessions {
|
func newUDPSessions(localIP net.IP) *udpSessions {
|
||||||
return &udpSessions{
|
return &udpSessions{
|
||||||
sessions: make(map[uuid.UUID]*net.UDPConn),
|
sessions: make(map[uuid.UUID]*net.UDPConn),
|
||||||
|
localIP: localIP,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
|
||||||
us.lock.Lock()
|
us.lock.Lock()
|
||||||
defer us.lock.Unlock()
|
defer us.lock.Unlock()
|
||||||
dstAddr := &net.UDPAddr{
|
dstAddr := &net.UDPAddr{
|
||||||
|
@ -28,16 +35,56 @@ func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) erro
|
||||||
}
|
}
|
||||||
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr)
|
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
us.sessions[id] = conn
|
us.sessions[id] = conn
|
||||||
return nil
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (us *udpSessions) unregister(id uuid.UUID) {
|
||||||
|
us.lock.Lock()
|
||||||
|
defer us.lock.Unlock()
|
||||||
|
delete(us.sessions, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (us *udpSessions) send(id uuid.UUID, payload []byte) error {
|
||||||
|
us.lock.RLock()
|
||||||
|
defer us.lock.RUnlock()
|
||||||
|
conn, ok := us.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("session %s not found", id)
|
||||||
|
}
|
||||||
|
_, err := conn.Write(payload)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ud *udpSessions) localAddr() *net.UDPAddr {
|
func (ud *udpSessions) localAddr() *net.UDPAddr {
|
||||||
// TODO: Determine the IP to bind to
|
// TODO: Determine the IP to bind to
|
||||||
|
|
||||||
return &net.UDPAddr{
|
return &net.UDPAddr{
|
||||||
IP: net.IPv4zero,
|
IP: ud.localIP,
|
||||||
Port: 0,
|
Port: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: TUN-5421 allow user to specify which IP to bind to
|
||||||
|
func GetLocalIP() (net.IP, error) {
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
// Find the IP that is not loop back
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
}
|
||||||
|
if !ip.IsLoopback() {
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("cannot determine IP to bind to")
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionIDLen = len(uuid.UUID{})
|
||||||
|
MaxDatagramFrameSize = 1220
|
||||||
|
)
|
||||||
|
|
||||||
|
// Each QUIC datagram should be suffixed with session ID.
|
||||||
|
// ExtractSessionID extracts the session ID and a slice with only the payload
|
||||||
|
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {
|
||||||
|
msgLen := len(b)
|
||||||
|
if msgLen < sessionIDLen {
|
||||||
|
return uuid.Nil, nil, fmt.Errorf("session ID has %d bytes, but data only has %d", sessionIDLen, len(b))
|
||||||
|
}
|
||||||
|
// Parse last 16 bytess as UUID and remove it from slice
|
||||||
|
sessionID, err := uuid.FromBytes(b[len(b)-sessionIDLen:])
|
||||||
|
if err != nil {
|
||||||
|
return uuid.Nil, nil, err
|
||||||
|
}
|
||||||
|
b = b[:len(b)-sessionIDLen]
|
||||||
|
return sessionID, b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because
|
||||||
|
// the payload slice might already have enough capacity to append the session ID at the end
|
||||||
|
func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) {
|
||||||
|
if len(b)+len(sessionID) > MaxDatagramFrameSize {
|
||||||
|
return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize)
|
||||||
|
}
|
||||||
|
b = append(b, sessionID[:]...)
|
||||||
|
return b, nil
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testSessionID = uuid.New()
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSuffixThenRemoveSessionID(t *testing.T) {
|
||||||
|
msg := []byte(t.Name())
|
||||||
|
msgWithID, err := SuffixSessionID(testSessionID, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, msgWithID, len(msg)+sessionIDLen)
|
||||||
|
|
||||||
|
sessionID, msgWithoutID, err := ExtractSessionID(msgWithID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msg, msgWithoutID)
|
||||||
|
require.Equal(t, testSessionID, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveSessionIDError(t *testing.T) {
|
||||||
|
// message is too short to contain session ID
|
||||||
|
msg := []byte("test")
|
||||||
|
_, _, err := ExtractSessionID(msg)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSuffixSessionIDError(t *testing.T) {
|
||||||
|
msg := make([]byte, MaxDatagramFrameSize-sessionIDLen)
|
||||||
|
_, err := SuffixSessionID(testSessionID, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
msg = make([]byte, MaxDatagramFrameSize-sessionIDLen+1)
|
||||||
|
_, err = SuffixSessionID(testSessionID, msg)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
Loading…
Reference in New Issue