239 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			239 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Go
		
	
	
	
| package connection
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/pkg/errors"
 | |
| 	"github.com/rs/zerolog"
 | |
| 	"golang.org/x/sync/errgroup"
 | |
| 
 | |
| 	"github.com/cloudflare/cloudflared/h2mux"
 | |
| 	"github.com/cloudflare/cloudflared/tracing"
 | |
| 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | |
| 	"github.com/cloudflare/cloudflared/websocket"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	muxerTimeout      = 5 * time.Second
 | |
| 	openStreamTimeout = 30 * time.Second
 | |
| )
 | |
| 
 | |
| type h2muxConnection struct {
 | |
| 	orchestrator Orchestrator
 | |
| 	gracePeriod  time.Duration
 | |
| 	muxerConfig  *MuxerConfig
 | |
| 	muxer        *h2mux.Muxer
 | |
| 	// connectionID is only used by metrics, and prometheus requires labels to be string
 | |
| 	connIndexStr string
 | |
| 	connIndex    uint8
 | |
| 
 | |
| 	observer          *Observer
 | |
| 	gracefulShutdownC <-chan struct{}
 | |
| 	stoppedGracefully bool
 | |
| 
 | |
| 	log *zerolog.Logger
 | |
| 
 | |
| 	// newRPCClientFunc allows us to mock RPCs during testing
 | |
| 	newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
 | |
| }
 | |
| 
 | |
| type MuxerConfig struct {
 | |
| 	HeartbeatInterval  time.Duration
 | |
| 	MaxHeartbeats      uint64
 | |
| 	CompressionSetting h2mux.CompressionSetting
 | |
| 	MetricsUpdateFreq  time.Duration
 | |
| }
 | |
| 
 | |
| func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Logger) *h2mux.MuxerConfig {
 | |
| 	return &h2mux.MuxerConfig{
 | |
| 		Timeout:            muxerTimeout,
 | |
| 		Handler:            h,
 | |
| 		IsClient:           true,
 | |
| 		HeartbeatInterval:  mc.HeartbeatInterval,
 | |
| 		MaxHeartbeats:      mc.MaxHeartbeats,
 | |
| 		Log:                log,
 | |
| 		CompressionQuality: mc.CompressionSetting,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
 | |
| func NewH2muxConnection(
 | |
| 	orchestrator Orchestrator,
 | |
| 	gracePeriod time.Duration,
 | |
| 	muxerConfig *MuxerConfig,
 | |
| 	edgeConn net.Conn,
 | |
| 	connIndex uint8,
 | |
| 	observer *Observer,
 | |
| 	gracefulShutdownC <-chan struct{},
 | |
| 	log *zerolog.Logger,
 | |
| ) (*h2muxConnection, error, bool) {
 | |
| 	h := &h2muxConnection{
 | |
| 		orchestrator:      orchestrator,
 | |
| 		gracePeriod:       gracePeriod,
 | |
| 		muxerConfig:       muxerConfig,
 | |
| 		connIndexStr:      uint8ToString(connIndex),
 | |
| 		connIndex:         connIndex,
 | |
| 		observer:          observer,
 | |
| 		gracefulShutdownC: gracefulShutdownC,
 | |
| 		newRPCClientFunc:  newRegistrationRPCClient,
 | |
| 		log:               log,
 | |
| 	}
 | |
| 
 | |
| 	// Establish a muxed connection with the edge
 | |
| 	// Client mux handshake with agent server
 | |
| 	muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer.logTransport), h2mux.ActiveStreams)
 | |
| 	if err != nil {
 | |
| 		recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
 | |
| 		return nil, err, recoverable
 | |
| 	}
 | |
| 	h.muxer = muxer
 | |
| 	return h, nil, false
 | |
| }
 | |
| 
 | |
| func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
 | |
| 	errGroup, serveCtx := errgroup.WithContext(ctx)
 | |
| 	errGroup.Go(func() error {
 | |
| 		return h.serveMuxer(serveCtx)
 | |
| 	})
 | |
| 
 | |
| 	errGroup.Go(func() error {
 | |
| 		if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		connectedFuse.Connected()
 | |
| 		return nil
 | |
| 	})
 | |
| 
 | |
| 	errGroup.Go(func() error {
 | |
| 		h.controlLoop(serveCtx, connectedFuse, true)
 | |
| 		return nil
 | |
| 	})
 | |
