package connection import ( "context" "crypto/tls" "fmt" "io" "net" "net/http" "strconv" "strings" "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" "github.com/rs/zerolog" quicpogs "github.com/cloudflare/cloudflared/quic" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) const ( // HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP. HTTPHeaderKey = "HttpHeader" // HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP. HTTPMethodKey = "HttpMethod" // HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP. HTTPHostKey = "HttpHost" ) // QUICConnection represents the type that facilitates Proxying via QUIC streams. type QUICConnection struct { session quic.Session logger *zerolog.Logger httpProxy OriginProxy gracefulShutdownC <-chan struct{} stoppedGracefully bool } // NewQUICConnection returns a new instance of QUICConnection. func NewQUICConnection( ctx context.Context, quicConfig *quic.Config, edgeAddr net.Addr, tlsConfig *tls.Config, httpProxy OriginProxy, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, observer *Observer, ) (*QUICConnection, error) { session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { return nil, fmt.Errorf("failed to dial to edge: %w", err) } registrationStream, err := session.OpenStream() if err != nil { return nil, fmt.Errorf("failed to open a registration stream: %w", err) } err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false) if err != nil { // Not wrapping error here to be consistent with the http2 message. return nil, err } return &QUICConnection{ session: session, httpProxy: httpProxy, logger: observer.log, }, nil } // Serve starts a QUIC session that begins accepting streams. func (q *QUICConnection) Serve(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) defer cancel() for { stream, err := q.session.AcceptStream(ctx) if err != nil { // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. if errors.Is(err, context.Canceled) { return nil } return fmt.Errorf("failed to accept QUIC stream: %w", err) } go func() { defer stream.Close() if err = q.handleStream(stream); err != nil { q.logger.Err(err).Msg("Failed to handle QUIC stream") } }() } } // Close closes the session with no errors specified. func (q *QUICConnection) Close() { q.session.CloseWithError(0, "") } func (q *QUICConnection) handleStream(stream quic.Stream) error { connectRequest, err := quicpogs.ReadConnectRequestData(stream) if err != nil { return err } switch connectRequest.Type { case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: req, err := buildHTTPRequest(connectRequest, stream) if err != nil { return err } w := newHTTPResponseAdapter(stream) return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) case quicpogs.ConnectionTypeTCP: rwa := &streamReadWriteAcker{ ReadWriter: stream, } return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) } return nil } // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // the client. type streamReadWriteAcker struct { io.ReadWriter } // AckConnection acks response back to the proxy. func (s *streamReadWriteAcker) AckConnection() error { return quicpogs.WriteConnectResponseData(s, nil) } // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. type httpResponseAdapter struct { io.Writer } func newHTTPResponseAdapter(w io.Writer) httpResponseAdapter { return httpResponseAdapter{w} } func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { metadata := make([]quicpogs.Metadata, 0) metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) for k, vv := range header { for _, v := range vv { httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k) metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v}) } } return quicpogs.WriteConnectResponseData(hrw, nil, metadata...) } func (hrw httpResponseAdapter) WriteErrorResponse(err error) { quicpogs.WriteConnectResponseData(hrw, err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*http.Request, error) { metadata := connectRequest.MetadataMap() dest := connectRequest.Dest method := metadata[HTTPMethodKey] host := metadata[HTTPHostKey] isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket req, err := http.NewRequest(method, dest, body) if err != nil { return nil, err } req.Host = host for _, metadata := range connectRequest.Metadata { if strings.Contains(metadata.Key, HTTPHeaderKey) { // metadata.Key is off the format httpHeaderKey: httpHeaderKey := strings.Split(metadata.Key, ":") if len(httpHeaderKey) != 2 { return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key) } req.Header.Add(httpHeaderKey[1], metadata.Val) } } // Go's http.Client automatically sends chunked request body if this value is not set on the // *http.Request struct regardless of header: // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154. if err := setContentLength(req); err != nil { return nil, fmt.Errorf("Error setting content-length: %w", err) } // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true: // * the request body blocks // * the content length is not set (or set to -1) // * the method doesn't usually have a body (GET, HEAD, DELETE, ...) // * there is no transfer-encoding=chunked already set. // So, if transfer cannot be chunked and content length is 0, we dont set a request body. if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 { req.Body = nil } stripWebsocketUpgradeHeader(req) return req, err } func setContentLength(req *http.Request) error { var err error if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" { req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) } return err } func isTransferEncodingChunked(req *http.Request) bool { transferEncodingVal := req.Header.Get("Transfer-Encoding") // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma // separated value as well. return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") }