320 lines
9.9 KiB
Go
320 lines
9.9 KiB
Go
package connection
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
"zombiezen.com/go/capnproto2/rpc"
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
)
|
|
|
|
type tunnelServerClient struct {
|
|
client tunnelpogs.TunnelServer_PogsClient
|
|
transport rpc.Transport
|
|
}
|
|
|
|
// NewTunnelRPCClient creates and returns a new RPC client, which will communicate using a stream on the given muxer.
|
|
// This method is exported for supervisor to call Authenticate RPC
|
|
func NewTunnelServerClient(
|
|
ctx context.Context,
|
|
stream io.ReadWriteCloser,
|
|
log *zerolog.Logger,
|
|
) *tunnelServerClient {
|
|
transport := tunnelrpc.NewTransportLogger(log, rpc.StreamTransport(stream))
|
|
conn := rpc.NewConn(
|
|
transport,
|
|
tunnelrpc.ConnLog(log),
|
|
)
|
|
registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
|
return &tunnelServerClient{
|
|
client: tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn},
|
|
transport: transport,
|
|
}
|
|
}
|
|
|
|
func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
|
|
authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return authResp.Outcome(), nil
|
|
}
|
|
|
|
func (tsc *tunnelServerClient) Close() {
|
|
// Closing the client will also close the connection
|
|
_ = tsc.client.Close()
|
|
_ = tsc.transport.Close()
|
|
}
|
|
|
|
type NamedTunnelRPCClient interface {
|
|
RegisterConnection(
|
|
c context.Context,
|
|
config *NamedTunnelConfig,
|
|
options *tunnelpogs.ConnectionOptions,
|
|
connIndex uint8,
|
|
observer *Observer,
|
|
) error
|
|
GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
|
|
Close()
|
|
}
|
|
|
|
type registrationServerClient struct {
|
|
client tunnelpogs.RegistrationServer_PogsClient
|
|
transport rpc.Transport
|
|
}
|
|
|
|
func newRegistrationRPCClient(
|
|
ctx context.Context,
|
|
stream io.ReadWriteCloser,
|
|
log *zerolog.Logger,
|
|
) NamedTunnelRPCClient {
|
|
transport := tunnelrpc.NewTransportLogger(log, rpc.StreamTransport(stream))
|
|
conn := rpc.NewConn(
|
|
transport,
|
|
tunnelrpc.ConnLog(log),
|
|
)
|
|
return ®istrationServerClient{
|
|
client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
|
|
transport: transport,
|
|
}
|
|
}
|
|
|
|
func (rsc *registrationServerClient) RegisterConnection(
|
|
ctx context.Context,
|
|
config *NamedTunnelConfig,
|
|
options *tunnelpogs.ConnectionOptions,
|
|
connIndex uint8,
|
|
observer *Observer,
|
|
) error {
|
|
conn, err := rsc.client.RegisterConnection(
|
|
ctx,
|
|
config.Credentials.Auth(),
|
|
config.Credentials.TunnelID,
|
|
connIndex,
|
|
options,
|
|
)
|
|
if err != nil {
|
|
if err.Error() == DuplicateConnectionError {
|
|
observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
|
|
return errDuplicationConnection
|
|
}
|
|
observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
|
|
return serverRegistrationErrorFromRPC(err)
|
|
}
|
|
|
|
observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
|
|
|
|
observer.logServerInfo(connIndex, conn.Location, fmt.Sprintf("Connection %s registered", conn.UUID))
|
|
observer.sendConnectedEvent(connIndex, conn.Location)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
|
|
ctx, cancel := context.WithTimeout(ctx, gracePeriod)
|
|
defer cancel()
|
|
_ = rsc.client.UnregisterConnection(ctx)
|
|
}
|
|
|
|
func (rsc *registrationServerClient) Close() {
|
|
// Closing the client will also close the connection
|
|
_ = rsc.client.Close()
|
|
// Closing the transport also closes the stream
|
|
_ = rsc.transport.Close()
|
|
}
|
|
|
|
type rpcName string
|
|
|
|
const (
|
|
register rpcName = "register"
|
|
reconnect rpcName = "reconnect"
|
|
unregister rpcName = "unregister"
|
|
authenticate rpcName = " authenticate"
|
|
)
|
|
|
|
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
|
|
h.observer.sendRegisteringEvent(registrationOptions.ConnectionID)
|
|
|
|
stream, err := h.newRPCStream(ctx, register)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rpcClient := NewTunnelServerClient(ctx, stream, h.observer.log)
|
|
defer rpcClient.Close()
|
|
|
|
_ = h.logServerInfo(ctx, rpcClient)
|
|
registration := rpcClient.client.RegisterTunnel(
|
|
ctx,
|
|
classicTunnel.OriginCert,
|
|
classicTunnel.Hostname,
|
|
registrationOptions,
|
|
)
|
|
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
|
// RegisterTunnel RPC failure
|
|
return h.processRegisterTunnelError(registrationErr, register)
|
|
}
|
|
|
|
// Send free tunnel URL to UI
|
|
h.observer.sendURL(registration.Url)
|
|
credentialSetter.SetEventDigest(h.connIndex, registration.EventDigest)
|
|
return h.processRegistrationSuccess(registration, register, credentialSetter, classicTunnel)
|
|
}
|
|
|
|
type CredentialManager interface {
|
|
ReconnectToken() ([]byte, error)
|
|
EventDigest(connID uint8) ([]byte, error)
|
|
SetEventDigest(connID uint8, digest []byte)
|
|
ConnDigest(connID uint8) ([]byte, error)
|
|
SetConnDigest(connID uint8, digest []byte)
|
|
}
|
|
|
|
func (h *h2muxConnection) processRegistrationSuccess(
|
|
registration *tunnelpogs.TunnelRegistration,
|
|
name rpcName,
|
|
credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig,
|
|
) error {
|
|
for _, logLine := range registration.LogLines {
|
|
h.observer.log.Info().Msg(logLine)
|
|
}
|
|
|
|
if registration.TunnelID != "" {
|
|
h.observer.metrics.tunnelsHA.AddTunnelID(h.connIndex, registration.TunnelID)
|
|
h.observer.log.Info().Msgf("Each HA connection's tunnel IDs: %v", h.observer.metrics.tunnelsHA.String())
|
|
}
|
|
|
|
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
|
|
if classicTunnel.IsTrialZone() {
|
|
err := h.observer.logTrialHostname(registration)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
credentialManager.SetConnDigest(h.connIndex, registration.ConnDigest)
|
|
h.observer.metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
|
|
|
h.observer.log.Info().Msgf("Route propagating, it may take up to 1 minute for your new route to become functional")
|
|
h.observer.metrics.regSuccess.WithLabelValues(string(name)).Inc()
|
|
return nil
|
|
}
|
|
|
|
func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, name rpcName) error {
|
|
if err.Error() == DuplicateConnectionError {
|
|
h.observer.metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
|
|
return errDuplicationConnection
|
|
}
|
|
h.observer.metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
|
|
return ServerRegisterTunnelError{
|
|
Cause: err,
|
|
Permanent: err.IsPermanent(),
|
|
}
|
|
}
|
|
|
|
func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
|
|
token, err := credentialManager.ReconnectToken()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
eventDigest, err := credentialManager.EventDigest(h.connIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
connDigest, err := credentialManager.ConnDigest(h.connIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
h.observer.log.Debug().Msg("initiating RPC stream to reconnect")
|
|
stream, err := h.newRPCStream(ctx, register)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rpcClient := NewTunnelServerClient(ctx, stream, h.observer.log)
|
|
defer rpcClient.Close()
|
|
|
|
_ = h.logServerInfo(ctx, rpcClient)
|
|
registration := rpcClient.client.ReconnectTunnel(
|
|
ctx,
|
|
token,
|
|
eventDigest,
|
|
connDigest,
|
|
classicTunnel.Hostname,
|
|
registrationOptions,
|
|
)
|
|
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
|
// ReconnectTunnel RPC failure
|
|
return h.processRegisterTunnelError(registrationErr, reconnect)
|
|
}
|
|
return h.processRegistrationSuccess(registration, reconnect, credentialManager, classicTunnel)
|
|
}
|
|
|
|
func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelServerClient) error {
|
|
// Request server info without blocking tunnel registration; must use capnp library directly.
|
|
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.client.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
|
return nil
|
|
})
|
|
serverInfoMessage, err := serverInfoPromise.Result().Struct()
|
|
if err != nil {
|
|
h.observer.log.Err(err).Msg("Failed to retrieve server information")
|
|
return err
|
|
}
|
|
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
|
if err != nil {
|
|
h.observer.log.Err(err).Msg("Failed to retrieve server information")
|
|
return err
|
|
}
|
|
h.observer.logServerInfo(h.connIndex, serverInfo.LocationName, "Connection established")
|
|
return nil
|
|
}
|
|
|
|
func (h *h2muxConnection) registerNamedTunnel(
|
|
ctx context.Context,
|
|
namedTunnel *NamedTunnelConfig,
|
|
connOptions *tunnelpogs.ConnectionOptions,
|
|
) error {
|
|
stream, err := h.newRPCStream(ctx, register)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
|
|
defer rpcClient.Close()
|
|
|
|
if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
|
|
h.observer.sendUnregisteringEvent(h.connIndex)
|
|
|
|
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
|
|
defer cancel()
|
|
|
|
stream, err := h.newRPCStream(unregisterCtx, unregister)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer stream.Close()
|
|
|
|
if isNamedTunnel {
|
|
rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
|
|
defer rpcClient.Close()
|
|
|
|
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
|
|
} else {
|
|
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log)
|
|
defer rpcClient.Close()
|
|
|
|
// gracePeriod is encoded in int64 using capnproto
|
|
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
|
|
}
|
|
|
|
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")
|
|
}
|