diff --git a/connection/connection.go b/connection/connection.go index 50464e4a..b7376e38 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -36,6 +36,13 @@ var ( flushableContentTypes = []string{sseContentType, grpcContentType} ) +// TunnelConnection represents the connection to the edge. +// The Serve method is provided to allow clients to handle any errors from the connection encountered during +// processing of the connection. Cancelling of the context provided to Serve will close the connection. +type TunnelConnection interface { + Serve(ctx context.Context) error +} + type Orchestrator interface { UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse GetConfigJSON() ([]byte, error) diff --git a/connection/quic.go b/connection/quic.go index cbf5b186..3109d77f 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -1,51 +1,16 @@ package connection import ( - "bufio" "context" "crypto/tls" "fmt" - "io" "net" - "net/http" "net/netip" "runtime" - "strconv" - "strings" "sync" - "sync/atomic" - "time" - "github.com/google/uuid" - "github.com/pkg/errors" "github.com/quic-go/quic-go" "github.com/rs/zerolog" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" - - "github.com/cloudflare/cloudflared/datagramsession" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/management" - "github.com/cloudflare/cloudflared/packet" - cfdquic "github.com/cloudflare/cloudflared/quic" - "github.com/cloudflare/cloudflared/tracing" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" -) - -const ( - // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPHeaderKey = "HttpHeader" - // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPMethodKey = "HttpMethod" - // HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP. - HTTPHostKey = "HttpHost" - - QUICMetadataFlowID = "FlowID" - // emperically this capacity has been working well - demuxChanCapacity = 16 ) var ( @@ -53,48 +18,21 @@ var ( portMapMutex sync.Mutex ) -// QUICConnection represents the type that facilitates Proxying via QUIC streams. -type QUICConnection struct { - session quic.Connection - logger *zerolog.Logger - orchestrator Orchestrator - // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer - sessionManager datagramsession.Manager - // datagramMuxer mux/demux datagrams from quic connection - datagramMuxer *cfdquic.DatagramMuxerV2 - packetRouter *ingress.PacketRouter - controlStreamHandler ControlStreamHandler - connOptions *tunnelpogs.ConnectionOptions - connIndex uint8 - - rpcTimeout time.Duration - streamWriteTimeout time.Duration - gracePeriod time.Duration -} - -// NewQUICConnection returns a new instance of QUICConnection. -func NewQUICConnection( +func DialQuic( ctx context.Context, quicConfig *quic.Config, + tlsConfig *tls.Config, edgeAddr netip.AddrPort, localAddr net.IP, connIndex uint8, - tlsConfig *tls.Config, - orchestrator Orchestrator, - connOptions *tunnelpogs.ConnectionOptions, - controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, - packetRouterConfig *ingress.GlobalRouterConfig, - rpcTimeout time.Duration, - streamWriteTimeout time.Duration, - gracePeriod time.Duration, -) (*QUICConnection, error) { +) (quic.Connection, error) { udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, logger) if err != nil { return nil, err } - session, err := quic.Dial(ctx, udpConn, net.UDPAddrFromAddrPort(edgeAddr), tlsConfig, quicConfig) + conn, err := quic.Dial(ctx, udpConn, net.UDPAddrFromAddrPort(edgeAddr), tlsConfig, quicConfig) if err != nil { // close the udp server socket in case of error connecting to the edge udpConn.Close() @@ -102,506 +40,11 @@ func NewQUICConnection( } // wrap the session, so that the UDPConn is closed after session is closed. - session = &wrapCloseableConnQuicConnection{ - session, + conn = &wrapCloseableConnQuicConnection{ + conn, udpConn, } - - sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) - datagramMuxer := cfdquic.NewDatagramMuxerV2(session, logger, sessionDemuxChan) - sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) - packetRouter := ingress.NewPacketRouter(packetRouterConfig, datagramMuxer, logger) - - return &QUICConnection{ - session: session, - orchestrator: orchestrator, - logger: logger, - sessionManager: sessionManager, - datagramMuxer: datagramMuxer, - packetRouter: packetRouter, - controlStreamHandler: controlStreamHandler, - connOptions: connOptions, - connIndex: connIndex, - rpcTimeout: rpcTimeout, - streamWriteTimeout: streamWriteTimeout, - gracePeriod: gracePeriod, - }, nil -} - -// Serve starts a QUIC session that begins accepting streams. -func (q *QUICConnection) Serve(ctx context.Context) error { - // origintunneld assumes the first stream is used for the control plane - controlStream, err := q.session.OpenStream() - if err != nil { - return fmt.Errorf("failed to open a registration control stream: %w", err) - } - - // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits - // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this - // connection). - // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the - // other goroutine as fast as possible. - ctx, cancel := context.WithCancel(ctx) - errGroup, ctx := errgroup.WithContext(ctx) - - // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control - // stream is already fully registered before the other goroutines can proceed. - errGroup.Go(func() error { - // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full - // amount of the grace period, allowing requests to finish before we cancel the context, which will - // make cloudflared exit. - if err := q.serveControlStream(ctx, controlStream); err == nil { - select { - case <-ctx.Done(): - case <-time.Tick(q.gracePeriod): - } - } - cancel() - return err - }) - errGroup.Go(func() error { - defer cancel() - return q.acceptStream(ctx) - }) - errGroup.Go(func() error { - defer cancel() - return q.sessionManager.Serve(ctx) - }) - errGroup.Go(func() error { - defer cancel() - return q.datagramMuxer.ServeReceive(ctx) - }) - errGroup.Go(func() error { - defer cancel() - return q.packetRouter.Serve(ctx) - }) - - return errGroup.Wait() -} - -func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { - // This blocks until the control plane is done. - err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator) - if err != nil { - // Not wrapping error here to be consistent with the http2 message. - return err - } - - return nil -} - -// Close closes the session with no errors specified. -func (q *QUICConnection) Close() { - q.session.CloseWithError(0, "") -} - -func (q *QUICConnection) acceptStream(ctx context.Context) error { - defer q.Close() - for { - quicStream, err := q.session.AcceptStream(ctx) - if err != nil { - // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. - if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { - return nil - } - return fmt.Errorf("failed to accept QUIC stream: %w", err) - } - go q.runStream(quicStream) - } -} - -func (q *QUICConnection) runStream(quicStream quic.Stream) { - ctx := quicStream.Context() - stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() - - // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that - // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. - // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream. - // A call to close will simulate a close to the read-side, which will fail subsequent reads. - noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} - ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q, q, q.rpcTimeout) - if err := ss.Serve(ctx, noCloseStream); err != nil { - q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream") - - // if we received an error at this level, then close write side of stream with an error, which will result in - // RST_STREAM frame. - quicStream.CancelWrite(0) - } -} - -func (q *QUICConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error { - request, err := stream.ReadConnectRequestData() - if err != nil { - return err - } - - if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil { - q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") - - // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is - // closed with an RST_STREAM frame - if connectResponseSent { - return err - } - - if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { - return writeRespErr - } - } - - return nil -} - -// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs. -// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream. -// This is important since it informs -func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, err error, request *pogs.ConnectRequest) (error, bool) { - originProxy, err := q.orchestrator.GetOriginProxy() - if err != nil { - return err, false - } - - switch request.Type { - case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket: - tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger) - if err != nil { - return err, false - } - w := newHTTPResponseAdapter(stream) - return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent - - case pogs.ConnectionTypeTCP: - rwa := &streamReadWriteAcker{RequestServerStream: stream} - metadata := request.MetadataMap() - return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ - Dest: request.Dest, - FlowID: metadata[QUICMetadataFlowID], - CfTraceID: metadata[tracing.TracerContextName], - ConnIndex: q.connIndex, - }), rwa.connectResponseSent - default: - return errors.Errorf("unsupported error type: %s", request.Type), false - } -} - -// RegisterUdpSession is the RPC method invoked by edge to register and run a session -func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) { - traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger) - ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes( - attribute.String("session-id", sessionID.String()), - attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)), - )) - log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger() - // Each session is a series of datagram from an eyeball to a dstIP:dstPort. - // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. - originProxy, err := ingress.DialUDP(dstIP, dstPort) - if err != nil { - log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) - tracing.EndWithErrorStatus(registerSpan, err) - return nil, err - } - registerSpan.SetAttributes( - attribute.Bool("socket-bind-success", true), - attribute.String("src", originProxy.LocalAddr().String()), - ) - - session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) - if err != nil { - originProxy.Close() - log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session") - tracing.EndWithErrorStatus(registerSpan, err) - return nil, err - } - - go q.serveUDPSession(session, closeAfterIdleHint) - - log.Debug(). - Str("sessionID", sessionID.String()). - Str("src", originProxy.LocalAddr().String()). - Str("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)). - Msgf("Registered session") - tracing.End(registerSpan) - - resp := tunnelpogs.RegisterUdpSessionResponse{ - Spans: traceCtx.GetProtoSpans(), - } - - return &resp, nil -} - -func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) { - ctx := q.session.Context() - closedByRemote, err := session.Serve(ctx, closeAfterIdleHint) - // If session is terminated by remote, then we know it has been unregistered from session manager and edge - if !closedByRemote { - if err != nil { - q.closeUDPSession(ctx, session.ID, err.Error()) - } else { - q.closeUDPSession(ctx, session.ID, "terminated without error") - } - } - q.logger.Debug().Err(err). - Int(management.EventTypeKey, int(management.UDP)). - Str("sessionID", session.ID.String()). - Msg("Session terminated") -} - -// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge -func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) { - q.sessionManager.UnregisterSession(ctx, sessionID, message, false) - quicStream, err := q.session.OpenStream() - if err != nil { - // Log this at debug because this is not an error if session was closed due to lost connection - // with edge - q.logger.Debug().Err(err). - Int(management.EventTypeKey, int(management.UDP)). - Str("sessionID", sessionID.String()). - Msgf("Failed to open quic stream to unregister udp session with edge") - return - } - - stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() - rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout) - if err != nil { - // Log this at debug because this is not an error if session was closed due to lost connection - // with edge - q.logger.Err(err).Str("sessionID", sessionID.String()). - Msgf("Failed to open rpc stream to unregister udp session with edge") - return - } - defer rpcClientStream.Close() - - if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil { - q.logger.Err(err).Str("sessionID", sessionID.String()). - Msgf("Failed to unregister udp session with edge") - } -} - -// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion -func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { - return q.sessionManager.UnregisterSession(ctx, sessionID, message, true) -} - -// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration -func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { - return q.orchestrator.UpdateConfig(version, config) -} - -// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to -// the client. -type streamReadWriteAcker struct { - *rpcquic.RequestServerStream - connectResponseSent bool -} - -// AckConnection acks response back to the proxy. -func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { - metadata := []pogs.Metadata{} - // Only add tracing if provided by origintunneld - if tracePropagation != "" { - metadata = append(metadata, pogs.Metadata{ - Key: tracing.CanonicalCloudflaredTracingHeader, - Val: tracePropagation, - }) - } - s.connectResponseSent = true - return s.WriteConnectResponseData(nil, metadata...) -} - -// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. -type httpResponseAdapter struct { - *rpcquic.RequestServerStream - headers http.Header - connectResponseSent bool -} - -func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter { - return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)} -} - -func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { - // we do not support trailers over QUIC -} - -func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { - metadata := make([]pogs.Metadata, 0) - metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) - for k, vv := range header { - for _, v := range vv { - httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k) - metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v}) - } - } - - return hrw.WriteConnectResponseData(nil, metadata...) -} - -func (hrw *httpResponseAdapter) Write(p []byte) (int, error) { - // Make sure to send WriteHeader response if not called yet - if !hrw.connectResponseSent { - hrw.WriteRespHeaders(http.StatusOK, hrw.headers) - } - return hrw.RequestServerStream.Write(p) -} - -func (hrw *httpResponseAdapter) Header() http.Header { - return hrw.headers -} - -// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here. -func (hrw *httpResponseAdapter) Flush() {} - -func (hrw *httpResponseAdapter) WriteHeader(status int) { - hrw.WriteRespHeaders(status, hrw.headers) -} - -func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - conn := &localProxyConnection{hrw.ReadWriteCloser} - readWriter := bufio.NewReadWriter( - bufio.NewReader(hrw.ReadWriteCloser), - bufio.NewWriter(hrw.ReadWriteCloser), - ) - return conn, readWriter, nil -} - -func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { - hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) -} - -func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error { - hrw.connectResponseSent = true - return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...) -} - -func buildHTTPRequest( - ctx context.Context, - connectRequest *pogs.ConnectRequest, - body io.ReadCloser, - connIndex uint8, - log *zerolog.Logger, -) (*tracing.TracedHTTPRequest, error) { - metadata := connectRequest.MetadataMap() - dest := connectRequest.Dest - method := metadata[HTTPMethodKey] - host := metadata[HTTPHostKey] - isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket - - req, err := http.NewRequestWithContext(ctx, method, dest, body) - if err != nil { - return nil, err - } - - req.Host = host - for _, metadata := range connectRequest.Metadata { - if strings.Contains(metadata.Key, HTTPHeaderKey) { - // metadata.Key is off the format httpHeaderKey: - httpHeaderKey := strings.Split(metadata.Key, ":") - if len(httpHeaderKey) != 2 { - return nil, fmt.Errorf("header Key: %s malformed", metadata.Key) - } - req.Header.Add(httpHeaderKey[1], metadata.Val) - } - } - // Go's http.Client automatically sends chunked request body if this value is not set on the - // *http.Request struct regardless of header: - // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154. - if err := setContentLength(req); err != nil { - return nil, fmt.Errorf("Error setting content-length: %w", err) - } - - // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true: - // * the request body blocks - // * the content length is not set (or set to -1) - // * the method doesn't usually have a body (GET, HEAD, DELETE, ...) - // * there is no transfer-encoding=chunked already set. - // So, if transfer cannot be chunked and content length is 0, we dont set a request body. - if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 { - req.Body = http.NoBody - } - stripWebsocketUpgradeHeader(req) - - // Check for tracing on request - tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log) - return tracedReq, err -} - -func setContentLength(req *http.Request) error { - var err error - if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" { - req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) - } - return err -} - -func isTransferEncodingChunked(req *http.Request) bool { - transferEncodingVal := req.Header.Get("Transfer-Encoding") - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma - // separated value as well. - return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") -} - -// A helper struct that guarantees a call to close only affects read side, but not write side. -type nopCloserReadWriter struct { - io.ReadWriteCloser - - // for use by Read only - // we don't need a memory barrier here because there is an implicit assumption that - // Read calls can't happen concurrently by different go-routines. - sawEOF bool - // should be updated and read using atomic primitives. - // value is read in Read method and written in Close method, which could be done by different - // go-routines. - closed uint32 -} - -func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) { - if np.sawEOF { - return 0, io.EOF - } - - if atomic.LoadUint32(&np.closed) > 0 { - return 0, fmt.Errorf("closed by handler") - } - - n, err = np.ReadWriteCloser.Read(p) - if err == io.EOF { - np.sawEOF = true - } - - return -} - -func (np *nopCloserReadWriter) Close() error { - atomic.StoreUint32(&np.closed, 1) - - return nil -} - -// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface -type muxerWrapper struct { - muxer *cfdquic.DatagramMuxerV2 -} - -func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error { - return rp.muxer.SendPacket(cfdquic.RawPacket(pk)) -} - -func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { - pk, err := rp.muxer.ReceivePacket(ctx) - if err != nil { - return packet.RawPacket{}, err - } - rawPacket, ok := pk.(cfdquic.RawPacket) - if ok { - return packet.RawPacket(rawPacket), nil - } - return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk) -} - -func (rp *muxerWrapper) Close() error { - return nil + return conn, nil } func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, logger *zerolog.Logger) (*net.UDPConn, error) { diff --git a/connection/quic_connection.go b/connection/quic_connection.go new file mode 100644 index 00000000..d0baab5e --- /dev/null +++ b/connection/quic_connection.go @@ -0,0 +1,444 @@ +package connection + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/pkg/errors" + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + + "github.com/cloudflare/cloudflared/packet" + cfdquic "github.com/cloudflare/cloudflared/quic" + "github.com/cloudflare/cloudflared/tracing" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" +) + +const ( + // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP. + HTTPHeaderKey = "HttpHeader" + // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. + HTTPMethodKey = "HttpMethod" + // HTTPHostKey is used to get or set http host in QUIC ALPN if the underlying proxy connection type is HTTP. + HTTPHostKey = "HttpHost" + + QUICMetadataFlowID = "FlowID" +) + +// quicConnection represents the type that facilitates Proxying via QUIC streams. +type quicConnection struct { + conn quic.Connection + logger *zerolog.Logger + orchestrator Orchestrator + datagramHandler DatagramSessionHandler + controlStreamHandler ControlStreamHandler + connOptions *tunnelpogs.ConnectionOptions + connIndex uint8 + + rpcTimeout time.Duration + streamWriteTimeout time.Duration + gracePeriod time.Duration +} + +// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic. +func NewTunnelConnection( + ctx context.Context, + conn quic.Connection, + connIndex uint8, + orchestrator Orchestrator, + datagramSessionHandler DatagramSessionHandler, + controlStreamHandler ControlStreamHandler, + connOptions *pogs.ConnectionOptions, + rpcTimeout time.Duration, + streamWriteTimeout time.Duration, + gracePeriod time.Duration, + logger *zerolog.Logger, +) (TunnelConnection, error) { + return &quicConnection{ + conn: conn, + logger: logger, + orchestrator: orchestrator, + datagramHandler: datagramSessionHandler, + controlStreamHandler: controlStreamHandler, + connOptions: connOptions, + connIndex: connIndex, + rpcTimeout: rpcTimeout, + streamWriteTimeout: streamWriteTimeout, + gracePeriod: gracePeriod, + }, nil +} + +// Serve starts a QUIC connection that begins accepting streams. +func (q *quicConnection) Serve(ctx context.Context) error { + // The edge assumes the first stream is used for the control plane + controlStream, err := q.conn.OpenStream() + if err != nil { + return fmt.Errorf("failed to open a registration control stream: %w", err) + } + + // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits + // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this + // connection). + // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the + // other goroutine as fast as possible. + ctx, cancel := context.WithCancel(ctx) + errGroup, ctx := errgroup.WithContext(ctx) + + // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control + // stream is already fully registered before the other goroutines can proceed. + errGroup.Go(func() error { + // err is equal to nil if we exit due to unregistration. If that happens we want to wait the full + // amount of the grace period, allowing requests to finish before we cancel the context, which will + // make cloudflared exit. + if err := q.serveControlStream(ctx, controlStream); err == nil { + select { + case <-ctx.Done(): + case <-time.Tick(q.gracePeriod): + } + } + cancel() + return err + + }) + errGroup.Go(func() error { + defer cancel() + return q.acceptStream(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return q.datagramHandler.Serve(ctx) + }) + + return errGroup.Wait() +} + +// serveControlStream will serve the RPC; blocking until the control plane is done. +func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { + return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator) +} + +// Close the connection with no errors specified. +func (q *quicConnection) Close() { + q.conn.CloseWithError(0, "") +} + +func (q *quicConnection) acceptStream(ctx context.Context) error { + defer q.Close() + for { + quicStream, err := q.conn.AcceptStream(ctx) + if err != nil { + // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. + if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { + return nil + } + return fmt.Errorf("failed to accept QUIC stream: %w", err) + } + go q.runStream(quicStream) + } +} + +func (q *quicConnection) runStream(quicStream quic.Stream) { + ctx := quicStream.Context() + stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) + defer stream.Close() + + // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that + // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. + // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream. + // A call to close will simulate a close to the read-side, which will fail subsequent reads. + noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} + ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q.datagramHandler, q, q.rpcTimeout) + if err := ss.Serve(ctx, noCloseStream); err != nil { + q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream") + + // if we received an error at this level, then close write side of stream with an error, which will result in + // RST_STREAM frame. + quicStream.CancelWrite(0) + } +} + +func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error { + request, err := stream.ReadConnectRequestData() + if err != nil { + return err + } + + if err, connectResponseSent := q.dispatchRequest(ctx, stream, request); err != nil { + q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") + + // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is + // closed with an RST_STREAM frame + if connectResponseSent { + return err + } + + if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { + return writeRespErr + } + } + + return nil +} + +// dispatchRequest will dispatch the request to the origin depending on the type and returns an error if it occurs. +// Also returns if the connect response was sent to the downstream during processing of the origin request. +func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, request *pogs.ConnectRequest) (err error, connectResponseSent bool) { + originProxy, err := q.orchestrator.GetOriginProxy() + if err != nil { + return err, false + } + + switch request.Type { + case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket: + tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger) + if err != nil { + return err, false + } + w := newHTTPResponseAdapter(stream) + return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent + + case pogs.ConnectionTypeTCP: + rwa := &streamReadWriteAcker{RequestServerStream: stream} + metadata := request.MetadataMap() + return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ + Dest: request.Dest, + FlowID: metadata[QUICMetadataFlowID], + CfTraceID: metadata[tracing.TracerContextName], + ConnIndex: q.connIndex, + }), rwa.connectResponseSent + default: + return errors.Errorf("unsupported error type: %s", request.Type), false + } +} + +// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration +func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + return q.orchestrator.UpdateConfig(version, config) +} + +// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to +// the client. +type streamReadWriteAcker struct { + *rpcquic.RequestServerStream + connectResponseSent bool +} + +// AckConnection acks response back to the proxy. +func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { + metadata := []pogs.Metadata{} + // Only add tracing if provided by the edge request + if tracePropagation != "" { + metadata = append(metadata, pogs.Metadata{ + Key: tracing.CanonicalCloudflaredTracingHeader, + Val: tracePropagation, + }) + } + s.connectResponseSent = true + return s.WriteConnectResponseData(nil, metadata...) +} + +// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. +type httpResponseAdapter struct { + *rpcquic.RequestServerStream + headers http.Header + connectResponseSent bool +} + +func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter { + return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)} +} + +func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { + // we do not support trailers over QUIC +} + +func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { + metadata := make([]pogs.Metadata, 0) + metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) + for k, vv := range header { + for _, v := range vv { + httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k) + metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v}) + } + } + + return hrw.WriteConnectResponseData(nil, metadata...) +} + +func (hrw *httpResponseAdapter) Write(p []byte) (int, error) { + // Make sure to send WriteHeader response if not called yet + if !hrw.connectResponseSent { + hrw.WriteRespHeaders(http.StatusOK, hrw.headers) + } + return hrw.RequestServerStream.Write(p) +} + +func (hrw *httpResponseAdapter) Header() http.Header { + return hrw.headers +} + +// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here. +func (hrw *httpResponseAdapter) Flush() {} + +func (hrw *httpResponseAdapter) WriteHeader(status int) { + hrw.WriteRespHeaders(status, hrw.headers) +} + +func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn := &localProxyConnection{hrw.ReadWriteCloser} + readWriter := bufio.NewReadWriter( + bufio.NewReader(hrw.ReadWriteCloser), + bufio.NewWriter(hrw.ReadWriteCloser), + ) + return conn, readWriter, nil +} + +func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { + hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) +} + +func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error { + hrw.connectResponseSent = true + return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...) +} + +func buildHTTPRequest( + ctx context.Context, + connectRequest *pogs.ConnectRequest, + body io.ReadCloser, + connIndex uint8, + log *zerolog.Logger, +) (*tracing.TracedHTTPRequest, error) { + metadata := connectRequest.MetadataMap() + dest := connectRequest.Dest + method := metadata[HTTPMethodKey] + host := metadata[HTTPHostKey] + isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket + + req, err := http.NewRequestWithContext(ctx, method, dest, body) + if err != nil { + return nil, err + } + + req.Host = host + for _, metadata := range connectRequest.Metadata { + if strings.Contains(metadata.Key, HTTPHeaderKey) { + // metadata.Key is off the format httpHeaderKey: + httpHeaderKey := strings.Split(metadata.Key, ":") + if len(httpHeaderKey) != 2 { + return nil, fmt.Errorf("header Key: %s malformed", metadata.Key) + } + req.Header.Add(httpHeaderKey[1], metadata.Val) + } + } + // Go's http.Client automatically sends chunked request body if this value is not set on the + // *http.Request struct regardless of header: + // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154. + if err := setContentLength(req); err != nil { + return nil, fmt.Errorf("Error setting content-length: %w", err) + } + + // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true: + // * the request body blocks + // * the content length is not set (or set to -1) + // * the method doesn't usually have a body (GET, HEAD, DELETE, ...) + // * there is no transfer-encoding=chunked already set. + // So, if transfer cannot be chunked and content length is 0, we dont set a request body. + if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 { + req.Body = http.NoBody + } + stripWebsocketUpgradeHeader(req) + + // Check for tracing on request + tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log) + return tracedReq, err +} + +func setContentLength(req *http.Request) error { + var err error + if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" { + req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + } + return err +} + +func isTransferEncodingChunked(req *http.Request) bool { + transferEncodingVal := req.Header.Get("Transfer-Encoding") + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma + // separated value as well. + return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") +} + +// A helper struct that guarantees a call to close only affects read side, but not write side. +type nopCloserReadWriter struct { + io.ReadWriteCloser + + // for use by Read only + // we don't need a memory barrier here because there is an implicit assumption that + // Read calls can't happen concurrently by different go-routines. + sawEOF bool + // should be updated and read using atomic primitives. + // value is read in Read method and written in Close method, which could be done by different + // go-routines. + closed uint32 +} + +func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) { + if np.sawEOF { + return 0, io.EOF + } + + if atomic.LoadUint32(&np.closed) > 0 { + return 0, fmt.Errorf("closed by handler") + } + + n, err = np.ReadWriteCloser.Read(p) + if err == io.EOF { + np.sawEOF = true + } + + return +} + +func (np *nopCloserReadWriter) Close() error { + atomic.StoreUint32(&np.closed, 1) + + return nil +} + +// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface +type muxerWrapper struct { + muxer *cfdquic.DatagramMuxerV2 +} + +func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error { + return rp.muxer.SendPacket(cfdquic.RawPacket(pk)) +} + +func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { + pk, err := rp.muxer.ReceivePacket(ctx) + if err != nil { + return packet.RawPacket{}, err + } + rawPacket, ok := pk.(cfdquic.RawPacket) + if ok { + return packet.RawPacket(rawPacket), nil + } + return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk) +} + +func (rp *muxerWrapper) Close() error { + return nil +} diff --git a/connection/quic_test.go b/connection/quic_connection_test.go similarity index 90% rename from connection/quic_test.go rename to connection/quic_connection_test.go index c073b850..ba052437 100644 --- a/connection/quic_test.go +++ b/connection/quic_connection_test.go @@ -15,7 +15,6 @@ import ( "net/http" "net/netip" "net/url" - "os" "strings" "testing" "time" @@ -30,10 +29,11 @@ import ( "golang.org/x/net/nettest" "github.com/cloudflare/cloudflared/datagramsession" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/packet" cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" ) @@ -164,11 +164,11 @@ func TestQUICServer(t *testing.T) { close(serverDone) }() - qc := testQUICConnection(netip.MustParseAddrPort(udpListener.LocalAddr().String()), t, uint8(i)) + tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i)) connDone := make(chan struct{}) go func() { - qc.Serve(ctx) + tunnelConn.Serve(ctx) close(connDone) }() @@ -528,13 +528,14 @@ func TestServeUDPSession(t *testing.T) { }() // Random index to avoid reusing port - qc := testQUICConnection(netip.MustParseAddrPort(udpListener.LocalAddr().String()), t, 28) - go qc.Serve(ctx) + tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28) + go tunnelConn.Serve(ctx) edgeQUICSession := <-edgeQUICSessionChan - serveSession(ctx, qc, edgeQUICSession, closedByOrigin, io.EOF.Error(), t) - serveSession(ctx, qc, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t) - serveSession(ctx, qc, edgeQUICSession, closedByRemote, "eyeball closed connection", t) + + serveSession(ctx, datagramConn, edgeQUICSession, closedByOrigin, io.EOF.Error(), t) + serveSession(ctx, datagramConn, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t) + serveSession(ctx, datagramConn, edgeQUICSession, closedByRemote, "eyeball closed connection", t) cancel() } @@ -619,19 +620,19 @@ func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPo require.NotEqual(t, initialPort, getPortFunc(conn)) } -func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { +func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { var ( payload = []byte(t.Name()) ) sessionID := uuid.New() cfdConn, originConn := net.Pipe() // Registers and run a new session - session, err := qc.sessionManager.RegisterSession(ctx, sessionID, cfdConn) + session, err := datagramConn.sessionManager.RegisterSession(ctx, sessionID, cfdConn) require.NoError(t, err) sessionDone := make(chan struct{}) go func() { - qc.serveUDPSession(session, time.Millisecond*50) + datagramConn.serveUDPSession(session, time.Millisecond*50) close(sessionDone) }() @@ -655,7 +656,7 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic. case closedByOrigin: originConn.Close() case closedByRemote: - err = qc.UnregisterUdpSession(ctx, sessionID, expectedReason) + err = datagramConn.UnregisterUdpSession(ctx, sessionID, expectedReason) require.NoError(t, err) case closedByTimeout: } @@ -726,33 +727,58 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI return nil } -func testQUICConnection(udpListenerAddr netip.AddrPort, t *testing.T, index uint8) *QUICConnection { +func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) { tlsClientConfig := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"argotunnel"}, } // Start a mock httpProxy - log := zerolog.New(os.Stdout) + log := zerolog.New(io.Discard) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - qc, err := NewQUICConnection( + + // Dial the QUIC connection to the edge + conn, err := DialQuic( ctx, testQUICConfig, - udpListenerAddr, - nil, - index, tlsClientConfig, - &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, - &tunnelpogs.ConnectionOptions{}, - fakeControlStream{}, + serverAddr, + nil, // connect on a random port + index, &log, - nil, + ) + + // Start a session manager for the connection + sessionDemuxChan := make(chan *packet.Session, 4) + datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan) + sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) + packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, &log) + + datagramConn := &datagramV2Connection{ + conn, + sessionManager, + datagramMuxer, + packetRouter, + 15 * time.Second, + 0 * time.Second, + &log, + } + + tunnelConn, err := NewTunnelConnection( + ctx, + conn, + index, + &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, + datagramConn, + fakeControlStream{}, + &pogs.ConnectionOptions{}, 15*time.Second, 0*time.Second, 0*time.Second, + &log, ) require.NoError(t, err) - return qc + return tunnelConn, datagramConn } type mockReaderNoopWriter struct { diff --git a/connection/quic_datagram_v2.go b/connection/quic_datagram_v2.go new file mode 100644 index 00000000..1cedaa41 --- /dev/null +++ b/connection/quic_datagram_v2.go @@ -0,0 +1,200 @@ +package connection + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/google/uuid" + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + + "github.com/cloudflare/cloudflared/datagramsession" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/management" + "github.com/cloudflare/cloudflared/packet" + cfdquic "github.com/cloudflare/cloudflared/quic" + "github.com/cloudflare/cloudflared/tracing" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" +) + +const ( + // emperically this capacity has been working well + demuxChanCapacity = 16 +) + +// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming +// connection streams. +type DatagramSessionHandler interface { + Serve(context.Context) error + + pogs.SessionManager +} + +type datagramV2Connection struct { + conn quic.Connection + + // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer + sessionManager datagramsession.Manager + // datagramMuxer mux/demux datagrams from quic connection + datagramMuxer *cfdquic.DatagramMuxerV2 + packetRouter *ingress.PacketRouter + + rpcTimeout time.Duration + streamWriteTimeout time.Duration + + logger *zerolog.Logger +} + +func NewDatagramV2Connection(ctx context.Context, + conn quic.Connection, + packetConfig *ingress.GlobalRouterConfig, + rpcTimeout time.Duration, + streamWriteTimeout time.Duration, + logger *zerolog.Logger, +) DatagramSessionHandler { + sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) + datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan) + sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) + packetRouter := ingress.NewPacketRouter(packetConfig, datagramMuxer, logger) + + return &datagramV2Connection{ + conn, + sessionManager, + datagramMuxer, + packetRouter, + rpcTimeout, + streamWriteTimeout, + logger, + } +} + +func (d *datagramV2Connection) Serve(ctx context.Context) error { + // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits + // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this + // connection). + // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the + // other goroutine as fast as possible. + ctx, cancel := context.WithCancel(ctx) + errGroup, ctx := errgroup.WithContext(ctx) + + errGroup.Go(func() error { + defer cancel() + return d.sessionManager.Serve(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return d.datagramMuxer.ServeReceive(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return d.packetRouter.Serve(ctx) + }) + + return errGroup.Wait() +} + +// RegisterUdpSession is the RPC method invoked by edge to register and run a session +func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) { + traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger) + ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes( + attribute.String("session-id", sessionID.String()), + attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)), + )) + log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger() + // Each session is a series of datagram from an eyeball to a dstIP:dstPort. + // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. + originProxy, err := ingress.DialUDP(dstIP, dstPort) + if err != nil { + log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) + tracing.EndWithErrorStatus(registerSpan, err) + return nil, err + } + registerSpan.SetAttributes( + attribute.Bool("socket-bind-success", true), + attribute.String("src", originProxy.LocalAddr().String()), + ) + + session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) + if err != nil { + originProxy.Close() + log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session") + tracing.EndWithErrorStatus(registerSpan, err) + return nil, err + } + + go q.serveUDPSession(session, closeAfterIdleHint) + + log.Debug(). + Str("sessionID", sessionID.String()). + Str("src", originProxy.LocalAddr().String()). + Str("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)). + Msgf("Registered session") + tracing.End(registerSpan) + + resp := tunnelpogs.RegisterUdpSessionResponse{ + Spans: traceCtx.GetProtoSpans(), + } + + return &resp, nil +} + +// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion +func (q *datagramV2Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { + return q.sessionManager.UnregisterSession(ctx, sessionID, message, true) +} + +func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) { + ctx := q.conn.Context() + closedByRemote, err := session.Serve(ctx, closeAfterIdleHint) + // If session is terminated by remote, then we know it has been unregistered from session manager and edge + if !closedByRemote { + if err != nil { + q.closeUDPSession(ctx, session.ID, err.Error()) + } else { + q.closeUDPSession(ctx, session.ID, "terminated without error") + } + } + q.logger.Debug().Err(err). + Int(management.EventTypeKey, int(management.UDP)). + Str("sessionID", session.ID.String()). + Msg("Session terminated") +} + +// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge +func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) { + q.sessionManager.UnregisterSession(ctx, sessionID, message, false) + quicStream, err := q.conn.OpenStream() + if err != nil { + // Log this at debug because this is not an error if session was closed due to lost connection + // with edge + q.logger.Debug().Err(err). + Int(management.EventTypeKey, int(management.UDP)). + Str("sessionID", sessionID.String()). + Msgf("Failed to open quic stream to unregister udp session with edge") + return + } + + stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) + defer stream.Close() + rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout) + if err != nil { + // Log this at debug because this is not an error if session was closed due to lost connection + // with edge + q.logger.Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to open rpc stream to unregister udp session with edge") + return + } + defer rpcClientStream.Close() + + if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil { + q.logger.Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to unregister udp session with edge") + } +} diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index c30bdb7a..7de2cbd0 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -590,32 +590,55 @@ func (e *EdgeTunnelServer) serveQUIC( InitialPacketSize: initialPacketSize, } - quicConn, err := connection.NewQUICConnection( + // Dial the QUIC connection to the edge + conn, err := connection.DialQuic( ctx, quicConfig, + tlsConfig, edgeAddr, e.edgeBindAddr, connIndex, - tlsConfig, - e.orchestrator, - connOptions, - controlStreamHandler, connLogger.Logger(), - e.config.PacketConfig, - e.config.RPCTimeout, - e.config.WriteStreamTimeout, - e.config.GracePeriod, ) if err != nil { - connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") + connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection") return err, true } + datagramSessionManager := connection.NewDatagramV2Connection( + ctx, + conn, + e.config.PacketConfig, + e.config.RPCTimeout, + e.config.WriteStreamTimeout, + connLogger.Logger(), + ) + + // Wrap the [quic.Connection] as a TunnelConnection + tunnelConn, err := connection.NewTunnelConnection( + ctx, + conn, + connIndex, + e.orchestrator, + datagramSessionManager, + controlStreamHandler, + connOptions, + e.config.RPCTimeout, + e.config.WriteStreamTimeout, + e.config.GracePeriod, + connLogger.Logger(), + ) + if err != nil { + connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new tunnel connection") + return err, true + } + + // Serve the TunnelConnection errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - err := quicConn.Serve(serveCtx) + err := tunnelConn.Serve(serveCtx) if err != nil { - connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection") + connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve tunnel connection") } return err }) @@ -624,8 +647,8 @@ func (e *EdgeTunnelServer) serveQUIC( err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) if err != nil { // forcefully break the connection (this is only used for testing) - // errgroup will return context canceled for the quicConn.Serve - connLogger.Logger().Debug().Msg("Forcefully breaking quic connection") + // errgroup will return context canceled for the tunnelConn.Serve + connLogger.Logger().Debug().Msg("Forcefully breaking tunnel connection") } return err })