TUN-5301: Separate datagram multiplex and session management logic from quic connection logic
This commit is contained in:
parent
dd32dc1364
commit
eea3d11e40
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/datagramsession"
|
||||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
@ -35,7 +36,8 @@ type QUICConnection struct {
|
||||||
session quic.Session
|
session quic.Session
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
httpProxy OriginProxy
|
httpProxy OriginProxy
|
||||||
udpSessions *udpSessions
|
sessionManager datagramsession.Manager
|
||||||
|
localIP net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQUICConnection returns a new instance of QUICConnection.
|
// NewQUICConnection returns a new instance of QUICConnection.
|
||||||
|
@ -49,12 +51,6 @@ func NewQUICConnection(
|
||||||
controlStreamHandler ControlStreamHandler,
|
controlStreamHandler ControlStreamHandler,
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
) (*QUICConnection, error) {
|
) (*QUICConnection, error) {
|
||||||
localIP, err := GetLocalIP()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP)
|
|
||||||
udpSessions := newUDPSessions(localIP)
|
|
||||||
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial to edge: %w", err)
|
return nil, fmt.Errorf("failed to dial to edge: %w", err)
|
||||||
|
@ -71,24 +67,36 @@ func NewQUICConnection(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
datagramMuxer, err := quicpogs.NewDatagramMuxer(session)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionManager := datagramsession.NewManager(datagramMuxer, observer.log)
|
||||||
|
|
||||||
|
localIP, err := getLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &QUICConnection{
|
return &QUICConnection{
|
||||||
session: session,
|
session: session,
|
||||||
httpProxy: httpProxy,
|
httpProxy: httpProxy,
|
||||||
logger: observer.log,
|
logger: observer.log,
|
||||||
udpSessions: udpSessions,
|
sessionManager: sessionManager,
|
||||||
|
localIP: localIP,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve starts a QUIC session that begins accepting streams.
|
// Serve starts a QUIC session that begins accepting streams.
|
||||||
func (q *QUICConnection) Serve(ctx context.Context) error {
|
func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
|
||||||
return q.listenEdgeDatagram()
|
|
||||||
})
|
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
return q.acceptStream(ctx)
|
return q.acceptStream(ctx)
|
||||||
})
|
})
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
return q.sessionManager.Serve(ctx)
|
||||||
|
})
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,26 +119,6 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload
|
|
||||||
func (q *QUICConnection) listenEdgeDatagram() error {
|
|
||||||
for {
|
|
||||||
msg, err := q.session.ReceiveMessage()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
go func(msg []byte) {
|
|
||||||
sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg)
|
|
||||||
if err != nil {
|
|
||||||
q.logger.Err(err).Msg("Failed to parse session ID from datagram")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil {
|
|
||||||
q.logger.Err(err).Msg("Failed to send UDP to origin")
|
|
||||||
}
|
|
||||||
}(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the session with no errors specified.
|
// Close closes the session with no errors specified.
|
||||||
func (q *QUICConnection) Close() {
|
func (q *QUICConnection) Close() {
|
||||||
q.session.CloseWithError(0, "")
|
q.session.CloseWithError(0, "")
|
||||||
|
@ -186,46 +174,29 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||||
udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort)
|
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
|
||||||
|
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
|
||||||
|
originProxy, err := q.newUDPProxy(dstIP, dstPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
q.logger.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort)
|
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
|
||||||
go q.listenOriginUDP(sessionID, udpConn)
|
if err != nil {
|
||||||
|
q.logger.Err(err).Msgf("Failed to register udp session %s", sessionID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer q.sessionManager.UnregisterSession(q.session.Context(), sessionID)
|
||||||
|
if err := session.Serve(q.session.Context()); err != nil {
|
||||||
|
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).Msg("session terminated")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin
|
// TODO: TUN-5422 Implement UnregisterUdpSession RPC
|
||||||
func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) {
|
|
||||||
defer func() {
|
|
||||||
q.udpSessions.unregister(sessionID)
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
readBuffer := make([]byte, MaxDatagramFrameSize)
|
|
||||||
for {
|
|
||||||
n, err := conn.Read(readBuffer)
|
|
||||||
if n > 0 {
|
|
||||||
if n > MaxDatagramFrameSize-sessionIDLen {
|
|
||||||
// TODO: TUN-5302 return ICMP packet too big message
|
|
||||||
q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n])
|
|
||||||
if err != nil {
|
|
||||||
q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := q.session.SendMessage(msgWithID); err != nil {
|
|
||||||
q.logger.Err(err).Msg("Failed to send datagram back to edge")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
q.logger.Err(err).Msg("Failed to read UDP from origin")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||||
// the client.
|
// the client.
|
||||||
|
@ -320,3 +291,35 @@ func isTransferEncodingChunked(req *http.Request) bool {
|
||||||
// separated value as well.
|
// separated value as well.
|
||||||
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
|
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: TUN-5303: Define an UDPProxy in ingress package
|
||||||
|
func (q *QUICConnection) newUDPProxy(dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
|
||||||
|
dstAddr := &net.UDPAddr{
|
||||||
|
IP: dstIP,
|
||||||
|
Port: int(dstPort),
|
||||||
|
}
|
||||||
|
return net.DialUDP("udp", nil, dstAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: TUN-5303: Find the local IP once in ingress package
|
||||||
|
// TODO: TUN-5421 allow user to specify which IP to bind to
|
||||||
|
func getLocalIP() (net.IP, error) {
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
// Find the IP that is not loop back
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
}
|
||||||
|
if !ip.IsLoopback() {
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("cannot determine IP to bind to")
|
||||||
|
}
|
||||||
|
|
|
@ -1,90 +0,0 @@
|
||||||
package connection
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: TUN-5422 Unregister session
|
|
||||||
const (
|
|
||||||
sessionIDLen = len(uuid.UUID{})
|
|
||||||
)
|
|
||||||
|
|
||||||
type udpSessions struct {
|
|
||||||
lock sync.RWMutex
|
|
||||||
sessions map[uuid.UUID]*net.UDPConn
|
|
||||||
localIP net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
func newUDPSessions(localIP net.IP) *udpSessions {
|
|
||||||
return &udpSessions{
|
|
||||||
sessions: make(map[uuid.UUID]*net.UDPConn),
|
|
||||||
localIP: localIP,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
|
|
||||||
us.lock.Lock()
|
|
||||||
defer us.lock.Unlock()
|
|
||||||
dstAddr := &net.UDPAddr{
|
|
||||||
IP: dstIP,
|
|
||||||
Port: int(dstPort),
|
|
||||||
}
|
|
||||||
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
us.sessions[id] = conn
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (us *udpSessions) unregister(id uuid.UUID) {
|
|
||||||
us.lock.Lock()
|
|
||||||
defer us.lock.Unlock()
|
|
||||||
delete(us.sessions, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (us *udpSessions) send(id uuid.UUID, payload []byte) error {
|
|
||||||
us.lock.RLock()
|
|
||||||
defer us.lock.RUnlock()
|
|
||||||
conn, ok := us.sessions[id]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("session %s not found", id)
|
|
||||||
}
|
|
||||||
_, err := conn.Write(payload)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ud *udpSessions) localAddr() *net.UDPAddr {
|
|
||||||
// TODO: Determine the IP to bind to
|
|
||||||
|
|
||||||
return &net.UDPAddr{
|
|
||||||
IP: ud.localIP,
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: TUN-5421 allow user to specify which IP to bind to
|
|
||||||
func GetLocalIP() (net.IP, error) {
|
|
||||||
addrs, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, addr := range addrs {
|
|
||||||
// Find the IP that is not loop back
|
|
||||||
var ip net.IP
|
|
||||||
switch v := addr.(type) {
|
|
||||||
case *net.IPNet:
|
|
||||||
ip = v.IP
|
|
||||||
case *net.IPAddr:
|
|
||||||
ip = v.IP
|
|
||||||
}
|
|
||||||
if !ip.IsLoopback() {
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("cannot determine IP to bind to")
|
|
||||||
}
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerSessionEvent is an event to start tracking a new session
|
||||||
|
type registerSessionEvent struct {
|
||||||
|
sessionID uuid.UUID
|
||||||
|
originProxy io.ReadWriteCloser
|
||||||
|
resultChan chan *Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser) *registerSessionEvent {
|
||||||
|
return ®isterSessionEvent{
|
||||||
|
sessionID: sessionID,
|
||||||
|
originProxy: originProxy,
|
||||||
|
resultChan: make(chan *Session, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unregisterSessionEvent is an event to stop tracking and terminate the session.
|
||||||
|
type unregisterSessionEvent struct {
|
||||||
|
sessionID uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDatagram is an event when transport receives new datagram
|
||||||
|
type newDatagram struct {
|
||||||
|
sessionID uuid.UUID
|
||||||
|
payload []byte
|
||||||
|
}
|
|
@ -0,0 +1,134 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
requestChanCapacity = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager defines the APIs to manage sessions from the same transport.
|
||||||
|
type Manager interface {
|
||||||
|
// Serve starts the event loop
|
||||||
|
Serve(ctx context.Context) error
|
||||||
|
// RegisterSession starts tracking a session. Caller is responsible for starting the session
|
||||||
|
RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error)
|
||||||
|
// UnregisterSession stops tracking the session and terminates it
|
||||||
|
UnregisterSession(ctx context.Context, sessionID uuid.UUID) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type manager struct {
|
||||||
|
registrationChan chan *registerSessionEvent
|
||||||
|
unregistrationChan chan *unregisterSessionEvent
|
||||||
|
datagramChan chan *newDatagram
|
||||||
|
transport transport
|
||||||
|
sessions map[uuid.UUID]*Session
|
||||||
|
log *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(transport transport, log *zerolog.Logger) Manager {
|
||||||
|
return &manager{
|
||||||
|
registrationChan: make(chan *registerSessionEvent),
|
||||||
|
unregistrationChan: make(chan *unregisterSessionEvent),
|
||||||
|
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
|
||||||
|
datagramChan: make(chan *newDatagram, requestChanCapacity),
|
||||||
|
transport: transport,
|
||||||
|
sessions: make(map[uuid.UUID]*Session),
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) Serve(ctx context.Context) error {
|
||||||
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
for {
|
||||||
|
sessionID, payload, err := m.transport.ReceiveFrom()
|
||||||
|
if err != nil {
|
||||||
|
m.log.Err(err).Msg("Failed to receive datagram from transport, closing session manager")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
datagram := &newDatagram{
|
||||||
|
sessionID: sessionID,
|
||||||
|
payload: payload,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
|
||||||
|
// Send the datagram to the event loop. It will find the session to send to
|
||||||
|
case m.datagramChan <- datagram:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case datagram := <-m.datagramChan:
|
||||||
|
m.sendToSession(datagram)
|
||||||
|
case registration := <-m.registrationChan:
|
||||||
|
m.registerSession(ctx, registration)
|
||||||
|
// TODO: TUN-5422: Unregister inactive session upon timeout
|
||||||
|
case unregistration := <-m.unregistrationChan:
|
||||||
|
m.unregisterSession(unregistration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return errGroup.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
|
||||||
|
event := newRegisterSessionEvent(sessionID, originProxy)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case m.registrationChan <- event:
|
||||||
|
session := <-event.resultChan
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
|
||||||
|
session := newSession(registration.sessionID, m.transport, registration.originProxy)
|
||||||
|
m.sessions[registration.sessionID] = session
|
||||||
|
registration.resultChan <- session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID) error {
|
||||||
|
event := &unregisterSessionEvent{sessionID: sessionID}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case m.unregistrationChan <- event:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
|
||||||
|
session, ok := m.sessions[unregistration.sessionID]
|
||||||
|
if ok {
|
||||||
|
delete(m.sessions, unregistration.sessionID)
|
||||||
|
session.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) sendToSession(datagram *newDatagram) {
|
||||||
|
session, ok := m.sessions[datagram.sessionID]
|
||||||
|
if !ok {
|
||||||
|
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
||||||
|
// need to run in another go routine
|
||||||
|
_, err := session.writeToDst(datagram.payload)
|
||||||
|
if err != nil {
|
||||||
|
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,214 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManagerServe(t *testing.T) {
|
||||||
|
const (
|
||||||
|
sessions = 20
|
||||||
|
msgs = 50
|
||||||
|
)
|
||||||
|
log := zerolog.Nop()
|
||||||
|
transport := &mockQUICTransport{
|
||||||
|
reqChan: newDatagramChannel(),
|
||||||
|
respChan: newDatagramChannel(),
|
||||||
|
}
|
||||||
|
mg := NewManager(transport, &log)
|
||||||
|
|
||||||
|
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
|
||||||
|
for i := 0; i < sessions; i++ {
|
||||||
|
sessionID := uuid.New()
|
||||||
|
eyeballTracker[sessionID] = newDatagramChannel()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
serveDone := make(chan struct{})
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
mg.Serve(ctx)
|
||||||
|
close(serveDone)
|
||||||
|
}(ctx)
|
||||||
|
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
sessionID, payload, err := transport.respChan.Receive(ctx)
|
||||||
|
if err != nil {
|
||||||
|
require.Equal(t, context.Canceled, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
respChan := eyeballTracker[sessionID]
|
||||||
|
require.NoError(t, respChan.Send(ctx, sessionID, payload))
|
||||||
|
}
|
||||||
|
}(ctx)
|
||||||
|
|
||||||
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
for sID, receiver := range eyeballTracker {
|
||||||
|
// Assign loop variables to local variables
|
||||||
|
sessionID := sID
|
||||||
|
eyeballRespReceiver := receiver
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
payload := testPayload(sessionID)
|
||||||
|
expectResp := testResponse(payload)
|
||||||
|
|
||||||
|
cfdConn, originConn := net.Pipe()
|
||||||
|
|
||||||
|
origin := mockOrigin{
|
||||||
|
expectMsgCount: msgs,
|
||||||
|
expectedMsg: payload,
|
||||||
|
expectedResp: expectResp,
|
||||||
|
conn: originConn,
|
||||||
|
}
|
||||||
|
eyeball := mockEyeball{
|
||||||
|
expectMsgCount: msgs,
|
||||||
|
expectedMsg: expectResp,
|
||||||
|
expectSessionID: sessionID,
|
||||||
|
respReceiver: eyeballRespReceiver,
|
||||||
|
}
|
||||||
|
|
||||||
|
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
|
||||||
|
reqErrGroup.Go(func() error {
|
||||||
|
return origin.serve()
|
||||||
|
})
|
||||||
|
reqErrGroup.Go(func() error {
|
||||||
|
return eyeball.serve(reqCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sessionDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
session.Serve(ctx)
|
||||||
|
close(sessionDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < msgs; i++ {
|
||||||
|
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure eyeball and origin have received all messages before unregistering the session
|
||||||
|
require.NoError(t, reqErrGroup.Wait())
|
||||||
|
|
||||||
|
require.NoError(t, mg.UnregisterSession(ctx, sessionID))
|
||||||
|
<-sessionDone
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, errGroup.Wait())
|
||||||
|
cancel()
|
||||||
|
transport.close()
|
||||||
|
<-serveDone
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockOrigin struct {
|
||||||
|
expectMsgCount int
|
||||||
|
expectedMsg []byte
|
||||||
|
expectedResp []byte
|
||||||
|
conn io.ReadWriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mo *mockOrigin) serve() error {
|
||||||
|
expectedMsgLen := len(mo.expectedMsg)
|
||||||
|
readBuffer := make([]byte, expectedMsgLen+1)
|
||||||
|
for i := 0; i < mo.expectMsgCount; i++ {
|
||||||
|
n, err := mo.conn.Read(readBuffer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n != expectedMsgLen {
|
||||||
|
return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
|
||||||
|
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = mo.conn.Write(mo.expectedResp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPayload(sessionID uuid.UUID) []byte {
|
||||||
|
return []byte(fmt.Sprintf("Message from %s", sessionID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testResponse(msg []byte) []byte {
|
||||||
|
return []byte(fmt.Sprintf("Response to %v", msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockEyeball struct {
|
||||||
|
expectMsgCount int
|
||||||
|
expectedMsg []byte
|
||||||
|
expectSessionID uuid.UUID
|
||||||
|
respReceiver *datagramChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (me *mockEyeball) serve(ctx context.Context) error {
|
||||||
|
for i := 0; i < me.expectMsgCount; i++ {
|
||||||
|
sessionID, msg, err := me.respReceiver.Receive(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if sessionID != me.expectSessionID {
|
||||||
|
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(msg, me.expectedMsg) {
|
||||||
|
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
|
||||||
|
type datagramChannel struct {
|
||||||
|
datagramChan chan *newDatagram
|
||||||
|
closedChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDatagramChannel() *datagramChannel {
|
||||||
|
return &datagramChannel{
|
||||||
|
datagramChan: make(chan *newDatagram, 1),
|
||||||
|
closedChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-rc.closedChan:
|
||||||
|
return fmt.Errorf("datagram channel closed")
|
||||||
|
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return uuid.Nil, nil, ctx.Err()
|
||||||
|
case <-rc.closedChan:
|
||||||
|
return uuid.Nil, nil, fmt.Errorf("datagram channel closed")
|
||||||
|
case msg := <-rc.datagramChan:
|
||||||
|
return msg.sessionID, msg.payload, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *datagramChannel) Close() {
|
||||||
|
// No need to close msgChan, it will be garbage collect once there is no reference to it
|
||||||
|
close(rc.closedChan)
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Each Session is a bidirectional pipe of datagrams between transport and dstConn
|
||||||
|
// Currently the only implementation of transport is quic DatagramMuxer
|
||||||
|
// Destination can be a connection with origin or with eyeball
|
||||||
|
// When the destination is origin:
|
||||||
|
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
||||||
|
// write method of the Session to send to origin
|
||||||
|
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
|
||||||
|
// When the destination is eyeball:
|
||||||
|
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
|
||||||
|
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
||||||
|
// write method of the Session to send to eyeball
|
||||||
|
type Session struct {
|
||||||
|
id uuid.UUID
|
||||||
|
transport transport
|
||||||
|
dstConn io.ReadWriteCloser
|
||||||
|
doneChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session {
|
||||||
|
return &Session{
|
||||||
|
id: id,
|
||||||
|
transport: transport,
|
||||||
|
dstConn: dstConn,
|
||||||
|
doneChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Serve(ctx context.Context) error {
|
||||||
|
serveCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-serveCtx.Done():
|
||||||
|
case <-s.doneChan:
|
||||||
|
}
|
||||||
|
s.dstConn.Close()
|
||||||
|
}()
|
||||||
|
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||||
|
// This makes it safe to share readBuffer between iterations
|
||||||
|
readBuffer := make([]byte, 1280)
|
||||||
|
for {
|
||||||
|
// TODO: TUN-5303: origin proxy should determine the buffer size
|
||||||
|
n, err := s.dstConn.Read(readBuffer)
|
||||||
|
if n > 0 {
|
||||||
|
if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) writeToDst(payload []byte) (int, error) {
|
||||||
|
return s.dstConn.Write(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) close() {
|
||||||
|
close(s.doneChan)
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestCloseSession makes sure a session will stop after context is done
|
||||||
|
func TestSessionCtxDone(t *testing.T) {
|
||||||
|
testSessionReturns(t, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCloseSession makes sure a session will stop after close method is called
|
||||||
|
func TestCloseSession(t *testing.T) {
|
||||||
|
testSessionReturns(t, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSessionReturns(t *testing.T, closeByContext bool) {
|
||||||
|
sessionID := uuid.New()
|
||||||
|
cfdConn, originConn := net.Pipe()
|
||||||
|
payload := testPayload(sessionID)
|
||||||
|
transport := &mockQUICTransport{
|
||||||
|
reqChan: newDatagramChannel(),
|
||||||
|
respChan: newDatagramChannel(),
|
||||||
|
}
|
||||||
|
session := newSession(sessionID, transport, cfdConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
sessionDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
session.Serve(ctx)
|
||||||
|
close(sessionDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
n, err := session.writeToDst(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(payload), n)
|
||||||
|
}()
|
||||||
|
|
||||||
|
readBuffer := make([]byte, len(payload)+1)
|
||||||
|
n, err := originConn.Read(readBuffer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(payload), n)
|
||||||
|
|
||||||
|
if closeByContext {
|
||||||
|
cancel()
|
||||||
|
} else {
|
||||||
|
session.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
<-sessionDone
|
||||||
|
// call cancelled again otherwise the linter will warn about possible context leak
|
||||||
|
cancel()
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import "github.com/google/uuid"
|
||||||
|
|
||||||
|
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
|
||||||
|
type transport interface {
|
||||||
|
// SendTo writes payload for a session to the transport
|
||||||
|
SendTo(sessionID uuid.UUID, payload []byte) error
|
||||||
|
// ReceiveFrom reads the next datagram from the transport
|
||||||
|
ReceiveFrom() (uuid.UUID, []byte, error)
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package datagramsession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockQUICTransport struct {
|
||||||
|
reqChan *datagramChannel
|
||||||
|
respChan *datagramChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
buf := make([]byte, len(payload))
|
||||||
|
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||||
|
copy(buf, payload)
|
||||||
|
return mt.respChan.Send(context.Background(), sessionID, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
|
||||||
|
return mt.reqChan.Receive(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
||||||
|
return mt.reqChan.Send(ctx, sessionID, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mt *mockQUICTransport) close() {
|
||||||
|
mt.reqChan.Close()
|
||||||
|
mt.respChan.Close()
|
||||||
|
}
|
|
@ -4,13 +4,59 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sessionIDLen = len(uuid.UUID{})
|
|
||||||
MaxDatagramFrameSize = 1220
|
MaxDatagramFrameSize = 1220
|
||||||
|
sessionIDLen = len(uuid.UUID{})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type DatagramMuxer struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
session quic.Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatagramMuxer(quicSession quic.Session) (*DatagramMuxer, error) {
|
||||||
|
muxerID, err := uuid.NewRandom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &DatagramMuxer{
|
||||||
|
ID: muxerID,
|
||||||
|
session: quicSession,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex
|
||||||
|
// the payload from multiple datagram sessions
|
||||||
|
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
if len(payload) > MaxDatagramFrameSize-sessionIDLen {
|
||||||
|
// TODO: TUN-5302 return ICMP packet too big message
|
||||||
|
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), MaxDatagramFrameSize-sessionIDLen)
|
||||||
|
}
|
||||||
|
msgWithID, err := SuffixSessionID(sessionID, payload)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
||||||
|
}
|
||||||
|
if err := dm.session.SendMessage(msgWithID); err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager
|
||||||
|
// which determines how to proxy to the origin. It assumes the datagram session has already been
|
||||||
|
// registered with session manager through other side channel
|
||||||
|
func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
|
||||||
|
msg, err := dm.session.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return uuid.Nil, nil, err
|
||||||
|
}
|
||||||
|
return ExtractSessionID(msg)
|
||||||
|
}
|
||||||
|
|
||||||
// Each QUIC datagram should be suffixed with session ID.
|
// Each QUIC datagram should be suffixed with session ID.
|
||||||
// ExtractSessionID extracts the session ID and a slice with only the payload
|
// ExtractSessionID extracts the session ID and a slice with only the payload
|
||||||
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {
|
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {
|
||||||
|
|
Loading…
Reference in New Issue