package datagramsession import ( "context" "io" "time" "github.com/google/uuid" "github.com/lucas-clemente/quic-go" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" ) const ( requestChanCapacity = 16 defaultReqTimeout = time.Second * 5 ) // Manager defines the APIs to manage sessions from the same transport. type Manager interface { // Serve starts the event loop Serve(ctx context.Context) error // RegisterSession starts tracking a session. Caller is responsible for starting the session RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error) // UnregisterSession stops tracking the session and terminates it UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error } type manager struct { registrationChan chan *registerSessionEvent unregistrationChan chan *unregisterSessionEvent datagramChan chan *newDatagram transport transport sessions map[uuid.UUID]*Session log *zerolog.Logger // timeout waiting for an API to finish. This can be overriden in test timeout time.Duration } func NewManager(transport transport, log *zerolog.Logger) *manager { return &manager{ registrationChan: make(chan *registerSessionEvent), unregistrationChan: make(chan *unregisterSessionEvent), // datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events datagramChan: make(chan *newDatagram, requestChanCapacity), transport: transport, sessions: make(map[uuid.UUID]*Session), log: log, timeout: defaultReqTimeout, } } func (m *manager) Serve(ctx context.Context) error { errGroup, ctx := errgroup.WithContext(ctx) errGroup.Go(func() error { for { sessionID, payload, err := m.transport.ReceiveFrom() if err != nil { if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) { return nil } else { return err } } datagram := &newDatagram{ sessionID: sessionID, payload: payload, } select { case <-ctx.Done(): return ctx.Err() // Only the event loop routine can update/lookup the sessions map to avoid concurrent access // Send the datagram to the event loop. It will find the session to send to case m.datagramChan <- datagram: } } }) errGroup.Go(func() error { for { select { case <-ctx.Done(): return nil case datagram := <-m.datagramChan: m.sendToSession(datagram) case registration := <-m.registrationChan: m.registerSession(ctx, registration) // TODO: TUN-5422: Unregister inactive session upon timeout case unregistration := <-m.unregistrationChan: m.unregisterSession(unregistration) } } }) return errGroup.Wait() } func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() event := newRegisterSessionEvent(sessionID, originProxy) select { case <-ctx.Done(): m.log.Error().Msg("Datagram session registration timeout") return nil, ctx.Err() case m.registrationChan <- event: session := <-event.resultChan return session, nil } } func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log) m.sessions[registration.sessionID] = session registration.resultChan <- session } func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error { ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() event := &unregisterSessionEvent{ sessionID: sessionID, err: &errClosedSession{ message: message, byRemote: byRemote, }, } select { case <-ctx.Done(): m.log.Error().Msg("Datagram session unregistration timeout") return ctx.Err() case m.unregistrationChan <- event: return nil } } func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) { session, ok := m.sessions[unregistration.sessionID] if ok { delete(m.sessions, unregistration.sessionID) session.close(unregistration.err) } } func (m *manager) sendToSession(datagram *newDatagram) { session, ok := m.sessions[datagram.sessionID] if !ok { m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found") return } // session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't // need to run in another go routine _, err := session.transportToDst(datagram.payload) if err != nil { m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session") } }