cloudflared-mirror/connection/h2mux.go

221 lines
6.3 KiB
Go

package connection
import (
"context"
"net"
"net/http"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
)
const (
muxerTimeout = 5 * time.Second
openStreamTimeout = 30 * time.Second
)
type h2muxConnection struct {
config *Config
muxerConfig *MuxerConfig
muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string
connIndexStr string
connIndex uint8
observer *Observer
}
type MuxerConfig struct {
HeartbeatInterval time.Duration
MaxHeartbeats uint64
CompressionSetting h2mux.CompressionSetting
MetricsUpdateFreq time.Duration
}
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.Service) *h2mux.MuxerConfig {
return &h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: h,
IsClient: true,
HeartbeatInterval: mc.HeartbeatInterval,
MaxHeartbeats: mc.MaxHeartbeats,
Logger: logger,
CompressionQuality: mc.CompressionSetting,
}
}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection(ctx context.Context,
config *Config,
muxerConfig *MuxerConfig,
edgeConn net.Conn,
connIndex uint8,
observer *Observer,
) (*h2muxConnection, error, bool) {
h := &h2muxConnection{
config: config,
muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
observer: observer,
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer), h2mux.ActiveStreams)
if err != nil {
recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
return nil, err, recoverable
}
h.muxer = muxer
return h, nil, false
}
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() error {
stream, err := h.newRPCStream(serveCtx, register)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
defer rpcClient.Close()
if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err
}
connectedFuse.Connected()
return nil
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, true)
return nil
})
return errGroup.Wait()
}
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Connected()
}
}()
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
if err == nil {
return nil
}
// log errors and proceed to RegisterTunnel
h.observer.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", h.connIndex, err)
}
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, false)
return nil
})
return errGroup.Wait()
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
// here to notify other routines to stop
err := h.muxer.Serve(ctx)
if err == nil {
return muxerShutdownError{}
}
return err
}
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
for {
select {
case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns
if connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.muxer.Shutdown()
return
case <-updateMetricsTickC:
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
}
}
}
func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := h.muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return nil, err
}
return stream, nil
}
func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
respWriter := &h2muxRespWriter{stream}
req, reqErr := h.newRequest(stream)
if reqErr != nil {
respWriter.WriteErrorResponse()
return reqErr
}
err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
if err != nil {
respWriter.WriteErrorResponse()
return err
}
return nil
}
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
return req, nil
}
type h2muxRespWriter struct {
*h2mux.MuxedStream
}
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseMetaHeaderOrigin})
return rp.WriteHeaders(headers)
}
func (rp *h2muxRespWriter) WriteErrorResponse() {
rp.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"},
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},
})
rp.Write([]byte("502 Bad Gateway"))
}