250 lines
7.0 KiB
Go
250 lines
7.0 KiB
Go
package connection
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/cloudflare/cloudflared/h2mux"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
"github.com/cloudflare/cloudflared/websocket"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/rs/zerolog"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
const (
|
|
muxerTimeout = 5 * time.Second
|
|
openStreamTimeout = 30 * time.Second
|
|
)
|
|
|
|
type h2muxConnection struct {
|
|
config *Config
|
|
muxerConfig *MuxerConfig
|
|
muxer *h2mux.Muxer
|
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
|
connIndexStr string
|
|
connIndex uint8
|
|
|
|
observer *Observer
|
|
gracefulShutdownC chan struct{}
|
|
stoppedGracefully bool
|
|
|
|
// newRPCClientFunc allows us to mock RPCs during testing
|
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
|
}
|
|
|
|
type MuxerConfig struct {
|
|
HeartbeatInterval time.Duration
|
|
MaxHeartbeats uint64
|
|
CompressionSetting h2mux.CompressionSetting
|
|
MetricsUpdateFreq time.Duration
|
|
}
|
|
|
|
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Logger) *h2mux.MuxerConfig {
|
|
return &h2mux.MuxerConfig{
|
|
Timeout: muxerTimeout,
|
|
Handler: h,
|
|
IsClient: true,
|
|
HeartbeatInterval: mc.HeartbeatInterval,
|
|
MaxHeartbeats: mc.MaxHeartbeats,
|
|
Log: log,
|
|
CompressionQuality: mc.CompressionSetting,
|
|
}
|
|
}
|
|
|
|
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
|
func NewH2muxConnection(
|
|
config *Config,
|
|
muxerConfig *MuxerConfig,
|
|
edgeConn net.Conn,
|
|
connIndex uint8,
|
|
observer *Observer,
|
|
gracefulShutdownC chan struct{},
|
|
) (*h2muxConnection, error, bool) {
|
|
h := &h2muxConnection{
|
|
config: config,
|
|
muxerConfig: muxerConfig,
|
|
connIndexStr: uint8ToString(connIndex),
|
|
connIndex: connIndex,
|
|
observer: observer,
|
|
gracefulShutdownC: gracefulShutdownC,
|
|
newRPCClientFunc: newRegistrationRPCClient,
|
|
}
|
|
|
|
// Establish a muxed connection with the edge
|
|
// Client mux handshake with agent server
|
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer.log), h2mux.ActiveStreams)
|
|
if err != nil {
|
|
recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
|
|
return nil, err, recoverable
|
|
}
|
|
h.muxer = muxer
|
|
return h, nil, false
|
|
}
|
|
|
|
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
|
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
return h.serveMuxer(serveCtx)
|
|
})
|
|
|
|
errGroup.Go(func() error {
|
|
if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
|
|
return err
|
|
}
|
|
connectedFuse.Connected()
|
|
return nil
|
|
})
|
|
|
|
errGroup.Go(func() error {
|
|
h.controlLoop(serveCtx, connectedFuse, true)
|
|
return nil
|
|
})
|
|
|
|
err := errGroup.Wait()
|
|
if err == errMuxerStopped {
|
|
if h.stoppedGracefully {
|
|
return nil
|
|
}
|
|
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
|
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
return h.serveMuxer(serveCtx)
|
|
})
|
|
|
|
errGroup.Go(func() (err error) {
|
|
defer func() {
|
|
if err == nil {
|
|
connectedFuse.Connected()
|
|
}
|
|
}()
|
|
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
|
|
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
// log errors and proceed to RegisterTunnel
|
|
h.observer.log.Err(err).
|
|
Uint8(LogFieldConnIndex, h.connIndex).
|
|
Msg("Couldn't reconnect connection. Re-registering it instead.")
|
|
}
|
|
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
|
|
})
|
|
|
|
errGroup.Go(func() error {
|
|
h.controlLoop(serveCtx, connectedFuse, false)
|
|
return nil
|
|
})
|
|
|
|
err := errGroup.Wait()
|
|
if err == errMuxerStopped {
|
|
if h.stoppedGracefully {
|
|
return nil
|
|
}
|
|
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
|
|
// All routines should stop when muxer finish serving. When muxer is shutdown
|
|
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
|
// here to notify other routines to stop
|
|
err := h.muxer.Serve(ctx)
|
|
if err == nil {
|
|
return errMuxerStopped
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
|
|
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
|
|
for {
|
|
select {
|
|
case <-h.gracefulShutdownC:
|
|
if connectedFuse.IsConnected() {
|
|
h.unregister(isNamedTunnel)
|
|
}
|
|
h.stoppedGracefully = true
|
|
h.gracefulShutdownC = nil
|
|
|
|
case <-ctx.Done():
|
|
// UnregisterTunnel blocks until the RPC call returns
|
|
if !h.stoppedGracefully && connectedFuse.IsConnected() {
|
|
h.unregister(isNamedTunnel)
|
|
}
|
|
h.muxer.Shutdown()
|
|
return
|
|
|
|
case <-updateMetricsTickC:
|
|
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
|
|
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
|
defer openStreamCancel()
|
|
stream, err := h.muxer.OpenRPCStream(openStreamCtx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return stream, nil
|
|
}
|
|
|
|
func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
|
|
respWriter := &h2muxRespWriter{stream}
|
|
|
|
req, reqErr := h.newRequest(stream)
|
|
if reqErr != nil {
|
|
respWriter.WriteErrorResponse()
|
|
return reqErr
|
|
}
|
|
|
|
err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
|
if err != nil {
|
|
respWriter.WriteErrorResponse()
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
|
req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
|
}
|
|
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "invalid request received")
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
type h2muxRespWriter struct {
|
|
*h2mux.MuxedStream
|
|
}
|
|
|
|
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
|
|
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
|
|
headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
|
|
return rp.WriteHeaders(headers)
|
|
}
|
|
|
|
func (rp *h2muxRespWriter) WriteErrorResponse() {
|
|
_ = rp.WriteHeaders([]h2mux.Header{
|
|
{Name: ":status", Value: "502"},
|
|
{Name: ResponseMetaHeaderField, Value: responseMetaHeaderCfd},
|
|
})
|
|
_, _ = rp.Write([]byte("502 Bad Gateway"))
|
|
}
|