| 
 | |
| 	err := errGroup.Wait()
 | |
| 	if err == errMuxerStopped {
 | |
| 		if h.stoppedGracefully {
 | |
| 			return nil
 | |
| 		}
 | |
| 		h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
 | |
| 	// All routines should stop when muxer finish serving. When muxer is shutdown
 | |
| 	// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
 | |
| 	// here to notify other routines to stop
 | |
| 	err := h.muxer.Serve(ctx)
 | |
| 	if err == nil {
 | |
| 		return errMuxerStopped
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
 | |
| 	updateMetricsTicker := time.NewTicker(h.muxerConfig.MetricsUpdateFreq)
 | |
| 	defer updateMetricsTicker.Stop()
 | |
| 	var shutdownCompleted <-chan struct{}
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-h.gracefulShutdownC:
 | |
| 			if connectedFuse.IsConnected() {
 | |
| 				h.unregister(isNamedTunnel)
 | |
| 			}
 | |
| 			h.stoppedGracefully = true
 | |
| 			h.gracefulShutdownC = nil
 | |
| 			shutdownCompleted = h.muxer.Shutdown()
 | |
| 
 | |
| 		case <-shutdownCompleted:
 | |
| 			return
 | |
| 
 | |
| 		case <-ctx.Done():
 | |
| 			// UnregisterTunnel blocks until the RPC call returns
 | |
| 			if !h.stoppedGracefully && connectedFuse.IsConnected() {
 | |
| 				h.unregister(isNamedTunnel)
 | |
| 			}
 | |
| 			h.muxer.Shutdown()
 | |
| 			// don't wait for shutdown to finish when context is closed, this is the hard termination path
 | |
| 			return
 | |
| 
 | |
| 		case <-updateMetricsTicker.C:
 | |
| 			h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
 | |
| 	openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
 | |
| 	defer openStreamCancel()
 | |
| 	stream, err := h.muxer.OpenRPCStream(openStreamCtx)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return stream, nil
 | |
| }
 | |
| 
 | |
| func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
 | |
| 	respWriter := &h2muxRespWriter{stream}
 | |
| 
 | |
| 	req, reqErr := h.newRequest(stream)
 | |
| 	if reqErr != nil {
 | |
| 		respWriter.WriteErrorResponse()
 | |
| 		return reqErr
 | |
| 	}
 | |
| 
 | |
| 	var sourceConnectionType = TypeHTTP
 | |
| 	if websocket.IsWebSocketUpgrade(req) {
 | |
| 		sourceConnectionType = TypeWebsocket
 | |
| 	}
 | |
| 
 | |
| 	originProxy, err := h.orchestrator.GetOriginProxy()
 | |
| 	if err != nil {
 | |
| 		respWriter.WriteErrorResponse()
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, h.connIndex, h.log), sourceConnectionType == TypeWebsocket)
 | |
| 	if err != nil {
 | |
| 		respWriter.WriteErrorResponse()
 | |
| 	}
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
 | |
| 	req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
 | |
| 	if err != nil {
 | |
| 		return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
 | |
| 	}
 | |
| 	err = H2RequestHeadersToH1Request(stream.Headers, req)
 | |
| 	if err != nil {
 | |
| 		return nil, errors.Wrap(err, "invalid request received")
 | |
| 	}
 | |
| 	return req, nil
 | |
| }
 | |
| 
 | |
| type h2muxRespWriter struct {
 | |
| 	*h2mux.MuxedStream
 | |
| }
 | |
| 
 | |
| func (rp *h2muxRespWriter) AddTrailer(trailerName, trailerValue string) {
 | |
| 	// do nothing. we don't support trailers over h2mux
 | |
| }
 | |
| 
 | |
| func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
 | |
| 	headers := H1ResponseToH2ResponseHeaders(status, header)
 | |
| 	headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
 | |
| 	return rp.WriteHeaders(headers)
 | |
| }
 | |
| 
 | |
| func (rp *h2muxRespWriter) WriteErrorResponse() {
 | |
| 	_ = rp.WriteHeaders([]h2mux.Header{
 | |
| 		{Name: ":status", Value: "502"},
 | |
| 		{Name: ResponseMetaHeader, Value: responseMetaHeaderCfd},
 | |
| 	})
 | |
| 	_, _ = rp.Write([]byte("502 Bad Gateway"))
 | |
| }
 |