TUN-5299: Send/receive QUIC datagram from edge and proxy to origin as UDP
This commit is contained in:
parent
fc2333c934
commit
dd32dc1364
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -26,6 +27,7 @@ const (
|
|||
HTTPMethodKey = "HttpMethod"
|
||||
// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPHostKey = "HttpHost"
|
||||
MaxDatagramFrameSize = 1220
|
||||
)
|
||||
|
||||
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
||||
|
@ -33,8 +35,6 @@ type QUICConnection struct {
|
|||
session quic.Session
|
||||
logger *zerolog.Logger
|
||||
httpProxy OriginProxy
|
||||
gracefulShutdownC <-chan struct{}
|
||||
stoppedGracefully bool
|
||||
udpSessions *udpSessions
|
||||
}
|
||||
|
||||
|
@ -49,6 +49,12 @@ func NewQUICConnection(
|
|||
controlStreamHandler ControlStreamHandler,
|
||||
observer *Observer,
|
||||
) (*QUICConnection, error) {
|
||||
localIP, err := GetLocalIP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP)
|
||||
udpSessions := newUDPSessions(localIP)
|
||||
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial to edge: %w", err)
|
||||
|
@ -69,15 +75,24 @@ func NewQUICConnection(
|
|||
session: session,
|
||||
httpProxy: httpProxy,
|
||||
logger: observer.log,
|
||||
udpSessions: newUDPSessions(),
|
||||
udpSessions: udpSessions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve starts a QUIC session that begins accepting streams.
|
||||
func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
return q.listenEdgeDatagram()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
return q.acceptStream(ctx)
|
||||
})
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||
for {
|
||||
stream, err := q.session.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
|
@ -96,6 +111,26 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload
|
||||
func (q *QUICConnection) listenEdgeDatagram() error {
|
||||
for {
|
||||
msg, err := q.session.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func(msg []byte) {
|
||||
sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg)
|
||||
if err != nil {
|
||||
q.logger.Err(err).Msg("Failed to parse session ID from datagram")
|
||||
return
|
||||
}
|
||||
if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil {
|
||||
q.logger.Err(err).Msg("Failed to send UDP to origin")
|
||||
}
|
||||
}(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the session with no errors specified.
|
||||
func (q *QUICConnection) Close() {
|
||||
q.session.CloseWithError(0, "")
|
||||
|
@ -120,7 +155,7 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
|
|||
}
|
||||
return q.handleRPCStream(rpcStream)
|
||||
default:
|
||||
return fmt.Errorf("Unknown protocol %v", signature)
|
||||
return fmt.Errorf("unknown protocol %v", signature)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,7 +186,45 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
|
|||
}
|
||||
|
||||
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
return q.udpSessions.register(sessionID, dstIP, dstPort)
|
||||
udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort)
|
||||
go q.listenOriginUDP(sessionID, udpConn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin
|
||||
func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) {
|
||||
defer func() {
|
||||
q.udpSessions.unregister(sessionID)
|
||||
conn.Close()
|
||||
}()
|
||||
readBuffer := make([]byte, MaxDatagramFrameSize)
|
||||
for {
|
||||
n, err := conn.Read(readBuffer)
|
||||
if n > 0 {
|
||||
if n > MaxDatagramFrameSize-sessionIDLen {
|
||||
// TODO: TUN-5302 return ICMP packet too big message
|
||||
q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen)
|
||||
continue
|
||||
}
|
||||
msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n])
|
||||
if err != nil {
|
||||
q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped")
|
||||
continue
|
||||
}
|
||||
if err := q.session.SendMessage(msgWithID); err != nil {
|
||||
q.logger.Err(err).Msg("Failed to send datagram back to edge")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
q.logger.Err(err).Msg("Failed to read UDP from origin")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||
|
@ -208,7 +281,7 @@ func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadClose
|
|||
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
|
||||
httpHeaderKey := strings.Split(metadata.Key, ":")
|
||||
if len(httpHeaderKey) != 2 {
|
||||
return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key)
|
||||
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
|
||||
}
|
||||
req.Header.Add(httpHeaderKey[1], metadata.Val)
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
func TestQUICServer(t *testing.T) {
|
||||
quicConfig := &quic.Config{
|
||||
KeepAlive: true,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
|
||||
// Setup test.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
|
@ -8,18 +9,24 @@ import (
|
|||
)
|
||||
|
||||
// TODO: TUN-5422 Unregister session
|
||||
const (
|
||||
sessionIDLen = len(uuid.UUID{})
|
||||
)
|
||||
|
||||
type udpSessions struct {
|
||||
lock sync.Mutex
|
||||
lock sync.RWMutex
|
||||
sessions map[uuid.UUID]*net.UDPConn
|
||||
localIP net.IP
|
||||
}
|
||||
|
||||
func newUDPSessions() *udpSessions {
|
||||
func newUDPSessions(localIP net.IP) *udpSessions {
|
||||
return &udpSessions{
|
||||
sessions: make(map[uuid.UUID]*net.UDPConn),
|
||||
localIP: localIP,
|
||||
}
|
||||
}
|
||||
|
||||
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
|
||||
us.lock.Lock()
|
||||
defer us.lock.Unlock()
|
||||
dstAddr := &net.UDPAddr{
|
||||
|
@ -28,16 +35,56 @@ func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) erro
|
|||
}
|
||||
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
us.sessions[id] = conn
|
||||
return nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (us *udpSessions) unregister(id uuid.UUID) {
|
||||
us.lock.Lock()
|
||||
defer us.lock.Unlock()
|
||||
delete(us.sessions, id)
|
||||
}
|
||||
|
||||
func (us *udpSessions) send(id uuid.UUID, payload []byte) error {
|
||||
us.lock.RLock()
|
||||
defer us.lock.RUnlock()
|
||||
conn, ok := us.sessions[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("session %s not found", id)
|
||||
}
|
||||
_, err := conn.Write(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ud *udpSessions) localAddr() *net.UDPAddr {
|
||||
// TODO: Determine the IP to bind to
|
||||
|
||||
return &net.UDPAddr{
|
||||
IP: net.IPv4zero,
|
||||
IP: ud.localIP,
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: TUN-5421 allow user to specify which IP to bind to
|
||||
func GetLocalIP() (net.IP, error) {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
// Find the IP that is not loop back
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
}
|
||||
if !ip.IsLoopback() {
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("cannot determine IP to bind to")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionIDLen = len(uuid.UUID{})
|
||||
MaxDatagramFrameSize = 1220
|
||||
)
|
||||
|
||||
// Each QUIC datagram should be suffixed with session ID.
|
||||
// ExtractSessionID extracts the session ID and a slice with only the payload
|
||||
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {
|
||||
msgLen := len(b)
|
||||
if msgLen < sessionIDLen {
|
||||
return uuid.Nil, nil, fmt.Errorf("session ID has %d bytes, but data only has %d", sessionIDLen, len(b))
|
||||
}
|
||||
// Parse last 16 bytess as UUID and remove it from slice
|
||||
sessionID, err := uuid.FromBytes(b[len(b)-sessionIDLen:])
|
||||
if err != nil {
|
||||
return uuid.Nil, nil, err
|
||||
}
|
||||
b = b[:len(b)-sessionIDLen]
|
||||
return sessionID, b, nil
|
||||
}
|
||||
|
||||
// SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because
|
||||
// the payload slice might already have enough capacity to append the session ID at the end
|
||||
func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) {
|
||||
if len(b)+len(sessionID) > MaxDatagramFrameSize {
|
||||
return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize)
|
||||
}
|
||||
b = append(b, sessionID[:]...)
|
||||
return b, nil
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
testSessionID = uuid.New()
|
||||
)
|
||||
|
||||
func TestSuffixThenRemoveSessionID(t *testing.T) {
|
||||
msg := []byte(t.Name())
|
||||
msgWithID, err := SuffixSessionID(testSessionID, msg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, msgWithID, len(msg)+sessionIDLen)
|
||||
|
||||
sessionID, msgWithoutID, err := ExtractSessionID(msgWithID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, msgWithoutID)
|
||||
require.Equal(t, testSessionID, sessionID)
|
||||
}
|
||||
|
||||
func TestRemoveSessionIDError(t *testing.T) {
|
||||
// message is too short to contain session ID
|
||||
msg := []byte("test")
|
||||
_, _, err := ExtractSessionID(msg)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSuffixSessionIDError(t *testing.T) {
|
||||
msg := make([]byte, MaxDatagramFrameSize-sessionIDLen)
|
||||
_, err := SuffixSessionID(testSessionID, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg = make([]byte, MaxDatagramFrameSize-sessionIDLen+1)
|
||||
_, err = SuffixSessionID(testSessionID, msg)
|
||||
require.Error(t, err)
|
||||
}
|
Loading…
Reference in New Issue