323 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			323 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
| package connection
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"crypto/tls"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/google/uuid"
 | |
| 	"github.com/lucas-clemente/quic-go"
 | |
| 	"github.com/pkg/errors"
 | |
| 	"github.com/rs/zerolog"
 | |
| 	"golang.org/x/sync/errgroup"
 | |
| 
 | |
| 	quicpogs "github.com/cloudflare/cloudflared/quic"
 | |
| 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	// HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP.
 | |
| 	HTTPHeaderKey = "HttpHeader"
 | |
| 	// HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
 | |
| 	HTTPMethodKey = "HttpMethod"
 | |
| 	// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
 | |
| 	HTTPHostKey          = "HttpHost"
 | |
| 	MaxDatagramFrameSize = 1220
 | |
| )
 | |
| 
 | |
| // QUICConnection represents the type that facilitates Proxying via QUIC streams.
 | |
| type QUICConnection struct {
 | |
| 	session     quic.Session
 | |
| 	logger      *zerolog.Logger
 | |
| 	httpProxy   OriginProxy
 | |
| 	udpSessions *udpSessions
 | |
| }
 | |
| 
 | |
| // NewQUICConnection returns a new instance of QUICConnection.
 | |
| func NewQUICConnection(
 | |
| 	ctx context.Context,
 | |
| 	quicConfig *quic.Config,
 | |
| 	edgeAddr net.Addr,
 | |
| 	tlsConfig *tls.Config,
 | |
| 	httpProxy OriginProxy,
 | |
| 	connOptions *tunnelpogs.ConnectionOptions,
 | |
| 	controlStreamHandler ControlStreamHandler,
 | |
| 	observer *Observer,
 | |
| ) (*QUICConnection, error) {
 | |
| 	localIP, err := GetLocalIP()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP)
 | |
| 	udpSessions := newUDPSessions(localIP)
 | |
| 	session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to dial to edge: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	registrationStream, err := session.OpenStream()
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to open a registration stream: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false)
 | |
