From 16ecf60800a8b55379740ce246e42d31fd829a28 Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Thu, 24 Oct 2024 11:42:02 -0700 Subject: [PATCH] TUN-8661: Refactor connection methods to support future different datagram muxing methods The current supervisor serves the quic connection by performing all of the following in one method: 1. Dial QUIC edge connection 2. Initialize datagram muxer for UDP sessions and ICMP 3. Wrap all together in a single struct to serve the process loops In an effort to better support modularity, each of these steps were broken out into their own separate methods that the supervisor will compose together to create the TunnelConnection and run its `Serve` method. This also provides us with the capability to better interchange the functionality supported by the datagram session manager in the future with a new mechanism. Closes TUN-8661 --- connection/connection.go | 7 + connection/quic.go | 571 +----------------- connection/quic_connection.go | 444 ++++++++++++++ .../{quic_test.go => quic_connection_test.go} | 74 ++- connection/quic_datagram_v2.go | 200 ++++++ supervisor/tunnel.go | 51 +- 6 files changed, 745 insertions(+), 602 deletions(-) create mode 100644 connection/quic_connection.go rename connection/{quic_test.go => quic_connection_test.go} (90%) create mode 100644 connection/quic_datagram_v2.go diff --git a/connection/connection.go b/connection/connection.go index 50464e4a..b7376e38 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -36,6 +36,13 @@ var ( flushableContentTypes = []string{sseContentType, grpcContentType} ) +// TunnelConnection represents the connection to the edge. +// The Serve method is provided to allow clients to handle any errors from the connection encountered during +// processing of the connection. Cancelling of the context provided to Serve will close the connection. +type TunnelConnection interface { + Serve(ctx context.Context) error +} + type Orchestrator interface { UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse GetConfigJSON() ([]byte, error) diff --git a/connection/quic.go b/connection/quic.go index cbf5b186..3109d77f 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -1,51 +1,16 @@ package connection import ( - "bufio" "context" "crypto/tls" "fmt" - "io" "net" - "net/http" "net/netip" "runtime" - "strconv" - "strings" "sync" - "sync/atomic" - "time" - "github.com/google/uuid" - "github.com/pkg/errors" "github.com/quic-go/quic-go" "github.com/rs/zerolog" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" - - "github.com/cloudflare/cloudflared/datagramsession" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/management" - "github.com/cloudflare/cloudflared/packet" - cfdquic "github.com/cloudflare/cloudflared/quic" - "github.com/cloudflare/cloudflared/tracing" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" -) - -const ( - // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPHeaderKey = "HttpHeader" - // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPMethodKey = "HttpMethod" - // HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPHostKey = "HttpHost" - - QUICMetadataFlowID = "FlowID" - // emperically this capacity has been working well - demuxChanCapacity = 16 ) var ( @@ -53,48 +18,21 @@ var ( portMapMutex sync.Mutex ) -// QUICConnection represents the type that facilitates Proxying via QUIC streams. -type QUICConnection struct { - session quic.Connection - logger *zerolog.Logger - orchestrator Orchestrator - // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer - sessionManager datagramsession.Manager - // datagramMuxer mux/demux datagrams from quic connection - datagramMuxer *cfdquic.DatagramMuxerV2 - packetRouter *ingress.PacketRouter - controlStreamHandler ControlStreamHandler - connOptions *tunnelpogs.ConnectionOptions - connIndex uint8 - - rpcTimeout time.Duration - streamWriteTimeout time.Duration - gracePeriod time.Duration -} - -// NewQUICConnection returns a new instance of QUICConnection. -func NewQUICConnection( +func DialQuic( ctx context.Context, quicConfig *quic.Config, + tlsConfig *tls.Config, edgeAddr netip.AddrPort, localAddr net.IP, connIndex uint8, - tlsConfig *tls.Config, - orchestrator Orchestrator, - connOptions *tunnelpogs.ConnectionOptions, - controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, - packetRouterConfig *ingress.GlobalRouterConfig, - rpcTimeout time.Duration, - streamWriteTimeout time.Duration, - gracePeriod time.Duration, -) (*QUICConnection, error) { +) (quic.Connection, error) { udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, logger) if err != nil { return nil, err } - session, err := quic.Dial(ctx, udpConn, net.UDPAddrFromAddrPort(edgeAddr), tlsConfig, quicConfig) + conn, err := quic.Dial(ctx, udpConn, net.UDPAddrFromAddrPort(edgeAddr), tlsConfig, quicConfig) if err != nil { // close the udp server socket in case of error connecting to the edge udpConn.Close() @@ -102,506 +40,11 @@ func NewQUICConnection( } // wrap the session, so that the UDPConn is closed after session is closed. - session = &wrapCloseableConnQuicConnection{ - session, + conn = &wrapCloseableConnQuicConnection{ + conn, udpConn, } - - sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) - datagramMuxer := cfdquic.NewDatagramMuxerV2(session, logger, sessionDemuxChan) - sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) - packetRouter := ingress.NewPacketRouter(packetRouterConfig, datagramMuxer, logger) - - return &QUICConnection{ - session: session, - orchestrator: orchestrator, - logger: logger, - sessionManager: sessionManager, - datagramMuxer: datagramMuxer, - packetRouter: packetRouter, - controlStreamHandler: controlStreamHandler, - connOptions: connOptions, - connIndex: connIndex, - rpcTimeout: rpcTimeout, - streamWriteTimeout: streamWriteTimeout, - gracePeriod: gracePeriod, - }, nil -} - -// Serve starts a QUIC session that begins accepting streams. -func (q *QUICConnection) Serve(ctx context.Context) error { - // origintunneld assumes the first stream is used for the control plane - controlStream, err := q.session.OpenStream() - if err != nil { - return fmt.Errorf("failed to open a registration control stream: %w", err) - } - - // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits - // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this - // connection). - // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the - // other goroutine as fast as possible. - ctx, cancel := context.WithCancel(ctx) - errGroup, ctx := errgroup.WithContext(ctx) - - // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control - // stream is already fully registered before the other goroutines can proceed. - errGroup.Go(func() error { - // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full - // amount of the grace period, allowing requests to finish before we cancel the context, which will - // make cloudflared exit. - if err := q.serveControlStream(ctx, controlStream); err == nil { - select { - case <-ctx.Done(): - case <-time.Tick(q.gracePeriod): - } - } - cancel() - return err - }) - errGroup.Go(func() error { - defer cancel() - return q.acceptStream(ctx) - }) - errGroup.Go(func() error { - defer cancel() - return q.sessionManager.Serve(ctx) - }) - errGroup.Go(func() error { - defer cancel() - return q.datagramMuxer.ServeReceive(ctx) - }) - errGroup.Go(func() error { - defer cancel() - return q.packetRouter.Serve(ctx) - }) - - return errGroup.Wait() -} - -func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { - // This blocks until the control plane is done. - err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator) - if err != nil { - // Not wrapping error here to be consistent with the http2 message. - return err - } - - return nil -} - -// Close closes the session with no errors specified. -func (q *QUICConnection) Close() { - q.session.CloseWithError(0, "") -} - -func (q *QUICConnection) acceptStream(ctx context.Context) error { - defer q.Close() - for { - quicStream, err := q.session.AcceptStream(ctx) - if err != nil { - // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. - if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { - return nil - } - return fmt.Errorf("failed to accept QUIC stream: %w", err) - } - go q.runStream(quicStream) - } -} - -func (q *QUICConnection) runStream(quicStream quic.Stream) { - ctx := quicStream.Context() - stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() - - // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that - // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. - // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream. - // A call to close will simulate a close to the read-side, which will fail subsequent reads. - noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} - ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q, q, q.rpcTimeout) - if err := ss.Serve(ctx, noCloseStream); err != nil { - q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream") - - // if we received an error at this level, then close write side of stream with an error, which will result in - // RST_STREAM frame. - quicStream.CancelWrite(0) - } -} - -func (q *QUICConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error { - request, err := stream.ReadConnectRequestData() - if err != nil { - return err - } - - if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil { - q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") - - // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is - // closed with an RST_STREAM frame - if connectResponseSent { - return err - } - - if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { - return writeRespErr - } - } - - return nil -} - -// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs. -// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream. -// This is important since it informs -func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, err error, request *pogs.ConnectRequest) (error, bool) { - originProxy, err := q.orchestrator.GetOriginProxy() - if err != nil { - return err, false - } - - switch request.Type { - case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket: - tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger) - if err != nil { - return err, false - } - w := newHTTPResponseAdapter(stream) - return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent - - case pogs.ConnectionTypeTCP: - rwa := &streamReadWriteAcker{RequestServerStream: stream} - metadata := request.MetadataMap() - return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ - Dest: request.Dest, - FlowID: metadata[QUICMetadataFlowID], - CfTraceID: metadata[tracing.TracerContextName], - ConnIndex: q.connIndex, - }), rwa.connectResponseSent - default: - return errors.Errorf("unsupported error type: %s", request.Type), false - } -} - -// RegisterUdpSession is the RPC method invoked by edge to register and run a session -func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) { - traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger) - ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes( - attribute.String("session-id", sessionID.String()), - attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)), - )) - log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger() - // Each session is a series of datagram from an eyeball to a dstIP:dstPort. - // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. - originProxy, err := ingress.DialUDP(dstIP, dstPort) - if err != nil { - log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) - tracing.EndWithErrorStatus(registerSpan, err) - return nil, err - } - registerSpan.SetAttributes( - attribute.Bool("socket-bind-success", true), - attribute.String("src", originProxy.LocalAddr().String()), - ) - - session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) - if err != nil { - originProxy.Close() - log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session") - tracing.EndWithErrorStatus(registerSpan, err) - return nil, err - } - - go q.serveUDPSession(session, closeAfterIdleHint) - - log.Debug(). - Str("sessionID", sessionID.String()). - Str("src", originProxy.LocalAddr().String()). - Str("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)). - Msgf("Registered session") - tracing.End(registerSpan) - - resp := tunnelpogs.RegisterUdpSessionResponse{ - Spans: traceCtx.GetProtoSpans(), - } - - return &resp, nil -} - -func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) { - ctx := q.session.Context() - closedByRemote, err := session.Serve(ctx, closeAfterIdleHint) - // If session is terminated by remote, then we know it has been unregistered from session manager and edge - if !closedByRemote { - if err != nil { - q.closeUDPSession(ctx, session.ID, err.Error()) - } else { - q.closeUDPSession(ctx, session.ID, "terminated without error") - } - } - q.logger.Debug().Err(err). - Int(management.EventTypeKey, int(management.UDP)). - Str("sessionID", session.ID.String()). - Msg("Session terminated") -} - -// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge -func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) { - q.sessionManager.UnregisterSession(ctx, sessionID, message, false) - quicStream, err := q.session.OpenStream() - if err != nil { - // Log this at debug because this is not an error if session was closed due to lost connection - // with edge - q.logger.Debug().Err(err). - Int(management.EventTypeKey, int(management.UDP)). - Str("sessionID", sessionID.String()). - Msgf("Failed to open quic stream to unregister udp session with edge") - return - } - - stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() - rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout) - if err != nil { - // Log this at debug because this is not an error if session was closed due to lost connection - // with edge - q.logger.Err(err).Str("sessionID", sessionID.String()). - Msgf("Failed to open rpc stream to unregister udp session with edge") - return - } - defer rpcClientStream.Close() - - if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil { - q.logger.Err(err).Str("sessionID", sessionID.String()). - Msgf("Failed to unregister udp session with edge") - } -} - -// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion -func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { - return q.sessionManager.UnregisterSession(ctx, sessionID, message, true) -} - -// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration -func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { - return q.orchestrator.UpdateConfig(version, config) -} - -// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to -// the client. -type streamReadWriteAcker struct { - *rpcquic.RequestServerStream - connectResponseSent bool -} - -// AckConnection acks response back to the proxy. -func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { - metadata := []pogs.Metadata{} - // Only add tracing if provided by origintunneld - if tracePropagation != "" { - metadata = append(metadata, pogs.Metadata{ - Key: tracing.CanonicalCloudflaredTracingHeader, - Val: tracePropagation, - }) - } - s.connectResponseSent = true - return s.WriteConnectResponseData(nil, metadata...) -} - -// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. -type httpResponseAdapter struct { - *rpcquic.RequestServerStream - headers http.Header - connectResponseSent bool -} - -func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter { - return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)} -} - -func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { - // we do not support trailers over QUIC -} - -func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { - metadata := make([]pogs.Metadata, 0) - metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) - for k, vv := range header { - for _, v := range vv { - httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k) - metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v}) - } - } - - return hrw.WriteConnectResponseData(nil, metadata...) -} - -func (hrw *httpResponseAdapter) Write(p []byte) (int, error) { - // Make sure to send WriteHeader response if not called yet - if !hrw.connectResponseSent { - hrw.WriteRespHeaders(http.StatusOK, hrw.headers) - } - return hrw.RequestServerStream.Write(p) -} - -func (hrw *httpResponseAdapter) Header() http.Header { - return hrw.headers -} - -// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here. -func (hrw *httpResponseAdapter) Flush() {} - -func (hrw *httpResponseAdapter) WriteHeader(status int) { - hrw.WriteRespHeaders(status, hrw.headers) -} - -func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - conn := &localProxyConnection{hrw.ReadWriteCloser} - readWriter := bufio.NewReadWriter( - bufio.NewReader(hrw.ReadWriteCloser), - bufio.NewWriter(hrw.ReadWriteCloser), - ) - return conn, readWriter, nil -} - -func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { - hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) -} - -func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error { - hrw.connectResponseSent = true - return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...) -} - -func buildHTTPRequest( - ctx context.Context, - connectRequest *pogs.ConnectRequest, - body io.ReadCloser, - connIndex uint8, - log *zerolog.Logger, -) (*tracing.TracedHTTPRequest, error) { - metadata := connectRequest.MetadataMap() - dest := connectRequest.Dest - method := metadata[HTTPMethodKey] - host := metadata[HTTPHostKey] - isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket - - req, err := http.NewRequestWithContext(ctx, method, dest, body) - if err != nil { - return nil, err - } - - req.Host = host - for _, metadata := range connectRequest.Metadata { - if strings.Contains(metadata.Key, HTTPHeaderKey) { - // metadata.Key is off the format httpHeaderKey: - httpHeaderKey := strings.Split(metadata.Key, ":") - if len(httpHeaderKey) != 2 { - return nil, fmt.Errorf("header Key: %s malformed", metadata.Key) - } - req.Header.Add(httpHeaderKey[1], metadata.Val) - } - } - // Go's http.Client automatically sends chunked request body if this value is not set on the - // *http.Request struct regardless of header: - // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154. - if err := setContentLength(req); err != nil { - return nil, fmt.Errorf("Error setting content-length: %w", err) - } - - // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true: - // * the request body blocks - // * the content length is not set (or set to -1) - // * the method doesn't usually have a body (GET, HEAD, DELETE, ...) - // * there is no transfer-encoding=chunked already set. - // So, if transfer cannot be chunked and content length is 0, we dont set a request body. - if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 { - req.Body = http.NoBody - } - stripWebsocketUpgradeHeader(req) - - // Check for tracing on request - tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log) - return tracedReq, err -} - -func setContentLength(req *http.Request) error { - var err error - if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" { - req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) - } - return err -} - -func isTransferEncodingChunked(req *http.Request) bool { - transferEncodingVal := req.Header.Get("Transfer-Encoding") - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma - // separated value as well. - return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") -} - -// A helper struct that guarantees a call to close only affects read side, but not write side. -type nopCloserReadWriter struct { - io.ReadWriteCloser - - // for use by Read only - // we don't need a memory barrier here because there is an implicit assumption that - // Read calls can't happen concurrently by different go-routines. - sawEOF bool - // should be updated and read using atomic primitives. - // value is read in Read method and written in Close method, which could be done by different - // go-routines. - closed uint32 -} - -func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) { - if np.sawEOF { - return 0, io.EOF - } - - if atomic.LoadUint32(&np.closed) > 0 { - return 0, fmt.Errorf("closed by handler") - } - - n, err = np.ReadWriteCloser.Read(p) - if err == io.EOF { - np.sawEOF = true - } - - return -} - -func (np *nopCloserReadWriter) Close() error { - atomic.StoreUint32(&np.closed, 1) - - return nil -} - -// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface -type muxerWrapper struct { - muxer *cfdquic.DatagramMuxerV2 -} - -func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error { - return rp.muxer.SendPacket(cfdquic.RawPacket(pk)) -} - -func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { - pk, err := rp.muxer.ReceivePacket(ctx) - if err != nil { - return packet.RawPacket{}, err - } - rawPacket, ok := pk.(cfdquic.RawPacket) - if ok { - return packet.RawPacket(rawPacket), nil - } - return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk) -} - -func (rp *muxerWrapper) Close() error { - return nil + return conn, nil } func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, logger *zerolog.Logger) (*net.UDPConn, error) { diff --git a/connection/quic_connection.go b/connection/quic_connection.go new file mode 100644 index 00000000..d0baab5e --- /dev/null +++ b/connection/quic_connection.go @@ -0,0 +1,444 @@ +package connection + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/pkg/errors" + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + + "github.com/cloudflare/cloudflared/packet" + cfdquic "github.com/cloudflare/cloudflared/quic" + "github.com/cloudflare/cloudflared/tracing" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" +) + +const ( + // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP. + HTTPHeaderKey = "HttpHeader" + // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. + HTTPMethodKey = "HttpMethod" + // HTTPHostKey is used to get or set http host in QUIC ALPN if the underlying proxy connection type is HTTP. + HTTPHostKey = "HttpHost" + + QUICMetadataFlowID = "FlowID" +) + +// quicConnection represents the type that facilitates Proxying via QUIC streams. +type quicConnection struct { + conn quic.Connection + logger *zerolog.Logger + orchestrator Orchestrator + datagramHandler DatagramSessionHandler + controlStreamHandler ControlStreamHandler + connOptions *tunnelpogs.ConnectionOptions + connIndex uint8 + + rpcTimeout time.Duration + streamWriteTimeout time.Duration + gracePeriod time.Duration +} + +// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic. +func NewTunnelConnection( + ctx context.Context, + conn quic.Connection, + connIndex uint8, + orchestrator Orchestrator, + datagramSessionHandler DatagramSessionHandler, + controlStreamHandler ControlStreamHandler, + connOptions *pogs.ConnectionOptions, + rpcTimeout time.Duration, + streamWriteTimeout time.Duration, + gracePeriod time.Duration, + logger *zerolog.Logger, +) (TunnelConnection, error) { + return &quicConnection{ + conn: conn, + logger: logger, + orchestrator: orchestrator, + datagramHandler: datagramSessionHandler, + controlStreamHandler: controlStreamHandler, + connOptions: connOptions, + connIndex: connIndex, + rpcTimeout: rpcTimeout, + streamWriteTimeout: streamWriteTimeout, + gracePeriod: gracePeriod, + }, nil +} + +// Serve starts a QUIC connection that begins accepting streams. +func (q *quicConnection) Serve(ctx context.Context) error { + // The edge assumes the first stream is used for the control plane + controlStream, err := q.conn.OpenStream() + if err != nil { + return fmt.Errorf("failed to open a registration control stream: %w", err) + } + + // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits + // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this + // connection). + // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the + // other goroutine as fast as possible. + ctx, cancel := context.WithCancel(ctx) + errGroup, ctx := errgroup.WithContext(ctx) + + // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control + // stream is already fully registered before the other goroutines can proceed. + errGroup.Go(func() error { + // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full + // amount of the grace period, allowing requests to finish before we cancel the context, which will + // make cloudflared exit. + if err := q.serveControlStream(ctx, controlStream); err == nil { + select { + case <-ctx.Done(): + case <-time.Tick(q.gracePeriod): + } + } + cancel() + return err + + }) + errGroup.Go(func() error { + defer cancel() + return q.acceptStream(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return q.datagramHandler.Serve(ctx) + }) + + return errGroup.Wait() +} + +// serveControlStream will serve the RPC; blocking until the control plane is done. +func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { + return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator) +} + +// Close the connection with no errors specified. +func (q *quicConnection) Close() { + q.conn.CloseWithError(0, "") +} + +func (q *quicConnection) acceptStream(ctx context.Context) error { + defer q.Close() + for { + quicStream, err := q.conn.AcceptStream(ctx) + if err != nil { + // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. + if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { + return nil + } + return fmt.Errorf("failed to accept QUIC stream: %w", err) + } + go q.runStream(quicStream) + } +} + +func (q *quicConnection) runStream(quicStream quic.Stream) { + ctx := quicStream.Context() + stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) + defer stream.Close() + + // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that + // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. + // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream. + // A call to close will simulate a close to the read-side, which will fail subsequent reads. + noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} + ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q.datagramHandler, q, q.rpcTimeout) + if err := ss.Serve(ctx, noCloseStream); err != nil { + q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream") + + // if we received an error at this level, then close write side of stream with an error, which will result in + // RST_STREAM frame. + quicStream.CancelWrite(0) + } +} + +func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error { + request, err := stream.ReadConnectRequestData() + if err != nil { + return err + } + + if err, connectResponseSent := q.dispatchRequest(ctx, stream, request); err != nil { + q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") + + // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is + // closed with an RST_STREAM frame + if connectResponseSent { + return err + } + + if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { + return writeRespErr + } + } + + return nil +} + +// dispatchRequest will dispatch the request to the origin depending on the type and returns an error if it occurs. +// Also returns if the connect response was sent to the downstream during processing of the origin request. +func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, request *pogs.ConnectRequest) (err error, connectResponseSent bool) { + originProxy, err := q.orchestrator.GetOriginProxy() + if err != nil { + return err, false + } + + switch request.Type { + case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket: + tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger) + if err != nil { + return err, false + } + w := newHTTPResponseAdapter(stream) + return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent + + case pogs.ConnectionTypeTCP: + rwa := &streamReadWriteAcker{RequestServerStream: stream} + metadata := request.MetadataMap() + return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ + Dest: request.Dest, + FlowID: metadata[QUICMetadataFlowID], + CfTraceID: metadata[tracing.TracerContextName], + ConnIndex: q.connIndex, + }), rwa.connectResponseSent + default: + return errors.Errorf("unsupported error type: %s", request.Type), false + } +} + +// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration +func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + return q.orchestrator.UpdateConfig(version, config) +} + +// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to +// the client. +type streamReadWriteAcker struct { + *rpcquic.RequestServerStream + connectResponseSent bool +} + +// AckConnection acks response back to the proxy. +func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { + metadata := []pogs.Metadata{} + // Only add tracing if provided by the edge request + if tracePropagation != "" { + metadata = append(metadata, pogs.Metadata{ + Key: tracing.CanonicalCloudflaredTracingHeader, + Val: tracePropagation, + }) + } + s.connectResponseSent = true + return s.WriteConnectResponseData(nil, metadata...) +} + +// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. +type httpResponseAdapter struct { + *rpcquic.RequestServerStream + headers http.Header + connectResponseSent bool +} + +func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter { + return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)} +} + +func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { + // we do not support trailers over QUIC +} + +func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { + metadata := make([]pogs.Metadata, 0) + metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) + for k, vv := range header { + for _, v := range vv { + httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k) + metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v}) + } + } + + return hrw.WriteConnectResponseData(nil, metadata...) +} + +func (hrw *httpResponseAdapter) Write(p []byte) (int, error) { + // Make sure to send WriteHeader response if not called yet + if !hrw.connectResponseSent { + hrw.WriteRespHeaders(http.StatusOK, hrw.headers) + } + return hrw.RequestServerStream.Write(p) +} + +func (hrw *httpResponseAdapter) Header() http.Header { + return hrw.headers +} + +// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here. +func (hrw *httpResponseAdapter) Flush() {} + +func (hrw *httpResponseAdapter) WriteHeader(status int) { + hrw.WriteRespHeaders(status, hrw.headers) +} + +func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn := &localProxyConnection{hrw.ReadWriteCloser} + readWriter := bufio.NewReadWriter( + bufio.NewReader(hrw.ReadWriteCloser), + bufio.NewWriter(hrw.ReadWriteCloser), + ) + return conn, readWriter, nil +} + +func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { + hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) +} + +func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error { + hrw.connectResponseSent = true + return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...) +} + +func buildHTTPRequest( + ctx context.Context, + connectRequest *pogs.ConnectRequest, + body io.ReadCloser, + connIndex uint8, + log *zerolog.Logger, +) (*tracing.TracedHTTPRequest, error) { + metadata := connectRequest.MetadataMap() + dest := connectRequest.Dest + method := metadata[HTTPMethodKey] + host := metadata[HTTPHostKey] + isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket + + req, err := http.NewRequestWithContext(ctx, method, dest, body) + if err != nil { + return nil, err + } + + req.Host = host + for _, metadata := range connectRequest.Metadata { + if strings.Contains(metadata.Key, HTTPHeaderKey) { + // metadata.Key is off the format httpHeaderKey: + httpHeaderKey := strings.Split(metadata.Key, ":") + if len(httpHeaderKey) != 2 { + return nil, fmt.Errorf("header Key: %s malformed", metadata.Key) + } + req.Header.Add(httpHeaderKey[1], metadata.Val) + } + } + // Go's http.Client automatically sends chunked request body if this value is not set on the + // *http.Request struct regardless of header: + // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154. + if err := setContentLength(req); err != nil { + return nil, fmt.Errorf("Error setting content-length: %w", err) + } + + // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true: + // * the request body blocks + // * the content length is not set (or set to -1) + // * the method doesn't usually have a body (GET, HEAD, DELETE, ...) + // * there is no transfer-encoding=chunked already set. + // So, if transfer cannot be chunked and content length is 0, we dont set a request body. + if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 { + req.Body = http.NoBody + } + stripWebsocketUpgradeHeader(req) + + // Check for tracing on request + tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log) + return tracedReq, err +} + +func setContentLength(req *http.Request) error { + var err error + if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" { + req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + } + return err +} + +func isTransferEncodingChunked(req *http.Request) bool { + transferEncodingVal := req.Header.Get("Transfer-Encoding") + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma + // separated value as well. + return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") +} + +// A helper struct that guarantees a call to close only affects read side, but not write side. +type nopCloserReadWriter struct { + io.ReadWriteCloser + + // for use by Read only + // we don't need a memory barrier here because there is an implicit assumption that + // Read calls can't happen concurrently by different go-routines. + sawEOF bool + // should be updated and read using atomic primitives. + // value is read in Read method and written in Close method, which could be done by different + // go-routines. + closed uint32 +} + +func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) { + if np.sawEOF { + return 0, io.EOF + } + + if atomic.LoadUint32(&np.closed) > 0 { + return 0, fmt.Errorf("closed by handler") + } + + n, err = np.ReadWriteCloser.Read(p) + if err == io.EOF { + np.sawEOF = true + } + + return +} + +func (np *nopCloserReadWriter) Close() error { + atomic.StoreUint32(&np.closed, 1) + + return nil +} + +// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface +type muxerWrapper struct { + muxer *cfdquic.DatagramMuxerV2 +} + +func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error { + return rp.muxer.SendPacket(cfdquic.RawPacket(pk)) +} + +func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { + pk, err := rp.muxer.ReceivePacket(ctx) + if err != nil { + return packet.RawPacket{}, err + } + rawPacket, ok := pk.(cfdquic.RawPacket) + if ok { + return packet.RawPacket(rawPacket), nil + } + return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk) +} + +func (rp *muxerWrapper) Close() error { + return nil +} diff --git a/connection/quic_test.go b/connection/quic_connection_test.go similarity index 90% rename from connection/quic_test.go rename to connection/quic_connection_test.go index c073b850..ba052437 100644 --- a/connection/quic_test.go +++ b/connection/quic_connection_test.go @@ -15,7 +15,6 @@ import ( "net/http" "net/netip" "net/url" - "os" "strings" "testing" "time" @@ -30,10 +29,11 @@ import ( "golang.org/x/net/nettest" "github.com/cloudflare/cloudflared/datagramsession" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/packet" cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" ) @@ -164,11 +164,11 @@ func TestQUICServer(t *testing.T) { close(serverDone) }() - qc := testQUICConnection(netip.MustParseAddrPort(udpListener.LocalAddr().String()), t, uint8(i)) + tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i)) connDone := make(chan struct{}) go func() { - qc.Serve(ctx) + tunnelConn.Serve(ctx) close(connDone) }() @@ -528,13 +528,14 @@ func TestServeUDPSession(t *testing.T) { }() // Random index to avoid reusing port - qc := testQUICConnection(netip.MustParseAddrPort(udpListener.LocalAddr().String()), t, 28) - go qc.Serve(ctx) + tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28) + go tunnelConn.Serve(ctx) edgeQUICSession := <-edgeQUICSessionChan - serveSession(ctx, qc, edgeQUICSession, closedByOrigin, io.EOF.Error(), t) - serveSession(ctx, qc, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t) - serveSession(ctx, qc, edgeQUICSession, closedByRemote, "eyeball closed connection", t) + + serveSession(ctx, datagramConn, edgeQUICSession, closedByOrigin, io.EOF.Error(), t) + serveSession(ctx, datagramConn, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t) + serveSession(ctx, datagramConn, edgeQUICSession, closedByRemote, "eyeball closed connection", t) cancel() } @@ -619,19 +620,19 @@ func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPo require.NotEqual(t, initialPort, getPortFunc(conn)) } -func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { +func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { var ( payload = []byte(t.Name()) ) sessionID := uuid.New() cfdConn, originConn := net.Pipe() // Registers and run a new session - session, err := qc.sessionManager.RegisterSession(ctx, sessionID, cfdConn) + session, err := datagramConn.sessionManager.RegisterSession(ctx, sessionID, cfdConn) require.NoError(t, err) sessionDone := make(chan struct{}) go func() { - qc.serveUDPSession(session, time.Millisecond*50) + datagramConn.serveUDPSession(session, time.Millisecond*50) close(sessionDone) }() @@ -655,7 +656,7 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic. case closedByOrigin: originConn.Close() case closedByRemote: - err = qc.UnregisterUdpSession(ctx, sessionID, expectedReason) + err = datagramConn.UnregisterUdpSession(ctx, sessionID, expectedReason) require.NoError(t, err) case closedByTimeout: } @@ -726,33 +727,58 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI return nil } -func testQUICConnection(udpListenerAddr netip.AddrPort, t *testing.T, index uint8) *QUICConnection { +func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) { tlsClientConfig := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"argotunnel"}, } // Start a mock httpProxy - log := zerolog.New(os.Stdout) + log := zerolog.New(io.Discard) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - qc, err := NewQUICConnection( + + // Dial the QUIC connection to the edge + conn, err := DialQuic( ctx, testQUICConfig, - udpListenerAddr, - nil, - index, tlsClientConfig, - &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, - &tunnelpogs.ConnectionOptions{}, - fakeControlStream{}, + serverAddr, + nil, // connect on a random port + index, &log, - nil, + ) + + // Start a session manager for the connection + sessionDemuxChan := make(chan *packet.Session, 4) + datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan) + sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) + packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, &log) + + datagramConn := &datagramV2Connection{ + conn, + sessionManager, + datagramMuxer, + packetRouter, + 15 * time.Second, + 0 * time.Second, + &log, + } + + tunnelConn, err := NewTunnelConnection( + ctx, + conn, + index, + &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, + datagramConn, + fakeControlStream{}, + &pogs.ConnectionOptions{}, 15*time.Second, 0*time.Second, 0*time.Second, + &log, ) require.NoError(t, err) - return qc + return tunnelConn, datagramConn } type mockReaderNoopWriter struct { diff --git a/connection/quic_datagram_v2.go b/connection/quic_datagram_v2.go new file mode 100644 index 00000000..1cedaa41 --- /dev/null +++ b/connection/quic_datagram_v2.go @@ -0,0 +1,200 @@ +package connection + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/google/uuid" + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + + "github.com/cloudflare/cloudflared/datagramsession" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/management" + "github.com/cloudflare/cloudflared/packet" + cfdquic "github.com/cloudflare/cloudflared/quic" + "github.com/cloudflare/cloudflared/tracing" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" +) + +const ( + // emperically this capacity has been working well + demuxChanCapacity = 16 +) + +// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming +// connection streams. +type DatagramSessionHandler interface { + Serve(context.Context) error + + pogs.SessionManager +} + +type datagramV2Connection struct { + conn quic.Connection + + // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer + sessionManager datagramsession.Manager + // datagramMuxer mux/demux datagrams from quic connection + datagramMuxer *cfdquic.DatagramMuxerV2 + packetRouter *ingress.PacketRouter + + rpcTimeout time.Duration + streamWriteTimeout time.Duration + + logger *zerolog.Logger +} + +func NewDatagramV2Connection(ctx context.Context, + conn quic.Connection, + packetConfig *ingress.GlobalRouterConfig, + rpcTimeout time.Duration, + streamWriteTimeout time.Duration, + logger *zerolog.Logger, +) DatagramSessionHandler { + sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) + datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan) + sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) + packetRouter := ingress.NewPacketRouter(packetConfig, datagramMuxer, logger) + + return &datagramV2Connection{ + conn, + sessionManager, + datagramMuxer, + packetRouter, + rpcTimeout, + streamWriteTimeout, + logger, + } +} + +func (d *datagramV2Connection) Serve(ctx context.Context) error { + // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits + // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this + // connection). + // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the + // other goroutine as fast as possible. + ctx, cancel := context.WithCancel(ctx) + errGroup, ctx := errgroup.WithContext(ctx) + + errGroup.Go(func() error { + defer cancel() + return d.sessionManager.Serve(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return d.datagramMuxer.ServeReceive(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return d.packetRouter.Serve(ctx) + }) + + return errGroup.Wait() +} + +// RegisterUdpSession is the RPC method invoked by edge to register and run a session +func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) { + traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger) + ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes( + attribute.String("session-id", sessionID.String()), + attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)), + )) + log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger() + // Each session is a series of datagram from an eyeball to a dstIP:dstPort. + // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. + originProxy, err := ingress.DialUDP(dstIP, dstPort) + if err != nil { + log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) + tracing.EndWithErrorStatus(registerSpan, err) + return nil, err + } + registerSpan.SetAttributes( + attribute.Bool("socket-bind-success", true), + attribute.String("src", originProxy.LocalAddr().String()), + ) + + session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) + if err != nil { + originProxy.Close() + log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session") + tracing.EndWithErrorStatus(registerSpan, err) + return nil, err + } + + go q.serveUDPSession(session, closeAfterIdleHint) + + log.Debug(). + Str("sessionID", sessionID.String()). + Str("src", originProxy.LocalAddr().String()). + Str("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)). + Msgf("Registered session") + tracing.End(registerSpan) + + resp := tunnelpogs.RegisterUdpSessionResponse{ + Spans: traceCtx.GetProtoSpans(), + } + + return &resp, nil +} + +// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion +func (q *datagramV2Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { + return q.sessionManager.UnregisterSession(ctx, sessionID, message, true) +} + +func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) { + ctx := q.conn.Context() + closedByRemote, err := session.Serve(ctx, closeAfterIdleHint) + // If session is terminated by remote, then we know it has been unregistered from session manager and edge + if !closedByRemote { + if err != nil { + q.closeUDPSession(ctx, session.ID, err.Error()) + } else { + q.closeUDPSession(ctx, session.ID, "terminated without error") + } + } + q.logger.Debug().Err(err). + Int(management.EventTypeKey, int(management.UDP)). + Str("sessionID", session.ID.String()). + Msg("Session terminated") +} + +// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge +func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) { + q.sessionManager.UnregisterSession(ctx, sessionID, message, false) + quicStream, err := q.conn.OpenStream() + if err != nil { + // Log this at debug because this is not an error if session was closed due to lost connection + // with edge + q.logger.Debug().Err(err). + Int(management.EventTypeKey, int(management.UDP)). + Str("sessionID", sessionID.String()). + Msgf("Failed to open quic stream to unregister udp session with edge") + return + } + + stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) + defer stream.Close() + rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout) + if err != nil { + // Log this at debug because this is not an error if session was closed due to lost connection + // with edge + q.logger.Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to open rpc stream to unregister udp session with edge") + return + } + defer rpcClientStream.Close() + + if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil { + q.logger.Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to unregister udp session with edge") + } +} diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index c30bdb7a..7de2cbd0 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -590,32 +590,55 @@ func (e *EdgeTunnelServer) serveQUIC( InitialPacketSize: initialPacketSize, } - quicConn, err := connection.NewQUICConnection( + // Dial the QUIC connection to the edge + conn, err := connection.DialQuic( ctx, quicConfig, + tlsConfig, edgeAddr, e.edgeBindAddr, connIndex, - tlsConfig, - e.orchestrator, - connOptions, - controlStreamHandler, connLogger.Logger(), - e.config.PacketConfig, - e.config.RPCTimeout, - e.config.WriteStreamTimeout, - e.config.GracePeriod, ) if err != nil { - connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") + connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection") return err, true } + datagramSessionManager := connection.NewDatagramV2Connection( + ctx, + conn, + e.config.PacketConfig, + e.config.RPCTimeout, + e.config.WriteStreamTimeout, + connLogger.Logger(), + ) + + // Wrap the [quic.Connection] as a TunnelConnection + tunnelConn, err := connection.NewTunnelConnection( + ctx, + conn, + connIndex, + e.orchestrator, + datagramSessionManager, + controlStreamHandler, + connOptions, + e.config.RPCTimeout, + e.config.WriteStreamTimeout, + e.config.GracePeriod, + connLogger.Logger(), + ) + if err != nil { + connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new tunnel connection") + return err, true + } + + // Serve the TunnelConnection errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - err := quicConn.Serve(serveCtx) + err := tunnelConn.Serve(serveCtx) if err != nil { - connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection") + connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve tunnel connection") } return err }) @@ -624,8 +647,8 @@ func (e *EdgeTunnelServer) serveQUIC( err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) if err != nil { // forcefully break the connection (this is only used for testing) - // errgroup will return context canceled for the quicConn.Serve - connLogger.Logger().Debug().Msg("Forcefully breaking quic connection") + // errgroup will return context canceled for the tunnelConn.Serve + connLogger.Logger().Debug().Msg("Forcefully breaking tunnel connection") } return err })