AUTH-2588 add DoH to service mode

This commit is contained in:
Dalton 2020-05-01 10:30:50 -05:00 committed by Dalton Cherry
parent 2c878c47ed
commit 2b7fbbb7b7
11 changed files with 406 additions and 81 deletions

View File

@ -0,0 +1,48 @@
package main
import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
)
// ForwardServiceType is used to identify what kind of overwatch service this is
const ForwardServiceType = "forward"
// ForwarderService is used to wrap the access package websocket forwarders
// into a service model for the overwatch package.
// it also holds a reference to the config object that represents its state
type ForwarderService struct {
forwarder config.Forwarder
shutdown chan struct{}
}
// NewForwardService creates a new forwarder service
func NewForwardService(f config.Forwarder) *ForwarderService {
return &ForwarderService{forwarder: f, shutdown: make(chan struct{}, 1)}
}
// Name is used to figure out this service is related to the others (normally the addr it binds to)
// e.g. localhost:78641 or 127.0.0.1:2222 since this is a websocket forwarder
func (s *ForwarderService) Name() string {
return s.forwarder.Listener
}
// Type is used to identify what kind of overwatch service this is
func (s *ForwarderService) Type() string {
return ForwardServiceType
}
// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates
func (s *ForwarderService) Hash() string {
return s.forwarder.Hash()
}
// Shutdown stops the websocket listener
func (s *ForwarderService) Shutdown() {
s.shutdown <- struct{}{}
}
// Run is the run loop that is started by the overwatch service
func (s *ForwarderService) Run() error {
return access.StartForwarder(s.forwarder, s.shutdown)
}

View File