| 	if err != nil {
 | |
| 		// Not wrapping error here to be consistent with the http2 message.
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &QUICConnection{
 | |
| 		session:     session,
 | |
| 		httpProxy:   httpProxy,
 | |
| 		logger:      observer.log,
 | |
| 		udpSessions: udpSessions,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // Serve starts a QUIC session that begins accepting streams.
 | |
| func (q *QUICConnection) Serve(ctx context.Context) error {
 | |
| 	errGroup, ctx := errgroup.WithContext(ctx)
 | |
| 	errGroup.Go(func() error {
 | |
| 		return q.listenEdgeDatagram()
 | |
| 	})
 | |
| 
 | |
| 	errGroup.Go(func() error {
 | |
| 		return q.acceptStream(ctx)
 | |
| 	})
 | |
| 	return errGroup.Wait()
 | |
| }
 | |
| 
 | |
| func (q *QUICConnection) acceptStream(ctx context.Context) error {
 | |
| 	for {
 | |
| 		stream, err := q.session.AcceptStream(ctx)
 | |
| 		if err != nil {
 | |
| 			// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
 | |
| 			if errors.Is(err, context.Canceled) {
 | |
| 				return nil
 | |
| 			}
 | |
| 			return fmt.Errorf("failed to accept QUIC stream: %w", err)
 | |
| 		}
 | |
| 		go func() {
 | |
| 			defer stream.Close()
 | |
| 			if err = q.handleStream(stream); err != nil {
 | |
| 				q.logger.Err(err).Msg("Failed to handle QUIC stream")
 | |
| 			}
 | |
| 		}()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload
 | |
| func (q *QUICConnection) listenEdgeDatagram() error {
 | |
| 	for {
 | |
| 		msg, err := q.session.ReceiveMessage()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		go func(msg []byte) {
 | |
| 			sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg)
 | |
| 			if err != nil {
 | |
| 				q.logger.Err(err).Msg("Failed to parse session ID from datagram")
 | |
| 				return
 | |
| 			}
 | |
| 			if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil {
 | |
| 				q.logger.Err(err).Msg("Failed to send UDP to origin")
 | |
| 			}
 | |
| 		}(msg)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Close closes the session with no errors specified.
 | |
| func (q *QUICConnection) Close() {
 | |
| 	q.session.CloseWithError(0, "")
 | |
| }
 | |
| 
 | |
| func (q *QUICConnection) handleStream(stream quic.Stream) error {
 | |
| 	signature, err := quicpogs.DetermineProtocol(stream)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	switch signature {
 | |
| 	case quicpogs.DataStreamProtocolSignature:
 | |
| 		reqServerStream, err := quicpogs.NewRequestServerStream(stream, signature)
 | |
| 		if err != nil {
 | |
| 			return nil
 | |
| 		}
 | |
| 		return q.handleDataStream(reqServerStream)
 | |
| 	case quicpogs.RPCStreamProtocolSignature:
 | |
| 		rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		return q.handleRPCStream(rpcStream)
 | |
| 	default:
 | |
| 		return fmt.Errorf("unknown protocol %v", signature)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error {
 | |
| 	connectRequest, err := stream.ReadConnectRequestData()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	switch connectRequest.Type {
 | |
| 	case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
 | |
| 		req, err := buildHTTPRequest(connectRequest, stream)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		w := newHTTPResponseAdapter(stream)
 | |
| 		return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
 | |
| 	case quicpogs.ConnectionTypeTCP:
 | |
| 		rwa := &streamReadWriteAcker{stream}
 | |
| 		return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
 | |
| 	return rpcStream.Serve(q, q.logger)
 | |
| }
 | |
| 
 | |
| func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
 | |
| 	udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort)
 | |
| 	go q.listenOriginUDP(sessionID, udpConn)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin
 | |
| func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) {
 | |
| 	defer func() {
 | |
| 		q.udpSessions.unregister(sessionID)
 | |
| 		conn.Close()
 | |
| 	}()
 | |
| 	readBuffer := make([]byte, MaxDatagramFrameSize)
 | |
| 	for {
 | |
| 		n, err := conn.Read(readBuffer)
 | |
| 		if n > 0 {
 | |
| 			if n > MaxDatagramFrameSize-sessionIDLen {
 | |
| 				// TODO: TUN-5302 return ICMP packet too big message
 | |
| 				q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen)
 | |
| 				continue
 | |
| 			}
 | |
| 			msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n])
 | |
| 			if err != nil {
 | |
| 				q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped")
 | |
| 				continue
 | |
| 			}
 | |
| 			if err := q.session.SendMessage(msgWithID); err != nil {
 | |
| 				q.logger.Err(err).Msg("Failed to send datagram back to edge")
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			q.logger.Err(err).Msg("Failed to read UDP from origin")
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
 | |
| // the client.
 | |
| type streamReadWriteAcker struct {
 | |
| 	*quicpogs.RequestServerStream
 | |
| }
 | |
| 
 | |
| // AckConnection acks response back to the proxy.
 | |
| func (s *streamReadWriteAcker) AckConnection() error {
 | |
| 	return s.WriteConnectResponseData(nil)
 | |
| }
 | |
| 
 | |
| // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
 | |
| type httpResponseAdapter struct {
 | |
| 	*quicpogs.RequestServerStream
 | |
| }
 | |
| 
 | |
| func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
 | |
| 	return httpResponseAdapter{s}
 | |
| }
 | |
| 
 | |
| func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
 | |
| 	metadata := make([]quicpogs.Metadata, 0)
 | |
| 	metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
 | |
| 	for k, vv := range header {
 | |
| 		for _, v := range vv {
 | |
| 			httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
 | |
| 			metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
 | |
| 		}
 | |
| 	}
 | |
| 	return hrw.WriteConnectResponseData(nil, metadata...)
 | |
| }
 | |
| 
 | |
| func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
 | |
| 	hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
 | |
| }
 | |
| 
 | |
| func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*http.Request, error) {
 | |
| 	metadata := connectRequest.MetadataMap()
 | |
| 	dest := connectRequest.Dest
 | |
| 	method := metadata[HTTPMethodKey]
 | |
| 	host := metadata[HTTPHostKey]
 | |
| 	isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
 | |
| 
 | |
| 	req, err := http.NewRequest(method, dest, body)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	req.Host = host
 | |
| 	for _, metadata := range connectRequest.Metadata {
 | |
| 		if strings.Contains(metadata.Key, HTTPHeaderKey) {
 | |
| 			// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
 | |
| 			httpHeaderKey := strings.Split(metadata.Key, ":")
 | |
| 			if len(httpHeaderKey) != 2 {
 | |
| 				return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
 | |
| 			}
 | |
| 			req.Header.Add(httpHeaderKey[1], metadata.Val)
 | |
| 		}
 | |
| 	}
 | |
| 	// Go's http.Client automatically sends chunked request body if this value is not set on the
 | |
| 	// *http.Request struct regardless of header:
 | |
| 	// https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
 | |
| 	if err := setContentLength(req); err != nil {
 | |
| 		return nil, fmt.Errorf("Error setting content-length: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	// Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
 | |
| 	//   * the request body blocks
 | |
| 	//   * the content length is not set (or set to -1)
 | |
| 	//   * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
 | |
| 	//   * there is no transfer-encoding=chunked already set.
 | |
| 	// So, if transfer cannot be chunked and content length is 0, we dont set a request body.
 | |
| 	if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
 | |
| 		req.Body = nil
 | |
| 	}
 | |
| 	stripWebsocketUpgradeHeader(req)
 | |
| 	return req, err
 | |
| }
 | |
| 
 | |
| func setContentLength(req *http.Request) error {
 | |
| 	var err error
 | |
| 	if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
 | |
| 		req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func isTransferEncodingChunked(req *http.Request) bool {
 | |
| 	transferEncodingVal := req.Header.Get("Transfer-Encoding")
 | |
| 	// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
 | |
| 	// separated value as well.
 | |
| 	return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
 | |
| }
 |