package connection import ( "context" "io" "net" "time" "github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) // registerClient derives a named tunnel rpc client that can then be used to register and unregister connections. type registerClientFunc func(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient type controlStream struct { observer *Observer connectedFuse ConnectedFuse tunnelProperties *TunnelProperties connIndex uint8 edgeAddress net.IP protocol Protocol registerClientFunc registerClientFunc registerTimeout time.Duration gracefulShutdownC <-chan struct{} gracePeriod time.Duration stoppedGracefully bool } // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown. type ControlStreamHandler interface { // ServeControlStream handles the control plane of the transport in the current goroutine calling this ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error // IsStopped tells whether the method above has finished IsStopped() bool } type TunnelConfigJSONGetter interface { GetConfigJSON() ([]byte, error) } // NewControlStream returns a new instance of ControlStreamHandler func NewControlStream( observer *Observer, connectedFuse ConnectedFuse, tunnelProperties *TunnelProperties, connIndex uint8, edgeAddress net.IP, registerClientFunc registerClientFunc, registerTimeout time.Duration, gracefulShutdownC <-chan struct{}, gracePeriod time.Duration, protocol Protocol, ) ControlStreamHandler { if registerClientFunc == nil { registerClientFunc = tunnelrpc.NewRegistrationClient } return &controlStream{ observer: observer, connectedFuse: connectedFuse, tunnelProperties: tunnelProperties, registerClientFunc: registerClientFunc, registerTimeout: registerTimeout, connIndex: connIndex, edgeAddress: edgeAddress, gracefulShutdownC: gracefulShutdownC, gracePeriod: gracePeriod, protocol: protocol, } } func (c *controlStream) ServeControlStream( ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter, ) error { registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout) registrationDetails, err := registrationClient.RegisterConnection( ctx, c.tunnelProperties.Credentials.Auth(), c.tunnelProperties.Credentials.TunnelID, connOptions, c.connIndex, c.edgeAddress) if err != nil { defer registrationClient.Close() if err.Error() == DuplicateConnectionError { c.observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc() return errDuplicationConnection } c.observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc() return serverRegistrationErrorFromRPC(err) } c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol) c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location) c.connectedFuse.Connected() // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged { if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil { if err := registrationClient.SendLocalConfiguration(ctx, tunnelConfig); err != nil { c.observer.metrics.localConfigMetrics.pushesErrors.Inc() c.observer.log.Err(err).Msg("unable to send local configuration") } c.observer.metrics.localConfigMetrics.pushes.Inc() } else { c.observer.log.Err(err).Msg("failed to obtain current configuration") } } c.waitForUnregister(ctx, registrationClient) return nil } func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) { // wait for connection termination or start of graceful shutdown defer registrationClient.Close() select { case <-ctx.Done(): break case <-c.gracefulShutdownC: c.stoppedGracefully = true } c.observer.sendUnregisteringEvent(c.connIndex) registrationClient.GracefulShutdown(ctx, c.gracePeriod) c.observer.log.Info(). Int(management.EventTypeKey, int(management.Cloudflared)). Uint8(LogFieldConnIndex, c.connIndex). IPAddr(LogFieldIPAddress, c.edgeAddress). Msg("Unregistered tunnel connection") } func (c *controlStream) IsStopped() bool { return c.stoppedGracefully }