diff --git a/origin/supervisor.go b/origin/supervisor.go index 6da0e901..36960b19 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "time" + + "github.com/google/uuid" ) const ( @@ -49,9 +51,9 @@ func NewSupervisor(config *TunnelConfig) *Supervisor { } } -func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error { +func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error { logger := s.config.Logger - if err := s.initialize(ctx, connectedSignal); err != nil { + if err := s.initialize(ctx, connectedSignal, u); err != nil { return err } var tunnelsWaiting []int @@ -94,7 +96,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err case <-backoffTimer: backoffTimer = nil for _, index := range tunnelsWaiting { - go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) + go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u) } tunnelsActive += len(tunnelsWaiting) tunnelsWaiting = nil @@ -118,7 +120,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err } } -func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { +func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error { logger := s.config.Logger edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) if err != nil { @@ -133,7 +135,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct s.lastResolve = time.Now() // check entitlement and version too old error before attempting to register more tunnels s.nextUnusedEdgeIP = s.config.HAConnections - go s.startFirstTunnel(ctx, connectedSignal) + go s.startFirstTunnel(ctx, connectedSignal, u) select { case <-ctx.Done(): <-s.tunnelErrors @@ -145,7 +147,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct } // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { - go s.startTunnel(ctx, i, make(chan struct{})) + go s.startTunnel(ctx, i, make(chan struct{}), u) time.Sleep(registrationInterval) } return nil @@ -153,8 +155,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed -func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal) +func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) { + err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) defer func() { s.tunnelErrors <- tunnelError{index: 0, err: err} }() @@ -175,14 +177,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan default: return } - err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal) + err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) } } // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. -func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal) +func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}, u uuid.UUID) { + err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) s.tunnelErrors <- tunnelError{index: index, err: err} } diff --git a/origin/tunnel.go b/origin/tunnel.go index 084455ac..57241398 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -132,15 +132,21 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte <-shutdownC cancel() }() + + u, err := uuid.NewRandom() + if err != nil { + return err + } + // If a user specified negative HAConnections, we will treat it as requesting 1 connection if config.HAConnections > 1 { - return NewSupervisor(config).Run(ctx, connectedSignal) + return NewSupervisor(config).Run(ctx, connectedSignal, u) } else { addrs, err := ResolveEdgeIPs(config.EdgeAddrs) if err != nil { return err } - return ServeTunnelLoop(ctx, config, addrs[0], 0, connectedSignal) + return ServeTunnelLoop(ctx, config, addrs[0], 0, connectedSignal, u) } } @@ -149,6 +155,7 @@ func ServeTunnelLoop(ctx context.Context, addr *net.TCPAddr, connectionID uint8, connectedSignal chan struct{}, + u uuid.UUID, ) error { logger := config.Logger config.Metrics.incrementHaConnections() @@ -164,7 +171,7 @@ func ServeTunnelLoop(ctx context.Context, // Ensure the above goroutine will terminate if we return without connecting defer connectedFuse.Fuse(false) for { - err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff) + err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff, u) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { logger.Infof("Retrying in %s seconds", duration) @@ -183,6 +190,7 @@ func ServeTunnel( connectionID uint8, connectedFuse *h2mux.BooleanFuse, backoff *BackoffHandler, + u uuid.UUID, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -222,7 +230,7 @@ func ServeTunnel( errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP) + err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP, u) if err == nil { connectedFuse.Fuse(true) backoff.SetGracePeriod() @@ -302,6 +310,7 @@ func RegisterTunnel( config *TunnelConfig, connectionID uint8, originLocalIP string, + uuid uuid.UUID, ) error { config.TransportLogger.Debug("initiating RPC stream to register") stream, err := muxer.OpenStream([]h2mux.Header{ @@ -328,10 +337,6 @@ func RegisterTunnel( serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { return nil }) - uuid, err := uuid.NewRandom() - if err != nil { - return clientRegisterTunnelError{cause: err} - } registration, err := ts.RegisterTunnel( ctx, config.OriginCert,