diff --git a/cmd/cloudflared/app_forward_service.go b/cmd/cloudflared/app_forward_service.go new file mode 100644 index 00000000..4c8446e2 --- /dev/null +++ b/cmd/cloudflared/app_forward_service.go @@ -0,0 +1,48 @@ +package main + +import ( + "github.com/cloudflare/cloudflared/cmd/cloudflared/access" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" +) + +// ForwardServiceType is used to identify what kind of overwatch service this is +const ForwardServiceType = "forward" + +// ForwarderService is used to wrap the access package websocket forwarders +// into a service model for the overwatch package. +// it also holds a reference to the config object that represents its state +type ForwarderService struct { + forwarder config.Forwarder + shutdown chan struct{} +} + +// NewForwardService creates a new forwarder service +func NewForwardService(f config.Forwarder) *ForwarderService { + return &ForwarderService{forwarder: f, shutdown: make(chan struct{}, 1)} +} + +// Name is used to figure out this service is related to the others (normally the addr it binds to) +// e.g. localhost:78641 or 127.0.0.1:2222 since this is a websocket forwarder +func (s *ForwarderService) Name() string { + return s.forwarder.Listener +} + +// Type is used to identify what kind of overwatch service this is +func (s *ForwarderService) Type() string { + return ForwardServiceType +} + +// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates +func (s *ForwarderService) Hash() string { + return s.forwarder.Hash() +} + +// Shutdown stops the websocket listener +func (s *ForwarderService) Shutdown() { + s.shutdown <- struct{}{} +} + +// Run is the run loop that is started by the overwatch service +func (s *ForwarderService) Run() error { + return access.StartForwarder(s.forwarder, s.shutdown) +} diff --git a/cmd/cloudflared/app_resolver_service.go b/cmd/cloudflared/app_resolver_service.go new file mode 100644 index 00000000..f98ca4c8 --- /dev/null +++ b/cmd/cloudflared/app_resolver_service.go @@ -0,0 +1,73 @@ +package main + +import ( + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/tunneldns" + "github.com/sirupsen/logrus" +) + +// ResolverServiceType is used to identify what kind of overwatch service this is +const ResolverServiceType = "resolver" + +// ResolverService is used to wrap the tunneldns package's DNS over HTTP +// into a service model for the overwatch package. +// it also holds a reference to the config object that represents its state +type ResolverService struct { + resolver config.DNSResolver + shutdown chan struct{} + logger *logrus.Logger +} + +// NewResolverService creates a new resolver service +func NewResolverService(r config.DNSResolver, logger *logrus.Logger) *ResolverService { + return &ResolverService{resolver: r, + shutdown: make(chan struct{}), + logger: logger, + } +} + +// Name is used to figure out this service is related to the others (normally the addr it binds to) +// this is just "resolver" since there can only be one DNS resolver running +func (s *ResolverService) Name() string { + return ResolverServiceType +} + +// Type is used to identify what kind of overwatch service this is +func (s *ResolverService) Type() string { + return ResolverServiceType +} + +// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates +func (s *ResolverService) Hash() string { + return s.resolver.Hash() +} + +// Shutdown stops the tunneldns listener +func (s *ResolverService) Shutdown() { + s.shutdown <- struct{}{} +} + +// Run is the run loop that is started by the overwatch service +func (s *ResolverService) Run() error { + // create a listener + l, err := tunneldns.CreateListener(s.resolver.AddressOrDefault(), s.resolver.PortOrDefault(), + s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault()) + if err != nil { + return err + } + + // start the listener. + readySignal := make(chan struct{}) + err = l.Start(readySignal) + if err != nil { + l.Stop() + return err + } + <-readySignal + s.logger.Infof("start resolver on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault()) + + // wait for shutdown signal + <-s.shutdown + s.logger.Infof("shutdown on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault()) + return l.Stop() +} diff --git a/cmd/cloudflared/app_service.go b/cmd/cloudflared/app_service.go index 42cfa4eb..9407d6e2 100644 --- a/cmd/cloudflared/app_service.go +++ b/cmd/cloudflared/app_service.go @@ -1,36 +1,27 @@ package main import ( - "github.com/cloudflare/cloudflared/cmd/cloudflared/access" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/overwatch" "github.com/sirupsen/logrus" ) -type forwarderState struct { - forwarder config.Forwarder - shutdown chan struct{} -} - -func (s *forwarderState) Shutdown() { - s.shutdown <- struct{}{} -} - // AppService is the main service that runs when no command lines flags are passed to cloudflared // it manages all the running services such as tunnels, forwarders, DNS resolver, etc type AppService struct { configManager config.Manager + serviceManager overwatch.Manager shutdownC chan struct{} - forwarders map[string]forwarderState configUpdateChan chan config.Root logger *logrus.Logger } // NewAppService creates a new AppService with needed supporting services -func NewAppService(configManager config.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService { +func NewAppService(configManager config.Manager, serviceManager overwatch.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService { return &AppService{ configManager: configManager, + serviceManager: serviceManager, shutdownC: shutdownC, - forwarders: make(map[string]forwarderState), configUpdateChan: make(chan config.Root), logger: logger, } @@ -45,6 +36,7 @@ func (s *AppService) Run() error { // Shutdown kills all the running services func (s *AppService) Shutdown() error { s.configManager.Shutdown() + s.shutdownC <- struct{}{} return nil } @@ -62,8 +54,8 @@ func (s *AppService) actionLoop() { case c := <-s.configUpdateChan: s.handleConfigUpdate(c) case <-s.shutdownC: - for _, state := range s.forwarders { - state.Shutdown() + for _, service := range s.serviceManager.Services() { + service.Shutdown() } return } @@ -72,41 +64,27 @@ func (s *AppService) actionLoop() { func (s *AppService) handleConfigUpdate(c config.Root) { // handle the client forward listeners - activeListeners := map[string]struct{}{} + activeServices := map[string]struct{}{} for _, f := range c.Forwarders { - s.handleForwarderUpdate(f) - activeListeners[f.Listener] = struct{}{} + service := NewForwardService(f) + s.serviceManager.Add(service) + activeServices[service.Name()] = struct{}{} } - // remove any listeners that are no longer active - for key, state := range s.forwarders { - if _, ok := activeListeners[key]; !ok { - state.Shutdown() - delete(s.forwarders, key) + // handle resolver changes + if c.Resolver.Enabled { + service := NewResolverService(c.Resolver, s.logger) + s.serviceManager.Add(service) + activeServices[service.Name()] = struct{}{} + + } + + // TODO: TUN-1451 - tunnels + + // remove any services that are no longer active + for _, service := range s.serviceManager.Services() { + if _, ok := activeServices[service.Name()]; !ok { + s.serviceManager.Remove(service.Name()) } } - - // TODO: AUTH-2588, TUN-1451 - tunnels and dns proxy -} - -// handle managing a forwarder service -func (s *AppService) handleForwarderUpdate(f config.Forwarder) { - // check if we need to start a new listener or stop an old one - if state, ok := s.forwarders[f.Listener]; ok { - if state.forwarder.Hash() == f.Hash() { - return // the exact same listener, no changes, so move along - } - state.Shutdown() //shutdown the listener since a new one is starting - } - // add a new forwarder to the list - state := forwarderState{forwarder: f, shutdown: make(chan struct{}, 1)} - s.forwarders[f.Listener] = state - - // start the forwarder - go func(f forwarderState) { - err := access.StartForwarder(f.forwarder, f.shutdown) - if err != nil { - s.logger.WithError(err).Errorf("Forwarder at address: %s", f.forwarder) - } - }(state) } diff --git a/cmd/cloudflared/config/manager.go b/cmd/cloudflared/config/manager.go index c02f4b85..8723b70d 100644 --- a/cmd/cloudflared/config/manager.go +++ b/cmd/cloudflared/config/manager.go @@ -28,14 +28,16 @@ type FileManager struct { notifier Notifier configPath string logger *logrus.Logger + ReadConfig func(string) (Root, error) } // NewFileManager creates a config manager -func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (Manager, error) { +func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (*FileManager, error) { m := &FileManager{ watcher: watcher, configPath: configPath, logger: logger, + ReadConfig: readConfigFromPath, } err := watcher.Add(configPath) return m, err @@ -58,11 +60,20 @@ func (m *FileManager) Start(notifier Notifier) error { // GetConfig reads the yaml file from the disk func (m *FileManager) GetConfig() (Root, error) { - if m.configPath == "" { + return m.ReadConfig(m.configPath) +} + +// Shutdown stops the watcher +func (m *FileManager) Shutdown() { + m.watcher.Shutdown() +} + +func readConfigFromPath(configPath string) (Root, error) { + if configPath == "" { return Root{}, errors.New("unable to find config file") } - file, err := os.Open(m.configPath) + file, err := os.Open(configPath) if err != nil { return Root{}, err } @@ -76,11 +87,6 @@ func (m *FileManager) GetConfig() (Root, error) { return config, nil } -// Shutdown stops the watcher -func (m *FileManager) Shutdown() { - m.watcher.Shutdown() -} - // File change notifications from the watcher // WatcherItemDidChange triggers when the yaml config is updated diff --git a/cmd/cloudflared/config/manager_test.go b/cmd/cloudflared/config/manager_test.go index b9ed82c8..419976aa 100644 --- a/cmd/cloudflared/config/manager_test.go +++ b/cmd/cloudflared/config/manager_test.go @@ -1,15 +1,12 @@ package config import ( - "bufio" "os" "testing" - "time" "github.com/cloudflare/cloudflared/log" "github.com/cloudflare/cloudflared/watcher" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v2" ) type mockNotifier struct { @@ -20,17 +17,27 @@ func (n *mockNotifier) ConfigDidUpdate(c Root) { n.configs = append(n.configs, c) } -func writeConfig(t *testing.T, f *os.File, c *Root) { - f.Sync() - b, err := yaml.Marshal(c) - assert.NoError(t, err) +type mockFileWatcher struct { + path string + notifier watcher.Notification + ready chan struct{} +} - w := bufio.NewWriter(f) - _, err = w.Write(b) - assert.NoError(t, err) +func (w *mockFileWatcher) Start(n watcher.Notification) { + w.notifier = n + w.ready <- struct{}{} +} - err = w.Flush() - assert.NoError(t, err) +func (w *mockFileWatcher) Add(string) error { + return nil +} + +func (w *mockFileWatcher) Shutdown() { + +} + +func (w *mockFileWatcher) TriggerChange() { + w.notifier.WatcherItemDidChange(w.path) } func TestConfigChanged(t *testing.T) { @@ -52,22 +59,24 @@ func TestConfigChanged(t *testing.T) { }, }, } - writeConfig(t, f, c) + configRead := func(configPath string) (Root, error) { + return *c, nil + } + wait := make(chan struct{}) + w := &mockFileWatcher{path: filePath, ready: wait} - w, err := watcher.NewFile() - assert.NoError(t, err) logger := log.CreateLogger() service, err := NewFileManager(w, filePath, logger) + service.ReadConfig = configRead assert.NoError(t, err) n := &mockNotifier{} go service.Start(n) + <-wait c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"}) - writeConfig(t, f, c) + w.TriggerChange() - // give it time to trigger - time.Sleep(10 * time.Millisecond) service.Shutdown() assert.Len(t, n.configs, 2, "did not get 2 config updates as expected") diff --git a/cmd/cloudflared/config/model.go b/cmd/cloudflared/config/model.go index 85312538..95d266e8 100644 --- a/cmd/cloudflared/config/model.go +++ b/cmd/cloudflared/config/model.go @@ -4,6 +4,7 @@ import ( "crypto/md5" "fmt" "io" + "strings" ) // Forwarder represents a client side listener to forward traffic to the edge @@ -19,6 +20,15 @@ type Tunnel struct { ProtocolType string `json:"type"` } +// DNSResolver represents a client side DNS resolver +type DNSResolver struct { + Enabled bool `json:"enabled"` + Address string `json:"address"` + Port uint16 `json:"port"` + Upstreams []string `json:"upstreams"` + Bootstraps []string `json:"bootstraps"` +} + // Root is the base options to configure the service type Root struct { OrgKey string `json:"org_key"` @@ -26,6 +36,7 @@ type Root struct { CheckinInterval int `json:"checkin_interval"` Forwarders []Forwarder `json:"forwarders,omitempty"` Tunnels []Tunnel `json:"tunnels,omitempty"` + Resolver DNSResolver `json:"resolver"` } // Hash returns the computed values to see if the forwarder values change @@ -35,3 +46,51 @@ func (f *Forwarder) Hash() string { io.WriteString(h, f.Listener) return fmt.Sprintf("%x", h.Sum(nil)) } + +// Hash returns the computed values to see if the forwarder values change +func (r *DNSResolver) Hash() string { + h := md5.New() + io.WriteString(h, r.Address) + io.WriteString(h, strings.Join(r.Bootstraps, ",")) + io.WriteString(h, strings.Join(r.Upstreams, ",")) + io.WriteString(h, fmt.Sprintf("%d", r.Port)) + io.WriteString(h, fmt.Sprintf("%v", r.Enabled)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// EnabledOrDefault returns the enabled property +func (r *DNSResolver) EnabledOrDefault() bool { + return r.Enabled +} + +// AddressOrDefault returns the address or returns the default if empty +func (r *DNSResolver) AddressOrDefault() string { + if r.Address != "" { + return r.Address + } + return "localhost" +} + +// PortOrDefault return the port or returns the default if 0 +func (r *DNSResolver) PortOrDefault() uint16 { + if r.Port > 0 { + return r.Port + } + return 53 +} + +// UpstreamsOrDefault returns the upstreams or returns the default if empty +func (r *DNSResolver) UpstreamsOrDefault() []string { + if len(r.Upstreams) > 0 { + return r.Upstreams + } + return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"} +} + +// BootstrapsOrDefault returns the bootstraps or returns the default if empty +func (r *DNSResolver) BootstrapsOrDefault() []string { + if len(r.Bootstraps) > 0 { + return r.Bootstraps + } + return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"} +} diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index 33a857cb..97687984 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -11,6 +11,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/log" "github.com/cloudflare/cloudflared/metrics" + "github.com/cloudflare/cloudflared/overwatch" "github.com/cloudflare/cloudflared/watcher" raven "github.com/getsentry/raven-go" @@ -180,7 +181,9 @@ func handleServiceMode(shutdownC chan struct{}) error { return err } - appService := NewAppService(configManager, shutdownC, logger) + serviceManager := overwatch.NewAppManager(nil) + + appService := NewAppService(configManager, serviceManager, shutdownC, logger) if err := appService.Run(); err != nil { logger.WithError(err).Error("Failed to start app service") return err diff --git a/overwatch/app_manager.go b/overwatch/app_manager.go new file mode 100644 index 00000000..bd341551 --- /dev/null +++ b/overwatch/app_manager.go @@ -0,0 +1,53 @@ +package overwatch + +// AppManager is the default implementation of overwatch service management +type AppManager struct { + services map[string]Service + errorChan chan error +} + +// NewAppManager creates a new overwatch manager +func NewAppManager(errorChan chan error) Manager { + return &AppManager{services: make(map[string]Service), errorChan: errorChan} +} + +// Add takes in a new service to manage. +// It stops the service if it already exist in the manager and is running +// It then starts the newly added service +func (m *AppManager) Add(service Service) { + // check for existing service + if currentService, ok := m.services[service.Name()]; ok { + if currentService.Hash() == service.Hash() { + return // the exact same service, no changes, so move along + } + currentService.Shutdown() //shutdown the listener since a new one is starting + } + m.services[service.Name()] = service + + //start the service! + go m.serviceRun(service) +} + +// Remove shutdowns the service by name and removes it from its current management list +func (m *AppManager) Remove(name string) { + if currentService, ok := m.services[name]; ok { + currentService.Shutdown() + } + delete(m.services, name) +} + +// Services returns all the current Services being managed +func (m *AppManager) Services() []Service { + values := []Service{} + for _, value := range m.services { + values = append(values, value) + } + return values +} + +func (m *AppManager) serviceRun(service Service) { + err := service.Run() + if err != nil && m.errorChan != nil { + m.errorChan <- err + } +} diff --git a/overwatch/manager.go b/overwatch/manager.go new file mode 100644 index 00000000..58aff6f2 --- /dev/null +++ b/overwatch/manager.go @@ -0,0 +1,17 @@ +package overwatch + +// Service is the required functions for an object to be managed by the overwatch Manager +type Service interface { + Name() string + Type() string + Hash() string + Shutdown() + Run() error +} + +// Manager is based type to manage running services +type Manager interface { + Add(Service) + Remove(string) + Services() []Service +} diff --git a/overwatch/manager_test.go b/overwatch/manager_test.go new file mode 100644 index 00000000..d07e6cf0 --- /dev/null +++ b/overwatch/manager_test.go @@ -0,0 +1,74 @@ +package overwatch + +import ( + "crypto/md5" + "errors" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +type mockService struct { + serviceName string + serviceType string + runError error +} + +func (s *mockService) Name() string { + return s.serviceName +} + +func (s *mockService) Type() string { + return s.serviceType +} + +func (s *mockService) Hash() string { + h := md5.New() + io.WriteString(h, s.serviceName) + io.WriteString(h, s.serviceType) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func (s *mockService) Shutdown() { +} + +func (s *mockService) Run() error { + return s.runError +} + +func TestManagerAddAndRemove(t *testing.T) { + m := NewAppManager(nil) + + first := &mockService{serviceName: "first", serviceType: "mock"} + second := &mockService{serviceName: "second", serviceType: "mock"} + m.Add(first) + m.Add(second) + assert.Len(t, m.Services(), 2, "expected 2 services in the list") + + m.Remove(first.Name()) + services := m.Services() + assert.Len(t, services, 1, "expected 1 service in the list") + assert.Equal(t, second.Hash(), services[0].Hash(), "hashes should match. Wrong service was removed") +} + +func TestManagerDuplicate(t *testing.T) { + m := NewAppManager(nil) + + first := &mockService{serviceName: "first", serviceType: "mock"} + m.Add(first) + m.Add(first) + assert.Len(t, m.Services(), 1, "expected 1 service in the list") +} + +func TestManagerErrorChannel(t *testing.T) { + errChan := make(chan error) + m := NewAppManager(errChan) + + err := errors.New("test error") + first := &mockService{serviceName: "first", serviceType: "mock", runError: err} + m.Add(first) + respErr := <-errChan + assert.Equal(t, err, respErr, "errors don't match") +} diff --git a/tunneldns/metrics.go b/tunneldns/metrics.go index 33f9eeae..bab53472 100644 --- a/tunneldns/metrics.go +++ b/tunneldns/metrics.go @@ -2,6 +2,7 @@ package tunneldns import ( "context" + "sync" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/metrics/vars" @@ -12,6 +13,8 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +var once sync.Once + // MetricsPlugin is an adapter for CoreDNS and built-in metrics type MetricsPlugin struct { Next plugin.Handler @@ -19,13 +22,15 @@ type MetricsPlugin struct { // NewMetricsPlugin creates a plugin with configured metrics func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin { - prometheus.MustRegister(vars.RequestCount) - prometheus.MustRegister(vars.RequestDuration) - prometheus.MustRegister(vars.RequestSize) - prometheus.MustRegister(vars.RequestDo) - prometheus.MustRegister(vars.RequestType) - prometheus.MustRegister(vars.ResponseSize) - prometheus.MustRegister(vars.ResponseRcode) + once.Do(func() { + prometheus.MustRegister(vars.RequestCount) + prometheus.MustRegister(vars.RequestDuration) + prometheus.MustRegister(vars.RequestSize) + prometheus.MustRegister(vars.RequestDo) + prometheus.MustRegister(vars.RequestType) + prometheus.MustRegister(vars.ResponseSize) + prometheus.MustRegister(vars.ResponseRcode) + }) return &MetricsPlugin{Next: next} }