package tunnelrpc import ( "context" "io" "net" "time" "github.com/google/uuid" "zombiezen.com/go/capnproto2/rpc" "github.com/cloudflare/cloudflared/tunnelrpc/metrics" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) type RegistrationClient interface { RegisterConnection( ctx context.Context, auth pogs.TunnelAuth, tunnelID uuid.UUID, options *pogs.ConnectionOptions, connIndex uint8, edgeAddress net.IP, ) (*pogs.ConnectionDetails, error) SendLocalConfiguration(ctx context.Context, config []byte) error GracefulShutdown(ctx context.Context, gracePeriod time.Duration) Close() } type registrationClient struct { client pogs.RegistrationServer_PogsClient transport rpc.Transport requestTimeout time.Duration } func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser, requestTimeout time.Duration) RegistrationClient { transport := SafeTransport(stream) conn := NewClientConn(transport) client := pogs.NewRegistrationServer_PogsClient(conn.Bootstrap(ctx), conn) return ®istrationClient{ client: client, transport: transport, requestTimeout: requestTimeout, } } func (r *registrationClient) RegisterConnection( ctx context.Context, auth pogs.TunnelAuth, tunnelID uuid.UUID, options *pogs.ConnectionOptions, connIndex uint8, edgeAddress net.IP, ) (*pogs.ConnectionDetails, error) { ctx, cancel := context.WithTimeout(ctx, r.requestTimeout) defer cancel() defer metrics.CapnpMetrics.ClientOperations.WithLabelValues(metrics.Registration, metrics.OperationRegisterConnection).Inc() timer := metrics.NewClientOperationLatencyObserver(metrics.Registration, metrics.OperationRegisterConnection) defer timer.ObserveDuration() conn, err := r.client.RegisterConnection(ctx, auth, tunnelID, connIndex, options) if err != nil { metrics.CapnpMetrics.ClientFailures.WithLabelValues(metrics.Registration, metrics.OperationRegisterConnection).Inc() } return conn, err } func (r *registrationClient) SendLocalConfiguration(ctx context.Context, config []byte) error { ctx, cancel := context.WithTimeout(ctx, r.requestTimeout) defer cancel() defer metrics.CapnpMetrics.ClientOperations.WithLabelValues(metrics.Registration, metrics.OperationUpdateLocalConfiguration).Inc() timer := metrics.NewClientOperationLatencyObserver(metrics.Registration, metrics.OperationUpdateLocalConfiguration) defer timer.ObserveDuration() err := r.client.SendLocalConfiguration(ctx, config) if err != nil { metrics.CapnpMetrics.ClientFailures.WithLabelValues(metrics.Registration, metrics.OperationUpdateLocalConfiguration).Inc() } return err } func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { ctx, cancel := context.WithTimeout(ctx, gracePeriod) defer cancel() defer metrics.CapnpMetrics.ClientOperations.WithLabelValues(metrics.Registration, metrics.OperationUnregisterConnection).Inc() timer := metrics.NewClientOperationLatencyObserver(metrics.Registration, metrics.OperationUnregisterConnection) defer timer.ObserveDuration() err := r.client.UnregisterConnection(ctx) if err != nil { metrics.CapnpMetrics.ClientFailures.WithLabelValues(metrics.Registration, metrics.OperationUnregisterConnection).Inc() } } func (r *registrationClient) Close() { // Closing the client will also close the connection _ = r.client.Close() // Closing the transport also closes the stream _ = r.transport.Close() }