@ -0,0 +1,73 @@
package main
import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/sirupsen/logrus"
)
// ResolverServiceType is used to identify what kind of overwatch service this is
const ResolverServiceType = "resolver"
// ResolverService is used to wrap the tunneldns package's DNS over HTTP
// into a service model for the overwatch package.
// it also holds a reference to the config object that represents its state
type ResolverService struct {
resolver config.DNSResolver
shutdown chan struct{}
logger *logrus.Logger
}
// NewResolverService creates a new resolver service
func NewResolverService(r config.DNSResolver, logger *logrus.Logger) *ResolverService {
return &ResolverService{resolver: r,
shutdown: make(chan struct{}),
logger: logger,
}
}
// Name is used to figure out this service is related to the others (normally the addr it binds to)
// this is just "resolver" since there can only be one DNS resolver running
func (s *ResolverService) Name() string {
return ResolverServiceType
}
// Type is used to identify what kind of overwatch service this is
func (s *ResolverService) Type() string {
return ResolverServiceType
}
// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates
func (s *ResolverService) Hash() string {
return s.resolver.Hash()
}
// Shutdown stops the tunneldns listener
func (s *ResolverService) Shutdown() {
s.shutdown <- struct{}{}
}
// Run is the run loop that is started by the overwatch service
func (s *ResolverService) Run() error {
// create a listener
l, err := tunneldns.CreateListener(s.resolver.AddressOrDefault(), s.resolver.PortOrDefault(),
s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault())
if err != nil {
return err
}
// start the listener.
readySignal := make(chan struct{})
err = l.Start(readySignal)
if err != nil {
l.Stop()
return err
}
<-readySignal
s.logger.Infof("start resolver on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault())
// wait for shutdown signal
<-s.shutdown
s.logger.Infof("shutdown on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault())
return l.Stop()
}

View File

@ -1,36 +1,27 @@
package main
import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/overwatch"
"github.com/sirupsen/logrus"
)
type forwarderState struct {
forwarder config.Forwarder
shutdown chan struct{}
}
func (s *forwarderState) Shutdown() {
s.shutdown <- struct{}{}
}
// AppService is the main service that runs when no command lines flags are passed to cloudflared
// it manages all the running services such as tunnels, forwarders, DNS resolver, etc
type AppService struct {
configManager config.Manager
serviceManager overwatch.Manager
shutdownC chan struct{}
forwarders map[string]forwarderState
configUpdateChan chan config.Root
logger *logrus.Logger
}
// NewAppService creates a new AppService with needed supporting services
func NewAppService(configManager config.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService {
func NewAppService(configManager config.Manager, serviceManager overwatch.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService {
return &AppService{
configManager: configManager,
serviceManager: serviceManager,
shutdownC: shutdownC,
forwarders: make(map[string]forwarderState),
configUpdateChan: make(chan config.Root),
logger: logger,
}
@ -45,6 +36,7 @@ func (s *AppService) Run() error {
// Shutdown kills all the running services
func (s *AppService) Shutdown() error {
s.configManager.Shutdown()
s.shutdownC <- struct{}{}
return nil
}
@ -62,8 +54,8 @@ func (s *AppService) actionLoop() {
case c := <-s.configUpdateChan:
s.handleConfigUpdate(c)
case <-s.shutdownC:
for _, state := range s.forwarders {
state.Shutdown()
for _, service := range s.serviceManager.Services() {
service.Shutdown()
}
return
}
@ -72,41 +64,27 @@ func (s *AppService) actionLoop() {
func (s *AppService) handleConfigUpdate(c config.Root) {
// handle the client forward listeners
activeListeners := map[string]struct{}{}
activeServices := map[string]struct{}{}
for _, f := range c.Forwarders {
s.handleForwarderUpdate(f)
activeListeners[f.Listener] = struct{}{}
service := NewForwardService(f)
s.serviceManager.Add(service)
activeServices[service.Name()] = struct{}{}
}
// remove any listeners that are no longer active
for key, state := range s.forwarders {
if _, ok := activeListeners[key]; !ok {
state.Shutdown()
delete(s.forwarders, key)
// handle resolver changes
if c.Resolver.Enabled {
service := NewResolverService(c.Resolver, s.logger)
s.serviceManager.Add(service)
activeServices[service.Name()] = struct{}{}
}
// TODO: TUN-1451 - tunnels
// remove any services that are no longer active
for _, service := range s.serviceManager.Services() {
if _, ok := activeServices[service.Name()]; !ok {
s.serviceManager.Remove(service.Name())
}
}
// TODO: AUTH-2588, TUN-1451 - tunnels and dns proxy
}
// handle managing a forwarder service
func (s *AppService) handleForwarderUpdate(f config.Forwarder) {
// check if we need to start a new listener or stop an old one
if state, ok := s.forwarders[f.Listener]; ok {
if state.forwarder.Hash() == f.Hash() {
return // the exact same listener, no changes, so move along
}
state.Shutdown() //shutdown the listener since a new one is starting
}
// add a new forwarder to the list
state := forwarderState{forwarder: f, shutdown: make(chan struct{}, 1)}
s.forwarders[f.Listener] = state
// start the forwarder
go func(f forwarderState) {
err := access.StartForwarder(f.forwarder, f.shutdown)
if err != nil {
s.logger.WithError(err).Errorf("Forwarder at address: %s", f.forwarder)
}
}(state)
}

View File

@ -28,14 +28,16 @@ type FileManager struct {
notifier Notifier
configPath string
logger *logrus.Logger
ReadConfig func(string) (Root, error)
}
// NewFileManager creates a config manager
func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (Manager, error) {
func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (*FileManager, error) {
m := &FileManager{
watcher: watcher,
configPath: configPath,
logger: logger,
ReadConfig: readConfigFromPath,
}
err := watcher.Add(configPath)
return m, err
@ -58,11 +60,20 @@ func (m *FileManager) Start(notifier Notifier) error {
// GetConfig reads the yaml file from the disk
func (m *FileManager) GetConfig() (Root, error) {
if m.configPath == "" {
return m.ReadConfig(m.configPath)
}
// Shutdown stops the watcher
func (m *FileManager) Shutdown() {
m.watcher.Shutdown()
}
func readConfigFromPath(configPath string) (Root, error) {
if configPath == "" {
return Root{}, errors.New("unable to find config file")
}
file, err := os.Open(m.configPath)
file, err := os.Open(configPath)
if err != nil {
return Root{}, err
}
@ -76,11 +87,6 @@ func (m *FileManager) GetConfig() (Root, error) {
return config, nil
}
// Shutdown stops the watcher
func (m *FileManager) Shutdown() {
m.watcher.Shutdown()
}
// File change notifications from the watcher
// WatcherItemDidChange triggers when the yaml config is updated

View File

@ -1,15 +1,12 @@
package config
import (
"bufio"
"os"
"testing"
"time"
"github.com/cloudflare/cloudflared/log"
"github.com/cloudflare/cloudflared/watcher"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
type mockNotifier struct {
@ -20,17 +17,27 @@ func (n *mockNotifier) ConfigDidUpdate(c Root) {
n.configs = append(n.configs, c)
}
func writeConfig(t *testing.T, f *os.File, c *Root) {
f.Sync()
b, err := yaml.Marshal(c)
assert.NoError(t, err)
type mockFileWatcher struct {
path string
notifier watcher.Notification
ready chan struct{}
}
w := bufio.NewWriter(f)
_, err = w.Write(b)
assert.NoError(t, err)
func (w *mockFileWatcher) Start(n watcher.Notification) {
w.notifier = n
w.ready <- struct{}{}
}
err = w.Flush()
assert.NoError(t, err)
func (w *mockFileWatcher) Add(string) error {
return nil
}
func (w *mockFileWatcher) Shutdown() {
}
func (w *mockFileWatcher) TriggerChange() {
w.notifier.WatcherItemDidChange(w.path)
}
func TestConfigChanged(t *testing.T) {
@ -52,22 +59,24 @@ func TestConfigChanged(t *testing.T) {
},
},
}
writeConfig(t, f, c)
configRead := func(configPath string) (Root, error) {
return *c, nil
}
wait := make(chan struct{})
w := &mockFileWatcher{path: filePath, ready: wait}
w, err := watcher.NewFile()
assert.NoError(t, err)
logger := log.CreateLogger()
service, err := NewFileManager(w, filePath, logger)
service.ReadConfig = configRead
assert.NoError(t, err)
n := &mockNotifier{}
go service.Start(n)
<-wait
c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"})
writeConfig(t, f, c)
w.TriggerChange()
// give it time to trigger
time.Sleep(10 * time.Millisecond)
service.Shutdown()
assert.Len(t, n.configs, 2, "did not get 2 config updates as expected")

View File

@ -4,6 +4,7 @@ import (
"crypto/md5"
"fmt"
"io"
"strings"
)
// Forwarder represents a client side listener to forward traffic to the edge
@ -19,6 +20,15 @@ type Tunnel struct {
ProtocolType string `json:"type"`
}
// DNSResolver represents a client side DNS resolver
type DNSResolver struct {
Enabled bool `json:"enabled"`
Address string `json:"address"`
Port uint16 `json:"port"`
Upstreams []string `json:"upstreams"`
Bootstraps []string `json:"bootstraps"`
}
// Root is the base options to configure the service
type Root struct {
OrgKey string `json:"org_key"`
@ -26,6 +36,7 @@ type Root struct {
CheckinInterval int `json:"checkin_interval"`
Forwarders []Forwarder `json:"forwarders,omitempty"`
Tunnels []Tunnel `json:"tunnels,omitempty"`
Resolver DNSResolver `json:"resolver"`
}
// Hash returns the computed values to see if the forwarder values change
@ -35,3 +46,51 @@ func (f *Forwarder) Hash() string {
io.WriteString(h, f.Listener)
return fmt.Sprintf("%x", h.Sum(nil))
}
// Hash returns the computed values to see if the forwarder values change
func (r *DNSResolver) Hash() string {
h := md5.New()
io.WriteString(h, r.Address)
io.WriteString(h, strings.Join(r.Bootstraps, ","))
io.WriteString(h, strings.Join(r.Upstreams, ","))
io.WriteString(h, fmt.Sprintf("%d", r.Port))
io.WriteString(h, fmt.Sprintf("%v", r.Enabled))
return fmt.Sprintf("%x", h.Sum(nil))
}
// EnabledOrDefault returns the enabled property
func (r *DNSResolver) EnabledOrDefault() bool {
return r.Enabled
}
// AddressOrDefault returns the address or returns the default if empty
func (r *DNSResolver) AddressOrDefault() string {
if r.Address != "" {
return r.Address
}
return "localhost"
}
// PortOrDefault return the port or returns the default if 0
func (r *DNSResolver) PortOrDefault() uint16 {
if r.Port > 0 {
return r.Port
}
return 53
}
// UpstreamsOrDefault returns the upstreams or returns the default if empty
func (r *DNSResolver) UpstreamsOrDefault() []string {
if len(r.Upstreams) > 0 {
return r.Upstreams
}
return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"}
}
// BootstrapsOrDefault returns the bootstraps or returns the default if empty
func (r *DNSResolver) BootstrapsOrDefault() []string {
if len(r.Bootstraps) > 0 {
return r.Bootstraps
}
return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/log"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/overwatch"
"github.com/cloudflare/cloudflared/watcher"
raven "github.com/getsentry/raven-go"
@ -180,7 +181,9 @@ func handleServiceMode(shutdownC chan struct{}) error {
return err
}
appService := NewAppService(configManager, shutdownC, logger)
serviceManager := overwatch.NewAppManager(nil)
appService := NewAppService(configManager, serviceManager, shutdownC, logger)
if err := appService.Run(); err != nil {
logger.WithError(err).Error("Failed to start app service")
return err

53
overwatch/app_manager.go Normal file
View File

@ -0,0 +1,53 @@
package overwatch
// AppManager is the default implementation of overwatch service management
type AppManager struct {
services map[string]Service
errorChan chan error
}
// NewAppManager creates a new overwatch manager
func NewAppManager(errorChan chan error) Manager {
return &AppManager{services: make(map[string]Service), errorChan: errorChan}
}
// Add takes in a new service to manage.
// It stops the service if it already exist in the manager and is running
// It then starts the newly added service
func (m *AppManager) Add(service Service) {
// check for existing service
if currentService, ok := m.services[service.Name()]; ok {
if currentService.Hash() == service.Hash() {
return // the exact same service, no changes, so move along
}
currentService.Shutdown() //shutdown the listener since a new one is starting
}
m.services[service.Name()] = service
//start the service!
go m.serviceRun(service)
}
// Remove shutdowns the service by name and removes it from its current management list
func (m *AppManager) Remove(name string) {
if currentService, ok := m.services[name]; ok {
currentService.Shutdown()
}
delete(m.services, name)
}
// Services returns all the current Services being managed
func (m *AppManager) Services() []Service {
values := []Service{}
for _, value := range m.services {
values = append(values, value)
}
return values
}
func (m *AppManager) serviceRun(service Service) {
err := service.Run()
if err != nil && m.errorChan != nil {
m.errorChan <- err
}
}

17
overwatch/manager.go Normal file
View File

@ -0,0 +1,17 @@
package overwatch
// Service is the required functions for an object to be managed by the overwatch Manager
type Service interface {
Name() string
Type() string
Hash() string
Shutdown()
Run() error
}
// Manager is based type to manage running services
type Manager interface {
Add(Service)
Remove(string)
Services() []Service
}

74
overwatch/manager_test.go Normal file
View File

@ -0,0 +1,74 @@
package overwatch
import (
"crypto/md5"
"errors"
"fmt"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
type mockService struct {
serviceName string
serviceType string
runError error
}
func (s *mockService) Name() string {
return s.serviceName
}
func (s *mockService) Type() string {
return s.serviceType
}
func (s *mockService) Hash() string {
h := md5.New()
io.WriteString(h, s.serviceName)
io.WriteString(h, s.serviceType)
return fmt.Sprintf("%x", h.Sum(nil))
}
func (s *mockService) Shutdown() {
}
func (s *mockService) Run() error {
return s.runError
}
func TestManagerAddAndRemove(t *testing.T) {
m := NewAppManager(nil)
first := &mockService{serviceName: "first", serviceType: "mock"}
second := &mockService{serviceName: "second", serviceType: "mock"}
m.Add(first)
m.Add(second)
assert.Len(t, m.Services(), 2, "expected 2 services in the list")
m.Remove(first.Name())
services := m.Services()
assert.Len(t, services, 1, "expected 1 service in the list")
assert.Equal(t, second.Hash(), services[0].Hash(), "hashes should match. Wrong service was removed")
}
func TestManagerDuplicate(t *testing.T) {
m := NewAppManager(nil)
first := &mockService{serviceName: "first", serviceType: "mock"}
m.Add(first)
m.Add(first)
assert.Len(t, m.Services(), 1, "expected 1 service in the list")
}
func TestManagerErrorChannel(t *testing.T) {
errChan := make(chan error)
m := NewAppManager(errChan)
err := errors.New("test error")
first := &mockService{serviceName: "first", serviceType: "mock", runError: err}
m.Add(first)
respErr := <-errChan
assert.Equal(t, err, respErr, "errors don't match")
}

View File

@ -2,6 +2,7 @@ package tunneldns
import (
"context"
"sync"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
@ -12,6 +13,8 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
var once sync.Once
// MetricsPlugin is an adapter for CoreDNS and built-in metrics
type MetricsPlugin struct {
Next plugin.Handler
@ -19,13 +22,15 @@ type MetricsPlugin struct {
// NewMetricsPlugin creates a plugin with configured metrics
func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
prometheus.MustRegister(vars.RequestCount)
prometheus.MustRegister(vars.RequestDuration)
prometheus.MustRegister(vars.RequestSize)
prometheus.MustRegister(vars.RequestDo)
prometheus.MustRegister(vars.RequestType)
prometheus.MustRegister(vars.ResponseSize)
prometheus.MustRegister(vars.ResponseRcode)
once.Do(func() {
prometheus.MustRegister(vars.RequestCount)
prometheus.MustRegister(vars.RequestDuration)
prometheus.MustRegister(vars.RequestSize)
prometheus.MustRegister(vars.RequestDo)
prometheus.MustRegister(vars.RequestType)
prometheus.MustRegister(vars.ResponseSize)
prometheus.MustRegister(vars.ResponseRcode)
})
return &MetricsPlugin{Next: next}
}