diff --git a/Makefile b/Makefile index da63e656..402f3af6 100644 --- a/Makefile +++ b/Makefile @@ -78,6 +78,6 @@ tunnelrpc/tunnelrpc.capnp.go: tunnelrpc/tunnelrpc.capnp .PHONY: vet vet: - go vet ./... + go vet -composites=false ./... which go-sumtype # go get github.com/BurntSushi/go-sumtype go-sumtype $$(go list ./...) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 447af863..4550acfc 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -12,8 +12,14 @@ import ( "syscall" "time" - "github.com/getsentry/raven-go" + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/supervisor" "github.com/google/uuid" + + "github.com/getsentry/raven-go" "golang.org/x/crypto/ssh/terminal" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" @@ -239,8 +245,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan } buildInfo := buildinfo.GetBuildInfo(version) - logger.Infof("Build info: %+v", *buildInfo) - logger.Infof("Version %s", version) + buildInfo.Log(logger) logClientOptions(c) if c.IsSet("proxy-dns") { @@ -256,16 +261,6 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan // Wait for proxy-dns to come up (if used) <-dnsReadySignal - // update needs to be after DNS proxy is up to resolve equinox server address - if updater.IsAutoupdateEnabled(c) { - logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) - wg.Add(1) - go func() { - defer wg.Done() - errC <- updater.Autoupdate(c.Duration("autoupdate-freq"), &listeners, shutdownC) - }() - } - metricsListener, err := listeners.Listen("tcp", c.String("metrics")) if err != nil { logger.WithError(err).Error("Error opening metrics server listener") @@ -285,7 +280,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan cloudflaredID, err := uuid.NewRandom() if err != nil { - logger.WithError(err).Error("cannot generate cloudflared ID") + logger.WithError(err).Error("Cannot generate cloudflared ID") return err } @@ -295,6 +290,21 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan cancel() }() + if c.IsSet("use-declarative-tunnels") { + return startDeclarativeTunnel(ctx, c, cloudflaredID, buildInfo, &listeners) + } + + // update needs to be after DNS proxy is up to resolve equinox server address + if updater.IsAutoupdateEnabled(c) { + logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) + wg.Add(1) + go func() { + defer wg.Done() + autoupdater := updater.NewAutoUpdater(c.Duration("autoupdate-freq"), &listeners) + errC <- autoupdater.Run(ctx) + }() + } + // Serve DNS proxy stand-alone if no hostname or tag or app is going to run if dnsProxyStandAlone(c) { connectedSignal.Notify() @@ -303,6 +313,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan } if c.IsSet("hello-world") { + logger.Infof("hello-world set") helloListener, err := hello.CreateTLSListener("127.0.0.1:") if err != nil { logger.WithError(err).Error("Cannot start Hello World Server") @@ -364,6 +375,114 @@ func Before(c *cli.Context) error { return nil } +func startDeclarativeTunnel(ctx context.Context, + c *cli.Context, + cloudflaredID uuid.UUID, + buildInfo *buildinfo.BuildInfo, + listeners *gracenet.Net, +) error { + reverseProxyOrigin, err := defaultOriginConfig(c) + if err != nil { + logger.WithError(err) + return err + } + defaultClientConfig := &pogs.ClientConfig{ + Version: pogs.InitVersion(), + SupervisorConfig: &pogs.SupervisorConfig{ + AutoUpdateFrequency: c.Duration("autoupdate-freq"), + MetricsUpdateFrequency: c.Duration("metrics-update-freq"), + GracePeriod: c.Duration("grace-period"), + }, + EdgeConnectionConfig: &pogs.EdgeConnectionConfig{ + NumHAConnections: uint8(c.Int("ha-connections")), + HeartbeatInterval: c.Duration("heartbeat-interval"), + Timeout: c.Duration("dial-edge-timeout"), + MaxFailedHeartbeats: c.Uint64("heartbeat-count"), + }, + DoHProxyConfigs: []*pogs.DoHProxyConfig{}, + ReverseProxyConfigs: []*pogs.ReverseProxyConfig{ + { + TunnelHostname: h2mux.TunnelHostname(c.String("hostname")), + Origin: reverseProxyOrigin, + }, + }, + } + + autoupdater := updater.NewAutoUpdater(defaultClientConfig.SupervisorConfig.AutoUpdateFrequency, listeners) + + originCert, err := getOriginCert(c) + if err != nil { + logger.WithError(err).Error("error getting origin cert") + return err + } + toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c) + if err != nil { + logger.WithError(err).Error("unable to create TLS config to connect with edge") + return err + } + + tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) + if err != nil { + logger.WithError(err).Error("unable to parse tag") + return err + } + + cloudflaredConfig := &connection.CloudflaredConfig{ + CloudflaredID: cloudflaredID, + Tags: tags, + BuildInfo: buildInfo, + } + + serviceDiscoverer, err := serviceDiscoverer(c, logger) + if err != nil { + logger.WithError(err).Error("unable to create service discoverer") + return err + } + supervisor, err := supervisor.NewSupervisor(defaultClientConfig, originCert, toEdgeTLSConfig, + serviceDiscoverer, cloudflaredConfig, autoupdater, updater.SupportAutoUpdate(), logger) + if err != nil { + logger.WithError(err).Error("unable to create Supervisor") + return err + } + return supervisor.Run(ctx) +} + +func defaultOriginConfig(c *cli.Context) (pogs.OriginConfig, error) { + if c.IsSet("hello-world") { + return &pogs.HelloWorldOriginConfig{}, nil + } + originConfig := &pogs.HTTPOriginConfig{ + TCPKeepAlive: c.Duration("proxy-tcp-keepalive"), + DialDualStack: !c.Bool("proxy-no-happy-eyeballs"), + TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), + TLSVerify: !c.Bool("no-tls-verify"), + OriginCAPool: c.String("origin-ca-pool"), + OriginServerName: c.String("origin-server-name"), + MaxIdleConnections: c.Uint64("proxy-keepalive-connections"), + IdleConnectionTimeout: c.Duration("proxy-keepalive-timeout"), + ProxyConnectTimeout: c.Duration("proxy-connection-timeout"), + ExpectContinueTimeout: c.Duration("proxy-expect-continue-timeout"), + ChunkedEncoding: c.Bool("no-chunked-encoding"), + } + if c.IsSet("unix-socket") { + unixSocket, err := config.ValidateUnixSocket(c) + if err != nil { + return nil, errors.Wrap(err, "error validating --unix-socket") + } + originConfig.URL = &pogs.UnixPath{Path: unixSocket} + } + originAddr, err := config.ValidateUrl(c) + if err != nil { + return nil, errors.Wrap(err, "error validating origin URL") + } + originURL, err := url.Parse(originAddr) + if err != nil { + return nil, errors.Wrapf(err, "%s is not a valid URL", originAddr) + } + originConfig.URL = &pogs.HTTPURL{URL: originURL} + return originConfig, nil +} + func waitToShutdown(wg *sync.WaitGroup, errC chan error, shutdownC, graceShutdownC chan struct{}, @@ -437,8 +556,8 @@ func tunnelFlags(shouldHide bool) []cli.Flag { }, altsrc.NewDurationFlag(&cli.DurationFlag{ Name: "autoupdate-freq", - Usage: "Autoupdate frequency. Default is 24h.", - Value: time.Hour * 24, + Usage: fmt.Sprintf("Autoupdate frequency. Default is %v.", updater.DefaultCheckUpdateFreq), + Value: updater.DefaultCheckUpdateFreq, Hidden: shouldHide, }), altsrc.NewBoolFlag(&cli.BoolFlag{ @@ -652,6 +771,18 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Value: time.Second * 90, Hidden: shouldHide, }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "proxy-connection-timeout", + Usage: "HTTP proxy timeout for closing an idle connection", + Value: time.Second * 90, + Hidden: shouldHide, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "proxy-expect-continue-timeout", + Usage: "HTTP proxy timeout for closing an idle connection", + Value: time.Second * 90, + Hidden: shouldHide, + }), altsrc.NewBoolFlag(&cli.BoolFlag{ Name: "proxy-dns", Usage: "Run a DNS over HTTPS proxy server.", @@ -711,5 +842,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_USE_DECLARATIVE"}, Hidden: true, }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "dial-edge-timeout", + Usage: "Maximum wait time to set up a connection with the edge", + Value: time.Second * 15, + EnvVars: []string{"DIAL_EDGE_TIMEOUT"}, + Hidden: true, + }), } } diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 56f33f43..31c3e43b 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -14,6 +14,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tlsconfig" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -273,6 +274,15 @@ func prepareTunnelConfig( }, nil } +func serviceDiscoverer(c *cli.Context, logger *logrus.Logger) (connection.EdgeServiceDiscoverer, error) { + // If --edge is specfied, resolve edge server addresses + if len(c.StringSlice("edge")) > 0 { + return connection.NewEdgeHostnameResolver(c.StringSlice("edge")) + } + // Otherwise lookup edge server addresses through service discovery + return connection.NewEdgeAddrResolver(logger) +} + func isRunningFromTerminal() bool { return terminal.IsTerminal(int(os.Stdout.Fd())) } diff --git a/cmd/cloudflared/updater/update.go b/cmd/cloudflared/updater/update.go index e34fb0cc..96cdc95b 100644 --- a/cmd/cloudflared/updater/update.go +++ b/cmd/cloudflared/updater/update.go @@ -1,6 +1,7 @@ package updater import ( + "context" "os" "runtime" "time" @@ -14,6 +15,7 @@ import ( ) const ( + DefaultCheckUpdateFreq = time.Hour * 24 appID = "app_idCzgxYerVD" noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/" noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems." @@ -75,30 +77,6 @@ func Update(_ *cli.Context) error { return updateOutcome.Error } -func Autoupdate(freq time.Duration, listeners *gracenet.Net, shutdownC chan struct{}) error { - tickC := time.Tick(freq) - for { - updateOutcome := loggedUpdate() - if updateOutcome.Updated { - os.Args = append(os.Args, "--is-autoupdated=true") - pid, err := listeners.StartProcess() - if err != nil { - logger.WithError(err).Error("Unable to restart server automatically") - return err - } - // stop old process after autoupdate. Otherwise we create a new process - // after each update - logger.Infof("PID of the new process is %d", pid) - return nil - } - select { - case <-tickC: - case <-shutdownC: - return nil - } - } -} - // Checks for an update and applies it if one is available func loggedUpdate() UpdateOutcome { updateOutcome := checkForUpdateAndApply() @@ -112,7 +90,88 @@ func loggedUpdate() UpdateOutcome { return updateOutcome } +// AutoUpdater periodically checks for new version of cloudflared. +type AutoUpdater struct { + configurable *configurable + listeners *gracenet.Net + updateConfigChan chan *configurable +} + +// AutoUpdaterConfigurable is the attributes of AutoUpdater that can be reconfigured during runtime +type configurable struct { + enabled bool + freq time.Duration +} + +func NewAutoUpdater(freq time.Duration, listeners *gracenet.Net) *AutoUpdater { + updaterConfigurable := &configurable{ + enabled: true, + freq: freq, + } + if freq == 0 { + updaterConfigurable.enabled = false + updaterConfigurable.freq = DefaultCheckUpdateFreq + } + return &AutoUpdater{ + configurable: updaterConfigurable, + listeners: listeners, + updateConfigChan: make(chan *configurable), + } +} + +func (a *AutoUpdater) Run(ctx context.Context) error { + ticker := time.NewTicker(a.configurable.freq) + for { + if a.configurable.enabled { + updateOutcome := loggedUpdate() + if updateOutcome.Updated { + os.Args = append(os.Args, "--is-autoupdated=true") + pid, err := a.listeners.StartProcess() + if err != nil { + logger.WithError(err).Error("Unable to restart server automatically") + return err + } + // stop old process after autoupdate. Otherwise we create a new process + // after each update + logger.Infof("PID of the new process is %d", pid) + return nil + } + } + select { + case <-ctx.Done(): + return ctx.Err() + case newConfigurable := <-a.updateConfigChan: + ticker.Stop() + a.configurable = newConfigurable + ticker = time.NewTicker(a.configurable.freq) + // Check if there is new version of cloudflared after receiving new AutoUpdaterConfigurable + case <-ticker.C: + } + } +} + +// Update is the method to pass new AutoUpdaterConfigurable to a running AutoUpdater. It is safe to be called concurrently +func (a *AutoUpdater) Update(newFreq time.Duration) { + newConfigurable := &configurable{ + enabled: true, + freq: newFreq, + } + // A ero duration means autoupdate is disabled + if newFreq == 0 { + newConfigurable.enabled = false + newConfigurable.freq = DefaultCheckUpdateFreq + } + a.updateConfigChan <- newConfigurable +} + func IsAutoupdateEnabled(c *cli.Context) bool { + if !SupportAutoUpdate() { + return false + } + return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 +} + +func SupportAutoUpdate() bool { if runtime.GOOS == "windows" { logger.Info(noUpdateOnWindowsMessage) return false @@ -122,8 +181,7 @@ func IsAutoupdateEnabled(c *cli.Context) bool { logger.Info(noUpdateInShellMessage) return false } - - return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 + return true } func isRunningFromTerminal() bool { diff --git a/cmd/cloudflared/updater/update_test.go b/cmd/cloudflared/updater/update_test.go new file mode 100644 index 00000000..218b22b4 --- /dev/null +++ b/cmd/cloudflared/updater/update_test.go @@ -0,0 +1,26 @@ +package updater + +import ( + "context" + "testing" + + "github.com/facebookgo/grace/gracenet" + "github.com/stretchr/testify/assert" +) + +func TestDisabledAutoUpdater(t *testing.T) { + listeners := &gracenet.Net{} + autoupdater := NewAutoUpdater(0, listeners) + ctx, cancel := context.WithCancel(context.Background()) + errC := make(chan error) + go func() { + errC <- autoupdater.Run(ctx) + }() + + assert.False(t, autoupdater.configurable.enabled) + assert.Equal(t, DefaultCheckUpdateFreq, autoupdater.configurable.freq) + + cancel() + // Make sure that autoupdater terminates after canceling the context + assert.Equal(t, context.Canceled, <-errC) +} diff --git a/originservice/originservice.go b/originservice/originservice.go index 9bbf4e02..5e7fff8a 100644 --- a/originservice/originservice.go +++ b/originservice/originservice.go @@ -23,6 +23,7 @@ import ( type OriginService interface { Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error) OriginAddr() string + Summary() string Shutdown() } @@ -78,6 +79,10 @@ func (hc *HTTPService) OriginAddr() string { return hc.originAddr } +func (hc *HTTPService) Summary() string { + return fmt.Sprintf("HTTP service listening on %s", hc.originAddr) +} + func (hc *HTTPService) Shutdown() {} // WebsocketService talks to origin using WS/WSS @@ -126,6 +131,10 @@ func (wsc *WebsocketService) OriginAddr() string { return wsc.originAddr } +func (wsc *WebsocketService) Summary() string { + return fmt.Sprintf("Websocket listening on %ss", wsc.originAddr) +} + func (wsc *WebsocketService) Shutdown() { close(wsc.shutdownC) } @@ -181,6 +190,10 @@ func (hwc *HelloWorldService) OriginAddr() string { return hwc.originAddr } +func (hwc *HelloWorldService) Summary() string { + return fmt.Sprintf("Hello World service listening on %s", hwc.originAddr) +} + func (hwc *HelloWorldService) Shutdown() { hwc.listener.Close() } diff --git a/streamhandler/stream_handler.go b/streamhandler/stream_handler.go index d350bffb..955ccca9 100644 --- a/streamhandler/stream_handler.go +++ b/streamhandler/stream_handler.go @@ -1,13 +1,16 @@ package streamhandler import ( + "context" "fmt" "net/http" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/tunnelhostnamemapper" + "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/sirupsen/logrus" + "zombiezen.com/go/capnproto2/rpc" ) // StreamHandler handles new stream opened by the edge. The streams can be used to proxy requests or make RPC. @@ -34,14 +37,66 @@ func NewStreamHandler(newConfigChan chan<- *pogs.ClientConfig, } } +// UseConfiguration implements ClientService +func (s *StreamHandler) UseConfiguration(ctx context.Context, config *pogs.ClientConfig) (*pogs.UseConfigurationResult, error) { + select { + case <-ctx.Done(): + err := fmt.Errorf("Timeout while sending new config to Supervisor") + s.logger.Error(err) + return nil, err + case s.newConfigChan <- config: + } + select { + case <-ctx.Done(): + err := fmt.Errorf("Timeout applying new configuration") + s.logger.Error(err) + return nil, err + case result := <-s.useConfigResultChan: + return result, nil + } +} + +// UpdateConfig replaces current originmapper mapping with mappings from newConfig +func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (failedConfigs []*pogs.FailedConfig) { + // TODO: TUN-1968: Gracefully apply new config + s.tunnelHostnameMapper.DeleteAll() + for _, tunnelConfig := range newConfig { + tunnelHostname := tunnelConfig.TunnelHostname + originSerice, err := tunnelConfig.Origin.Service() + if err != nil { + s.logger.WithField("tunnelHostname", tunnelHostname).WithError(err).Error("Invalid origin service config") + failedConfigs = append(failedConfigs, &pogs.FailedConfig{ + Config: tunnelConfig, + Reason: tunnelConfig.FailReason(err), + }) + continue + } + s.tunnelHostnameMapper.Add(tunnelConfig.TunnelHostname, originSerice) + s.logger.WithField("tunnelHostname", tunnelHostname).Infof("New origin service config: %v", originSerice.Summary()) + } + return +} + // ServeStream implements MuxedStreamHandler interface func (s *StreamHandler) ServeStream(stream *h2mux.MuxedStream) error { if stream.IsRPCStream() { - return fmt.Errorf("serveRPC not implemented") + return s.serveRPC(stream) } return s.serveRequest(stream) } +func (s *StreamHandler) serveRPC(stream *h2mux.MuxedStream) error { + stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "200"}}) + main := pogs.ClientService_ServerToClient(s) + rpcLogger := s.logger.WithField("subsystem", "clientserver-rpc") + rpcConn := rpc.NewConn( + tunnelrpc.NewTransportLogger(rpcLogger, rpc.StreamTransport(stream)), + rpc.MainInterface(main.Client), + tunnelrpc.ConnLog(s.logger.WithField("subsystem", "clientserver-rpc-transport")), + ) + return rpcConn.Wait() +} + func (s *StreamHandler) serveRequest(stream *h2mux.MuxedStream) error { tunnelHostname := stream.TunnelHostname() if !tunnelHostname.IsSet() { diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go new file mode 100644 index 00000000..a4fcee22 --- /dev/null +++ b/supervisor/supervisor.go @@ -0,0 +1,179 @@ +package supervisor + +import ( + "context" + "crypto/tls" + "fmt" + "os" + "os/signal" + "sync" + "syscall" + + "golang.org/x/sync/errgroup" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/streamhandler" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/sirupsen/logrus" +) + +type Supervisor struct { + connManager *connection.EdgeManager + streamHandler *streamhandler.StreamHandler + autoupdater *updater.AutoUpdater + supportAutoupdate bool + newConfigChan <-chan *pogs.ClientConfig + useConfigResultChan chan<- *pogs.UseConfigurationResult + state *state + logger *logrus.Entry +} + +func NewSupervisor( + defaultClientConfig *pogs.ClientConfig, + userCredential []byte, + tlsConfig *tls.Config, + serviceDiscoverer connection.EdgeServiceDiscoverer, + cloudflaredConfig *connection.CloudflaredConfig, + autoupdater *updater.AutoUpdater, + supportAutoupdate bool, + logger *logrus.Logger, +) (*Supervisor, error) { + newConfigChan := make(chan *pogs.ClientConfig) + useConfigResultChan := make(chan *pogs.UseConfigurationResult) + streamHandler := streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, logger) + invalidConfigs := streamHandler.UpdateConfig(defaultClientConfig.ReverseProxyConfigs) + + if len(invalidConfigs) > 0 { + for _, invalidConfig := range invalidConfigs { + logger.Errorf("Tunnel %+v is invalid, reason: %s", invalidConfig.Config, invalidConfig.Reason) + } + return nil, fmt.Errorf("At least 1 Tunnel config is invalid") + } + + tunnelHostnames := make([]h2mux.TunnelHostname, len(defaultClientConfig.ReverseProxyConfigs)) + for i, reverseProxyConfig := range defaultClientConfig.ReverseProxyConfigs { + tunnelHostnames[i] = reverseProxyConfig.TunnelHostname + } + defaultEdgeMgrConfigurable := &connection.EdgeManagerConfigurable{ + tunnelHostnames, + defaultClientConfig.EdgeConnectionConfig, + } + return &Supervisor{ + connManager: connection.NewEdgeManager(streamHandler, defaultEdgeMgrConfigurable, userCredential, tlsConfig, + serviceDiscoverer, cloudflaredConfig, logger), + streamHandler: streamHandler, + autoupdater: autoupdater, + supportAutoupdate: supportAutoupdate, + newConfigChan: newConfigChan, + useConfigResultChan: useConfigResultChan, + state: newState(defaultClientConfig), + logger: logger.WithField("subsystem", "supervisor"), + }, nil +} + +func (s *Supervisor) Run(ctx context.Context) error { + errGroup, groupCtx := errgroup.WithContext(ctx) + + errGroup.Go(func() error { + return s.connManager.Run(groupCtx) + }) + + errGroup.Go(func() error { + return s.listenToNewConfig(groupCtx) + }) + + errGroup.Go(func() error { + return s.listenToShutdownSignal(groupCtx) + }) + + if s.supportAutoupdate { + errGroup.Go(func() error { + return s.autoupdater.Run(groupCtx) + }) + } + + err := errGroup.Wait() + s.logger.Warnf("Supervisor terminated, reason: %v", err) + return err +} + +func (s *Supervisor) listenToShutdownSignal(serveCtx context.Context) error { + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signals) + + select { + case <-serveCtx.Done(): + return serveCtx.Err() + case sig := <-signals: + return fmt.Errorf("received %v signal", sig) + } +} + +func (s *Supervisor) listenToNewConfig(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case newConfig := <-s.newConfigChan: + s.useConfigResultChan <- s.notifySubsystemsNewConfig(newConfig) + } + } +} + +func (s *Supervisor) notifySubsystemsNewConfig(newConfig *pogs.ClientConfig) *pogs.UseConfigurationResult { + s.logger.Infof("Received configuration %v", newConfig.Version) + if s.state.hasAppliedVersion(newConfig.Version) { + s.logger.Infof("%v has been applied", newConfig.Version) + return &pogs.UseConfigurationResult{ + Success: true, + } + } + + s.state.updateConfig(newConfig) + var tunnelHostnames []h2mux.TunnelHostname + for _, tunnelConfig := range newConfig.ReverseProxyConfigs { + tunnelHostnames = append(tunnelHostnames, tunnelConfig.TunnelHostname) + } + // Update connManager configurable + s.connManager.UpdateConfigurable(&connection.EdgeManagerConfigurable{ + tunnelHostnames, + newConfig.EdgeConnectionConfig, + }) + // Update streamHandler tunnelHostnameMapper mapping + failedConfigs := s.streamHandler.UpdateConfig(newConfig.ReverseProxyConfigs) + + if s.supportAutoupdate { + s.autoupdater.Update(newConfig.SupervisorConfig.AutoUpdateFrequency) + } + + return &pogs.UseConfigurationResult{ + Success: len(failedConfigs) == 0, + FailedConfigs: failedConfigs, + } +} + +type state struct { + sync.RWMutex + currentConfig *pogs.ClientConfig +} + +func newState(currentConfig *pogs.ClientConfig) *state { + return &state{ + currentConfig: currentConfig, + } +} + +func (s *state) hasAppliedVersion(incomingVersion pogs.Version) bool { + s.RLock() + defer s.RUnlock() + return s.currentConfig.Version.IsNewerOrEqual(incomingVersion) +} + +func (s *state) updateConfig(newConfig *pogs.ClientConfig) { + s.Lock() + defer s.Unlock() + s.currentConfig = newConfig +}