cloudflared-mirror/datagramsession/manager.go

178 lines
5.6 KiB
Go

package datagramsession
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/packet"
)
const (
requestChanCapacity = 16
defaultReqTimeout = time.Second * 5
)
var (
errSessionManagerClosed = fmt.Errorf("session manager closed")
)
// Manager defines the APIs to manage sessions from the same transport.
type Manager interface {
// Serve starts the event loop
Serve(ctx context.Context) error
// RegisterSession starts tracking a session. Caller is responsible for starting the session
RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error)
// UnregisterSession stops tracking the session and terminates it
UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error
// UpdateLogger updates the logger used by the Manager
UpdateLogger(log *zerolog.Logger)
}
type manager struct {
registrationChan chan *registerSessionEvent
unregistrationChan chan *unregisterSessionEvent
sendFunc transportSender
receiveChan <-chan *packet.Session
closedChan <-chan struct{}
sessions map[uuid.UUID]*Session
log *zerolog.Logger
// timeout waiting for an API to finish. This can be overriden in test
timeout time.Duration
}
func NewManager(log *zerolog.Logger, sendF transportSender, receiveChan <-chan *packet.Session) *manager {
return &manager{
registrationChan: make(chan *registerSessionEvent),
unregistrationChan: make(chan *unregisterSessionEvent),
sendFunc: sendF,
receiveChan: receiveChan,
closedChan: make(chan struct{}),
sessions: make(map[uuid.UUID]*Session),
log: log,
timeout: defaultReqTimeout,
}
}
func (m *manager) UpdateLogger(log *zerolog.Logger) {
// Benign data race, no problem if the old pointer is read or not concurrently.
m.log = log
}
func (m *manager) Serve(ctx context.Context) error {
for {
select {
case <-ctx.Done():
m.shutdownSessions(ctx.Err())
return ctx.Err()
// receiveChan is buffered, so the transport can read more datagrams from transport while the event loop is
// processing other events
case datagram := <-m.receiveChan:
m.sendToSession(datagram)
case registration := <-m.registrationChan:
m.registerSession(ctx, registration)
case unregistration := <-m.unregistrationChan:
m.unregisterSession(unregistration)
}
}
}
func (m *manager) shutdownSessions(err error) {
if err == nil {
err = errSessionManagerClosed
}
closeSessionErr := &errClosedSession{
message: err.Error(),
}
// Usually connection with remote has been closed, so set this to true to skip unregistering from remote
// context.Canceled is an exception because that means session is being closed by our side
closeSessionErr.byRemote = !errors.Is(err, context.Canceled)
for _, s := range m.sessions {
s.close(closeSessionErr)
}
}
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
event := newRegisterSessionEvent(sessionID, originProxy)
select {
case <-ctx.Done():
m.log.Error().Msg("Datagram session registration timeout")
return nil, ctx.Err()
case m.registrationChan <- event:
session := <-event.resultChan
return session, nil
// Once closedChan is closed, manager won't accept more registration because nothing is
// reading from registrationChan and it's an unbuffered channel
case <-m.closedChan:
return nil, errSessionManagerClosed
}
}
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
session := m.newSession(registration.sessionID, registration.originProxy)
m.sessions[registration.sessionID] = session
registration.resultChan <- session
}
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
logger := m.log.With().Str("sessionID", id.String()).Logger()
return &Session{
ID: id,
sendFunc: m.sendFunc,
dstConn: dstConn,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2),
log: &logger,
}
}
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
event := &unregisterSessionEvent{
sessionID: sessionID,
err: &errClosedSession{
message: message,
byRemote: byRemote,
},
}
select {
case <-ctx.Done():
m.log.Error().Msg("Datagram session unregistration timeout")
return ctx.Err()
case m.unregistrationChan <- event:
return nil
case <-m.closedChan:
return errSessionManagerClosed
}
}
func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
session, ok := m.sessions[unregistration.sessionID]
if ok {
delete(m.sessions, unregistration.sessionID)
session.close(unregistration.err)
}
}
func (m *manager) sendToSession(datagram *packet.Session) {
session, ok := m.sessions[datagram.ID]
if !ok {
m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found")
return
}
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
// need to run in another go routine
session.transportToDst(datagram.Payload)
}