152 lines
4.9 KiB
Go
152 lines
4.9 KiB
Go
package connection
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/cloudflare/cloudflared/management"
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
)
|
|
|
|
// registerClient derives a named tunnel rpc client that can then be used to register and unregister connections.
|
|
type registerClientFunc func(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient
|
|
|
|
type controlStream struct {
|
|
observer *Observer
|
|
|
|
connectedFuse ConnectedFuse
|
|
tunnelProperties *TunnelProperties
|
|
connIndex uint8
|
|
edgeAddress net.IP
|
|
protocol Protocol
|
|
|
|
registerClientFunc registerClientFunc
|
|
registerTimeout time.Duration
|
|
|
|
gracefulShutdownC <-chan struct{}
|
|
gracePeriod time.Duration
|
|
stoppedGracefully bool
|
|
}
|
|
|
|
// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
|
|
type ControlStreamHandler interface {
|
|
// ServeControlStream handles the control plane of the transport in the current goroutine calling this
|
|
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error
|
|
// IsStopped tells whether the method above has finished
|
|
IsStopped() bool
|
|
}
|
|
|
|
type TunnelConfigJSONGetter interface {
|
|
GetConfigJSON() ([]byte, error)
|
|
}
|
|
|
|
// NewControlStream returns a new instance of ControlStreamHandler
|
|
func NewControlStream(
|
|
observer *Observer,
|
|
connectedFuse ConnectedFuse,
|
|
tunnelProperties *TunnelProperties,
|
|
connIndex uint8,
|
|
edgeAddress net.IP,
|
|
registerClientFunc registerClientFunc,
|
|
registerTimeout time.Duration,
|
|
gracefulShutdownC <-chan struct{},
|
|
gracePeriod time.Duration,
|
|
protocol Protocol,
|
|
) ControlStreamHandler {
|
|
if registerClientFunc == nil {
|
|
registerClientFunc = tunnelrpc.NewRegistrationClient
|
|
}
|
|
return &controlStream{
|
|
observer: observer,
|
|
connectedFuse: connectedFuse,
|
|
tunnelProperties: tunnelProperties,
|
|
registerClientFunc: registerClientFunc,
|
|
registerTimeout: registerTimeout,
|
|
connIndex: connIndex,
|
|
edgeAddress: edgeAddress,
|
|
gracefulShutdownC: gracefulShutdownC,
|
|
gracePeriod: gracePeriod,
|
|
protocol: protocol,
|
|
}
|
|
}
|
|
|
|
func (c *controlStream) ServeControlStream(
|
|
ctx context.Context,
|
|
rw io.ReadWriteCloser,
|
|
connOptions *tunnelpogs.ConnectionOptions,
|
|
tunnelConfigGetter TunnelConfigJSONGetter,
|
|
) error {
|
|
registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)
|
|
|
|
registrationDetails, err := registrationClient.RegisterConnection(
|
|
ctx,
|
|
c.tunnelProperties.Credentials.Auth(),
|
|
c.tunnelProperties.Credentials.TunnelID,
|
|
connOptions,
|
|
c.connIndex,
|
|
c.edgeAddress)
|
|
if err != nil {
|
|
defer registrationClient.Close()
|
|
if err.Error() == DuplicateConnectionError {
|
|
c.observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
|
|
return errDuplicationConnection
|
|
}
|
|
c.observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
|
|
return serverRegistrationErrorFromRPC(err)
|
|
}
|
|
c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
|
|
|
|
c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol)
|
|
c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location)
|
|
c.connectedFuse.Connected()
|
|
|
|
// if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration
|
|
if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged {
|
|
if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil {
|
|
if err := registrationClient.SendLocalConfiguration(ctx, tunnelConfig); err != nil {
|
|
c.observer.metrics.localConfigMetrics.pushesErrors.Inc()
|
|
c.observer.log.Err(err).Msg("unable to send local configuration")
|
|
}
|
|
c.observer.metrics.localConfigMetrics.pushes.Inc()
|
|
} else {
|
|
c.observer.log.Err(err).Msg("failed to obtain current configuration")
|
|
}
|
|
}
|
|
|
|
return c.waitForUnregister(ctx, registrationClient)
|
|
}
|
|
|
|
func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) error {
|
|
// wait for connection termination or start of graceful shutdown
|
|
defer registrationClient.Close()
|
|
var shutdownError error
|
|
select {
|
|
case <-ctx.Done():
|
|
shutdownError = ctx.Err()
|
|
break
|
|
case <-c.gracefulShutdownC:
|
|
c.stoppedGracefully = true
|
|
}
|
|
|
|
c.observer.sendUnregisteringEvent(c.connIndex)
|
|
err := registrationClient.GracefulShutdown(ctx, c.gracePeriod)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Error shutting down control stream")
|
|
}
|
|
c.observer.log.Info().
|
|
Int(management.EventTypeKey, int(management.Cloudflared)).
|
|
Uint8(LogFieldConnIndex, c.connIndex).
|
|
IPAddr(LogFieldIPAddress, c.edgeAddress).
|
|
Msg("Unregistered tunnel connection")
|
|
return shutdownError
|
|
}
|
|
|
|
func (c *controlStream) IsStopped() bool {
|
|
return c.stoppedGracefully
|
|
}
|