package datagramsession import ( "context" "fmt" "io" "net" "time" "github.com/google/uuid" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/packet" ) const ( defaultCloseIdleAfter = time.Second * 210 ) func SessionIdleErr(timeout time.Duration) error { return fmt.Errorf("session idle for %v", timeout) } type transportSender func(session *packet.Session) error // Session is a bidirectional pipe of datagrams between transport and dstConn // Destination can be a connection with origin or with eyeball // When the destination is origin: // - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to origin // - Datagrams from origin are read from conn and Send to transport using the transportSender callback. Transport will return them to eyeball // When the destination is eyeball: // - Datagrams from eyeball are read from conn and Send to transport. Transport will send them to cloudflared using the transportSender callback. // - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to the eyeball type Session struct { ID uuid.UUID sendFunc transportSender dstConn io.ReadWriteCloser // activeAtChan is used to communicate the last read/write time activeAtChan chan time.Time closeChan chan error log *zerolog.Logger } func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) { go func() { // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 // This makes it safe to share readBuffer between iterations const maxPacketSize = 1500 readBuffer := make([]byte, maxPacketSize) for { if closeSession, err := s.dstToTransport(readBuffer); err != nil { if err != net.ErrClosed { s.log.Error().Err(err).Msg("Failed to send session payload from destination to transport") } else { s.log.Debug().Msg("Session cannot read from destination because the connection is closed") } if closeSession { s.closeChan <- err return } } } }() err = s.waitForCloseCondition(ctx, closeAfterIdle) if closeSession, ok := err.(*errClosedSession); ok { closedByRemote = closeSession.byRemote } return closedByRemote, err } func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { // Closing dstConn cancels read so dstToTransport routine in Serve() can return defer s.dstConn.Close() if closeAfterIdle == 0 { // provide deafult is caller doesn't specify one closeAfterIdle = defaultCloseIdleAfter } checkIdleFreq := closeAfterIdle / 8 checkIdleTicker := time.NewTicker(checkIdleFreq) defer checkIdleTicker.Stop() activeAt := time.Now() for { select { case <-ctx.Done(): return ctx.Err() case reason := <-s.closeChan: return reason // TODO: TUN-5423 evaluate if using atomic is more efficient case now := <-checkIdleTicker.C: // The session is considered inactive if current time is after (last active time + allowed idle time) if now.After(activeAt.Add(closeAfterIdle)) { return SessionIdleErr(closeAfterIdle) } case activeAt = <-s.activeAtChan: // Update last active time } } } func (s *Session) dstToTransport(buffer []byte) (closeSession bool, err error) { n, err := s.dstConn.Read(buffer) s.markActive() // https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes if n > 0 || err == nil { session := packet.Session{ ID: s.ID, Payload: buffer[:n], } if sendErr := s.sendFunc(&session); sendErr != nil { return false, sendErr } } return err != nil, err } func (s *Session) transportToDst(payload []byte) (int, error) { s.markActive() n, err := s.dstConn.Write(payload) if err != nil { s.log.Err(err).Msg("Failed to write payload to session") } return n, err } // Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there // are many concurrent read/write. It is fine to lose some precision func (s *Session) markActive() { select { case s.activeAtChan <- time.Now(): default: } } func (s *Session) close(err *errClosedSession) { s.closeChan <- err }