diff --git a/Gopkg.lock b/Gopkg.lock index fce7e984..306106d5 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -186,12 +186,12 @@ version = "v1.1.0" [[projects]] + branch = "master" digest = "1:582b704bebaa06b48c29b0cec224a6058a09c86883aaddabde889cd1a5f73e1b" name = "github.com/google/uuid" packages = ["."] pruneopts = "UT" revision = "0cd6bf5da1e1c83f8b45653022c74f71af0538a4" - version = "v1.1.1" [[projects]] digest = "1:c79fb010be38a59d657c48c6ba1d003a8aa651fa56b579d959d74573b7dff8e1" diff --git a/Gopkg.toml b/Gopkg.toml index b38e6156..0d006193 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -81,3 +81,7 @@ [[constraint]] branch = "master" name = "github.com/cloudflare/golibs" + +[[constraint]] + name = "github.com/google/uuid" + version = "v1.1.1" diff --git a/connection/connection.go b/connection/connection.go new file mode 100644 index 00000000..220e9b28 --- /dev/null +++ b/connection/connection.go @@ -0,0 +1,159 @@ +package connection + +import ( + "context" + "crypto/tls" + "net" + "sync" + "time" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/tunnelrpc" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + + rpc "zombiezen.com/go/capnproto2/rpc" +) + +const ( + dialTimeout = 5 * time.Second +) + +type dialError struct { + cause error +} + +func (e dialError) Error() string { + return e.cause.Error() +} + +type muxerShutdownError struct{} + +func (e muxerShutdownError) Error() string { + return "muxer shutdown" +} + +type ConnectionConfig struct { + TLSConfig *tls.Config + HeartbeatInterval time.Duration + MaxHeartbeats uint64 + Logger *logrus.Entry +} + +type connectionHandler interface { + serve(ctx context.Context) error + connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) + shutdown() +} + +type h2muxHandler struct { + muxer *h2mux.Muxer + logger *logrus.Entry +} + +type muxedStreamHandler struct { +} + +// Implements MuxedStreamHandler interface +func (h *muxedStreamHandler) ServeStream(stream *h2mux.MuxedStream) error { + return nil +} + +func (h *h2muxHandler) serve(ctx context.Context) error { + // Serve doesn't return until h2mux is shutdown + if err := h.muxer.Serve(ctx); err != nil { + return err + } + return muxerShutdownError{} +} + +// Connect is used to establish connections with cloudflare's edge network +func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) { + conn, err := h.newRPConn() + if err != nil { + return nil, errors.Wrap(err, "Failed to create new RPC connection") + } + defer conn.Close() + tsClient := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)} + return tsClient.Connect(ctx, parameters) +} + +func (h *h2muxHandler) shutdown() { + h.muxer.Shutdown() +} + +func (h *h2muxHandler) newRPConn() (*rpc.Conn, error) { + stream, err := h.muxer.OpenStream([]h2mux.Header{ + {Name: ":method", Value: "RPC"}, + {Name: ":scheme", Value: "capnp"}, + {Name: ":path", Value: "*"}, + }, nil) + if err != nil { + return nil, err + } + return rpc.NewConn( + tunnelrpc.NewTransportLogger(h.logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)), + tunnelrpc.ConnLog(h.logger.WithField("subsystem", "rpc-transport")), + ), nil +} + +// NewConnectionHandler returns a connectionHandler, wrapping h2mux to make RPC calls +func newH2MuxHandler(ctx context.Context, + config *ConnectionConfig, + edgeIP *net.TCPAddr, +) (connectionHandler, error) { + // Inherit from parent context so we can cancel (Ctrl-C) while dialing + dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout) + defer dialCancel() + dialer := net.Dialer{DualStack: true} + plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String()) + if err != nil { + return nil, dialError{cause: errors.Wrap(err, "DialContext error")} + } + edgeConn := tls.Client(plaintextEdgeConn, config.TLSConfig) + edgeConn.SetDeadline(time.Now().Add(dialTimeout)) + err = edgeConn.Handshake() + if err != nil { + return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")} + } + // clear the deadline on the conn; h2mux has its own timeouts + edgeConn.SetDeadline(time.Time{}) + // Establish a muxed connection with the edge + // Client mux handshake with agent server + muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ + Timeout: dialTimeout, + Handler: &muxedStreamHandler{}, + IsClient: true, + HeartbeatInterval: config.HeartbeatInterval, + MaxHeartbeats: config.MaxHeartbeats, + Logger: config.Logger, + }) + if err != nil { + return nil, err + } + return &h2muxHandler{ + muxer: muxer, + logger: config.Logger, + }, nil +} + +// connectionPool is a pool of connection handlers +type connectionPool struct { + sync.Mutex + connectionHandlers []connectionHandler +} + +func (cp *connectionPool) put(h connectionHandler) { + cp.Lock() + defer cp.Unlock() + cp.connectionHandlers = append(cp.connectionHandlers, h) +} + +func (cp *connectionPool) close() { + cp.Lock() + defer cp.Unlock() + for _, h := range cp.connectionHandlers { + h.shutdown() + } +} diff --git a/origin/discovery.go b/connection/discovery.go similarity index 99% rename from origin/discovery.go rename to connection/discovery.go index 6f2fbfad..898b0755 100644 --- a/origin/discovery.go +++ b/connection/discovery.go @@ -1,4 +1,4 @@ -package origin +package connection import ( "context" diff --git a/origin/discovery_test.go b/connection/discovery_test.go similarity index 97% rename from origin/discovery_test.go rename to connection/discovery_test.go index fecf9c23..4e5aeacf 100644 --- a/origin/discovery_test.go +++ b/connection/discovery_test.go @@ -1,4 +1,4 @@ -package origin +package connection import ( "net" diff --git a/connection/supervisor.go b/connection/supervisor.go new file mode 100644 index 00000000..6ba0e331 --- /dev/null +++ b/connection/supervisor.go @@ -0,0 +1,145 @@ +package connection + +import ( + "context" + "net" + "time" + + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + // Waiting time before retrying a failed tunnel connection + reconnectDuration = time.Second * 10 + // SRV record resolution TTL + resolveTTL = time.Hour + // Interval between establishing new connection + connectionInterval = time.Second +) + +type CloudflaredConfig struct { + ConnectionConfig *ConnectionConfig + OriginCert []byte + Tags []tunnelpogs.Tag + EdgeAddrs []string + HAConnections uint + Logger *logrus.Logger +} + +// Supervisor is a stateful object that manages connections with the edge +type Supervisor struct { + config *CloudflaredConfig + state *supervisorState + connErrors chan error +} + +type supervisorState struct { + // IPs to connect to cloudflare's edge network + edgeIPs []*net.TCPAddr + // index of the next element to use in edgeIPs + nextEdgeIPIndex int + // last time edgeIPs were refreshed + lastResolveTime time.Time + // ID of this cloudflared instance + cloudflaredID uuid.UUID + // connectionPool is a pool of connectionHandlers that can be used to make RPCs + connectionPool *connectionPool +} + +func (s *supervisorState) getNextEdgeIP() *net.TCPAddr { + ip := s.edgeIPs[s.nextEdgeIPIndex%len(s.edgeIPs)] + s.nextEdgeIPIndex++ + return ip +} + +func NewSupervisor(config *CloudflaredConfig) *Supervisor { + return &Supervisor{ + config: config, + state: &supervisorState{ + connectionPool: &connectionPool{}, + }, + connErrors: make(chan error), + } +} + +func (s *Supervisor) Run(ctx context.Context) error { + logger := s.config.Logger + if err := s.initialize(); err != nil { + logger.WithError(err).Error("Failed to get edge IPs") + return err + } + defer s.state.connectionPool.close() + + var currentConnectionCount uint + expectedConnectionCount := s.config.HAConnections + if uint(len(s.state.edgeIPs)) < s.config.HAConnections { + logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(s.state.edgeIPs)) + expectedConnectionCount = uint(len(s.state.edgeIPs)) + } + for { + select { + case <-ctx.Done(): + return nil + case connErr := <-s.connErrors: + logger.WithError(connErr).Warnf("Connection dropped unexpectedly") + currentConnectionCount-- + default: + time.Sleep(5 * time.Second) + } + if currentConnectionCount < expectedConnectionCount { + h, err := newH2MuxHandler(ctx, s.config.ConnectionConfig, s.state.getNextEdgeIP()) + if err != nil { + logger.WithError(err).Error("Failed to create new connection handler") + continue + } + go func() { + s.connErrors <- h.serve(ctx) + }() + connResult, err := s.connect(ctx, s.config, s.state.cloudflaredID, h) + if err != nil { + logger.WithError(err).Errorf("Failed to connect to cloudflared's edge network") + h.shutdown() + continue + } + if connErr := connResult.Err; connErr != nil && !connErr.ShouldRetry { + logger.WithError(connErr).Errorf("Server respond with don't retry to connect") + h.shutdown() + return err + } + logger.Infof("Connected to %s", connResult.ServerInfo.LocationName) + s.state.connectionPool.put(h) + currentConnectionCount++ + } + } +} + +func (s *Supervisor) initialize() error { + edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs) + if err != nil { + return errors.Wrapf(err, "Failed to resolve cloudflare edge network address") + } + s.state.edgeIPs = edgeIPs + s.state.lastResolveTime = time.Now() + cloudflaredID, err := uuid.NewRandom() + if err != nil { + return errors.Wrap(err, "Failed to generate cloudflared ID") + } + s.state.cloudflaredID = cloudflaredID + return nil +} + +func (s *Supervisor) connect(ctx context.Context, + config *CloudflaredConfig, + cloudflaredID uuid.UUID, + h connectionHandler, +) (*tunnelpogs.ConnectResult, error) { + connectParameters := &tunnelpogs.ConnectParameters{ + OriginCert: config.OriginCert, + CloudflaredID: cloudflaredID, + NumPreviousAttempts: 0, + } + return h.connect(ctx, connectParameters) +} diff --git a/origin/supervisor.go b/origin/supervisor.go index 53042f70..9f0c352d 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -6,6 +6,7 @@ import ( "net" "time" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/signal" "github.com/google/uuid" @@ -124,7 +125,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { logger := s.config.Logger - edgeIPs, err := ResolveEdgeIPs(logger, s.config.EdgeAddrs) + edgeIPs, err := connection.ResolveEdgeIPs(logger, s.config.EdgeAddrs) if err != nil { logger.Infof("ResolveEdgeIPs err") return err @@ -223,7 +224,7 @@ func (s *Supervisor) refreshEdgeIPs() { } s.resolverC = make(chan resolveResult) go func() { - edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs) + edgeIPs, err := connection.ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs) s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err} }() } diff --git a/origin/tunnel.go b/origin/tunnel.go index 6441e6b4..2596d26c 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" @@ -36,6 +37,7 @@ const ( lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" TagHeaderNamePrefix = "Cf-Warp-Tag-" DuplicateConnectionError = "EDUPCONN" + isDeclarativeTunnel = false ) type TunnelConfig struct { @@ -151,9 +153,24 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte // If a user specified negative HAConnections, we will treat it as requesting 1 connection if config.HAConnections > 1 { + if isDeclarativeTunnel { + return connection.NewSupervisor(&connection.CloudflaredConfig{ + ConnectionConfig: &connection.ConnectionConfig{ + TLSConfig: config.TlsConfig, + HeartbeatInterval: config.HeartbeatInterval, + MaxHeartbeats: config.MaxHeartbeats, + Logger: config.Logger.WithField("subsystem", "connection_supervisor"), + }, + OriginCert: config.OriginCert, + Tags: config.Tags, + EdgeAddrs: config.EdgeAddrs, + HAConnections: uint(config.HAConnections), + Logger: config.Logger, + }).Run(ctx) + } return NewSupervisor(config).Run(ctx, connectedSignal, u) } else { - addrs, err := ResolveEdgeIPs(config.Logger, config.EdgeAddrs) + addrs, err := connection.ResolveEdgeIPs(config.Logger, config.EdgeAddrs) if err != nil { return err }