cloudflared-mirror/connection/supervisor.go

146 lines
4.0 KiB
Go

package connection
import (
"context"
"net"
"time"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
// Waiting time before retrying a failed tunnel connection
reconnectDuration = time.Second * 10
// SRV record resolution TTL
resolveTTL = time.Hour
// Interval between establishing new connection
connectionInterval = time.Second
)
type CloudflaredConfig struct {
ConnectionConfig *ConnectionConfig
OriginCert []byte
Tags []tunnelpogs.Tag
EdgeAddrs []string
HAConnections uint
Logger *logrus.Logger
}
// Supervisor is a stateful object that manages connections with the edge
type Supervisor struct {
config *CloudflaredConfig
state *supervisorState
connErrors chan error
}
type supervisorState struct {
// IPs to connect to cloudflare's edge network
edgeIPs []*net.TCPAddr
// index of the next element to use in edgeIPs
nextEdgeIPIndex int
// last time edgeIPs were refreshed
lastResolveTime time.Time
// ID of this cloudflared instance
cloudflaredID uuid.UUID
// connectionPool is a pool of connectionHandlers that can be used to make RPCs
connectionPool *connectionPool
}
func (s *supervisorState) getNextEdgeIP() *net.TCPAddr {
ip := s.edgeIPs[s.nextEdgeIPIndex%len(s.edgeIPs)]
s.nextEdgeIPIndex++
return ip
}
func NewSupervisor(config *CloudflaredConfig) *Supervisor {
return &Supervisor{
config: config,
state: &supervisorState{
connectionPool: &connectionPool{},
},
connErrors: make(chan error),
}
}
func (s *Supervisor) Run(ctx context.Context) error {
logger := s.config.Logger
if err := s.initialize(); err != nil {
logger.WithError(err).Error("Failed to get edge IPs")
return err
}
defer s.state.connectionPool.close()
var currentConnectionCount uint
expectedConnectionCount := s.config.HAConnections
if uint(len(s.state.edgeIPs)) < s.config.HAConnections {
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(s.state.edgeIPs))
expectedConnectionCount = uint(len(s.state.edgeIPs))
}
for {
select {
case <-ctx.Done():
return nil
case connErr := <-s.connErrors:
logger.WithError(connErr).Warnf("Connection dropped unexpectedly")
currentConnectionCount--
default:
time.Sleep(5 * time.Second)
}
if currentConnectionCount < expectedConnectionCount {
h, err := newH2MuxHandler(ctx, s.config.ConnectionConfig, s.state.getNextEdgeIP())
if err != nil {
logger.WithError(err).Error("Failed to create new connection handler")
continue
}
go func() {
s.connErrors <- h.serve(ctx)
}()
connResult, err := s.connect(ctx, s.config, s.state.cloudflaredID, h)
if err != nil {
logger.WithError(err).Errorf("Failed to connect to cloudflared's edge network")
h.shutdown()
continue
}
if connErr := connResult.Err; connErr != nil && !connErr.ShouldRetry {
logger.WithError(connErr).Errorf("Server respond with don't retry to connect")
h.shutdown()
return err
}
logger.Infof("Connected to %s", connResult.ServerInfo.LocationName)
s.state.connectionPool.put(h)
currentConnectionCount++
}
}
}
func (s *Supervisor) initialize() error {
edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs)
if err != nil {
return errors.Wrapf(err, "Failed to resolve cloudflare edge network address")
}
s.state.edgeIPs = edgeIPs
s.state.lastResolveTime = time.Now()
cloudflaredID, err := uuid.NewRandom()
if err != nil {
return errors.Wrap(err, "Failed to generate cloudflared ID")
}
s.state.cloudflaredID = cloudflaredID
return nil
}
func (s *Supervisor) connect(ctx context.Context,
config *CloudflaredConfig,
cloudflaredID uuid.UUID,
h connectionHandler,
) (*tunnelpogs.ConnectResult, error) {
connectParameters := &tunnelpogs.ConnectParameters{
OriginCert: config.OriginCert,
CloudflaredID: cloudflaredID,
NumPreviousAttempts: 0,
}
return h.connect(ctx, connectParameters)
}