AUTH-2588 add DoH to service mode
This commit is contained in:
parent
2c878c47ed
commit
2b7fbbb7b7
|
@ -0,0 +1,48 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
)
|
||||
|
||||
// ForwardServiceType is used to identify what kind of overwatch service this is
|
||||
const ForwardServiceType = "forward"
|
||||
|
||||
// ForwarderService is used to wrap the access package websocket forwarders
|
||||
// into a service model for the overwatch package.
|
||||
// it also holds a reference to the config object that represents its state
|
||||
type ForwarderService struct {
|
||||
forwarder config.Forwarder
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// NewForwardService creates a new forwarder service
|
||||
func NewForwardService(f config.Forwarder) *ForwarderService {
|
||||
return &ForwarderService{forwarder: f, shutdown: make(chan struct{}, 1)}
|
||||
}
|
||||
|
||||
// Name is used to figure out this service is related to the others (normally the addr it binds to)
|
||||
// e.g. localhost:78641 or 127.0.0.1:2222 since this is a websocket forwarder
|
||||
func (s *ForwarderService) Name() string {
|
||||
return s.forwarder.Listener
|
||||
}
|
||||
|
||||
// Type is used to identify what kind of overwatch service this is
|
||||
func (s *ForwarderService) Type() string {
|
||||
return ForwardServiceType
|
||||
}
|
||||
|
||||
// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates
|
||||
func (s *ForwarderService) Hash() string {
|
||||
return s.forwarder.Hash()
|
||||
}
|
||||
|
||||
// Shutdown stops the websocket listener
|
||||
func (s *ForwarderService) Shutdown() {
|
||||
s.shutdown <- struct{}{}
|
||||
}
|
||||
|
||||
// Run is the run loop that is started by the overwatch service
|
||||
func (s *ForwarderService) Run() error {
|
||||
return access.StartForwarder(s.forwarder, s.shutdown)
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ResolverServiceType is used to identify what kind of overwatch service this is
|
||||
const ResolverServiceType = "resolver"
|
||||
|
||||
// ResolverService is used to wrap the tunneldns package's DNS over HTTP
|
||||
// into a service model for the overwatch package.
|
||||
// it also holds a reference to the config object that represents its state
|
||||
type ResolverService struct {
|
||||
resolver config.DNSResolver
|
||||
shutdown chan struct{}
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewResolverService creates a new resolver service
|
||||
func NewResolverService(r config.DNSResolver, logger *logrus.Logger) *ResolverService {
|
||||
return &ResolverService{resolver: r,
|
||||
shutdown: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Name is used to figure out this service is related to the others (normally the addr it binds to)
|
||||
// this is just "resolver" since there can only be one DNS resolver running
|
||||
func (s *ResolverService) Name() string {
|
||||
return ResolverServiceType
|
||||
}
|
||||
|
||||
// Type is used to identify what kind of overwatch service this is
|
||||
func (s *ResolverService) Type() string {
|
||||
return ResolverServiceType
|
||||
}
|
||||
|
||||
// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates
|
||||
func (s *ResolverService) Hash() string {
|
||||
return s.resolver.Hash()
|
||||
}
|
||||
|
||||
// Shutdown stops the tunneldns listener
|
||||
func (s *ResolverService) Shutdown() {
|
||||
s.shutdown <- struct{}{}
|
||||
}
|
||||
|
||||
// Run is the run loop that is started by the overwatch service
|
||||
func (s *ResolverService) Run() error {
|
||||
// create a listener
|
||||
l, err := tunneldns.CreateListener(s.resolver.AddressOrDefault(), s.resolver.PortOrDefault(),
|
||||
s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// start the listener.
|
||||
readySignal := make(chan struct{})
|
||||
err = l.Start(readySignal)
|
||||
if err != nil {
|
||||
l.Stop()
|
||||
return err
|
||||
}
|
||||
<-readySignal
|
||||
s.logger.Infof("start resolver on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault())
|
||||
|
||||
// wait for shutdown signal
|
||||
<-s.shutdown
|
||||
s.logger.Infof("shutdown on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault())
|
||||
return l.Stop()
|
||||
}
|
|
@ -1,36 +1,27 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/overwatch"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type forwarderState struct {
|
||||
forwarder config.Forwarder
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
func (s *forwarderState) Shutdown() {
|
||||
s.shutdown <- struct{}{}
|
||||
}
|
||||
|
||||
// AppService is the main service that runs when no command lines flags are passed to cloudflared
|
||||
// it manages all the running services such as tunnels, forwarders, DNS resolver, etc
|
||||
type AppService struct {
|
||||
configManager config.Manager
|
||||
serviceManager overwatch.Manager
|
||||
shutdownC chan struct{}
|
||||
forwarders map[string]forwarderState
|
||||
configUpdateChan chan config.Root
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewAppService creates a new AppService with needed supporting services
|
||||
func NewAppService(configManager config.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService {
|
||||
func NewAppService(configManager config.Manager, serviceManager overwatch.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService {
|
||||
return &AppService{
|
||||
configManager: configManager,
|
||||
serviceManager: serviceManager,
|
||||
shutdownC: shutdownC,
|
||||
forwarders: make(map[string]forwarderState),
|
||||
configUpdateChan: make(chan config.Root),
|
||||
logger: logger,
|
||||
}
|
||||
|
@ -45,6 +36,7 @@ func (s *AppService) Run() error {
|
|||
// Shutdown kills all the running services
|
||||
func (s *AppService) Shutdown() error {
|
||||
s.configManager.Shutdown()
|
||||
s.shutdownC <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -62,8 +54,8 @@ func (s *AppService) actionLoop() {
|
|||
case c := <-s.configUpdateChan:
|
||||
s.handleConfigUpdate(c)
|
||||
case <-s.shutdownC:
|
||||
for _, state := range s.forwarders {
|
||||
state.Shutdown()
|
||||
for _, service := range s.serviceManager.Services() {
|
||||
service.Shutdown()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -72,41 +64,27 @@ func (s *AppService) actionLoop() {
|
|||
|
||||
func (s *AppService) handleConfigUpdate(c config.Root) {
|
||||
// handle the client forward listeners
|
||||
activeListeners := map[string]struct{}{}
|
||||
activeServices := map[string]struct{}{}
|
||||
for _, f := range c.Forwarders {
|
||||
s.handleForwarderUpdate(f)
|
||||
activeListeners[f.Listener] = struct{}{}
|
||||
service := NewForwardService(f)
|
||||
s.serviceManager.Add(service)
|
||||
activeServices[service.Name()] = struct{}{}
|
||||
}
|
||||
|
||||
// remove any listeners that are no longer active
|
||||
for key, state := range s.forwarders {
|
||||
if _, ok := activeListeners[key]; !ok {
|
||||
state.Shutdown()
|
||||
delete(s.forwarders, key)
|
||||
}
|
||||
// handle resolver changes
|
||||
if c.Resolver.Enabled {
|
||||
service := NewResolverService(c.Resolver, s.logger)
|
||||
s.serviceManager.Add(service)
|
||||
activeServices[service.Name()] = struct{}{}
|
||||
|
||||
}
|
||||
|
||||
// TODO: AUTH-2588, TUN-1451 - tunnels and dns proxy
|
||||
}
|
||||
|
||||
// handle managing a forwarder service
|
||||
func (s *AppService) handleForwarderUpdate(f config.Forwarder) {
|
||||
// check if we need to start a new listener or stop an old one
|
||||
if state, ok := s.forwarders[f.Listener]; ok {
|
||||
if state.forwarder.Hash() == f.Hash() {
|
||||
return // the exact same listener, no changes, so move along
|
||||
}
|
||||
state.Shutdown() //shutdown the listener since a new one is starting
|
||||
}
|
||||
// add a new forwarder to the list
|
||||
state := forwarderState{forwarder: f, shutdown: make(chan struct{}, 1)}
|
||||
s.forwarders[f.Listener] = state
|
||||
|
||||
// start the forwarder
|
||||
go func(f forwarderState) {
|
||||
err := access.StartForwarder(f.forwarder, f.shutdown)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Forwarder at address: %s", f.forwarder)
|
||||
}
|
||||
}(state)
|
||||
// TODO: TUN-1451 - tunnels
|
||||
|
||||
// remove any services that are no longer active
|
||||
for _, service := range s.serviceManager.Services() {
|
||||
if _, ok := activeServices[service.Name()]; !ok {
|
||||
s.serviceManager.Remove(service.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,14 +28,16 @@ type FileManager struct {
|
|||
notifier Notifier
|
||||
configPath string
|
||||
logger *logrus.Logger
|
||||
ReadConfig func(string) (Root, error)
|
||||
}
|
||||
|
||||
// NewFileManager creates a config manager
|
||||
func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (Manager, error) {
|
||||
func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (*FileManager, error) {
|
||||
m := &FileManager{
|
||||
watcher: watcher,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
ReadConfig: readConfigFromPath,
|
||||
}
|
||||
err := watcher.Add(configPath)
|
||||
return m, err
|
||||
|
@ -58,11 +60,20 @@ func (m *FileManager) Start(notifier Notifier) error {
|
|||
|
||||
// GetConfig reads the yaml file from the disk
|
||||
func (m *FileManager) GetConfig() (Root, error) {
|
||||
if m.configPath == "" {
|
||||
return m.ReadConfig(m.configPath)
|
||||
}
|
||||
|
||||
// Shutdown stops the watcher
|
||||
func (m *FileManager) Shutdown() {
|
||||
m.watcher.Shutdown()
|
||||
}
|
||||
|
||||
func readConfigFromPath(configPath string) (Root, error) {
|
||||
if configPath == "" {
|
||||
return Root{}, errors.New("unable to find config file")
|
||||
}
|
||||
|
||||
file, err := os.Open(m.configPath)
|
||||
file, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
return Root{}, err
|
||||
}
|
||||
|
@ -76,11 +87,6 @@ func (m *FileManager) GetConfig() (Root, error) {
|
|||
return config, nil
|
||||
}
|
||||
|
||||
// Shutdown stops the watcher
|
||||
func (m *FileManager) Shutdown() {
|
||||
m.watcher.Shutdown()
|
||||
}
|
||||
|
||||
// File change notifications from the watcher
|
||||
|
||||
// WatcherItemDidChange triggers when the yaml config is updated
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/log"
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type mockNotifier struct {
|
||||
|
@ -20,17 +17,27 @@ func (n *mockNotifier) ConfigDidUpdate(c Root) {
|
|||
n.configs = append(n.configs, c)
|
||||
}
|
||||
|
||||
func writeConfig(t *testing.T, f *os.File, c *Root) {
|
||||
f.Sync()
|
||||
b, err := yaml.Marshal(c)
|
||||
assert.NoError(t, err)
|
||||
type mockFileWatcher struct {
|
||||
path string
|
||||
notifier watcher.Notification
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
w := bufio.NewWriter(f)
|
||||
_, err = w.Write(b)
|
||||
assert.NoError(t, err)
|
||||
func (w *mockFileWatcher) Start(n watcher.Notification) {
|
||||
w.notifier = n
|
||||
w.ready <- struct{}{}
|
||||
}
|
||||
|
||||
err = w.Flush()
|
||||
assert.NoError(t, err)
|
||||
func (w *mockFileWatcher) Add(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Shutdown() {
|
||||
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) TriggerChange() {
|
||||
w.notifier.WatcherItemDidChange(w.path)
|
||||
}
|
||||
|
||||
func TestConfigChanged(t *testing.T) {
|
||||
|
@ -52,22 +59,24 @@ func TestConfigChanged(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
writeConfig(t, f, c)
|
||||
configRead := func(configPath string) (Root, error) {
|
||||
return *c, nil
|
||||
}
|
||||
wait := make(chan struct{})
|
||||
w := &mockFileWatcher{path: filePath, ready: wait}
|
||||
|
||||
w, err := watcher.NewFile()
|
||||
assert.NoError(t, err)
|
||||
logger := log.CreateLogger()
|
||||
service, err := NewFileManager(w, filePath, logger)
|
||||
service.ReadConfig = configRead
|
||||
assert.NoError(t, err)
|
||||
|
||||
n := &mockNotifier{}
|
||||
go service.Start(n)
|
||||
|
||||
<-wait
|
||||
c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"})
|
||||
writeConfig(t, f, c)
|
||||
w.TriggerChange()
|
||||
|
||||
// give it time to trigger
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
service.Shutdown()
|
||||
|
||||
assert.Len(t, n.configs, 2, "did not get 2 config updates as expected")
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/md5"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Forwarder represents a client side listener to forward traffic to the edge
|
||||
|
@ -19,6 +20,15 @@ type Tunnel struct {
|
|||
ProtocolType string `json:"type"`
|
||||
}
|
||||
|
||||
// DNSResolver represents a client side DNS resolver
|
||||
type DNSResolver struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Address string `json:"address"`
|
||||
Port uint16 `json:"port"`
|
||||
Upstreams []string `json:"upstreams"`
|
||||
Bootstraps []string `json:"bootstraps"`
|
||||
}
|
||||
|
||||
// Root is the base options to configure the service
|
||||
type Root struct {
|
||||
OrgKey string `json:"org_key"`
|
||||
|
@ -26,6 +36,7 @@ type Root struct {
|
|||
CheckinInterval int `json:"checkin_interval"`
|
||||
Forwarders []Forwarder `json:"forwarders,omitempty"`
|
||||
Tunnels []Tunnel `json:"tunnels,omitempty"`
|
||||
Resolver DNSResolver `json:"resolver"`
|
||||
}
|
||||
|
||||
// Hash returns the computed values to see if the forwarder values change
|
||||
|
@ -35,3 +46,51 @@ func (f *Forwarder) Hash() string {
|
|||
io.WriteString(h, f.Listener)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// Hash returns the computed values to see if the forwarder values change
|
||||
func (r *DNSResolver) Hash() string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, r.Address)
|
||||
io.WriteString(h, strings.Join(r.Bootstraps, ","))
|
||||
io.WriteString(h, strings.Join(r.Upstreams, ","))
|
||||
io.WriteString(h, fmt.Sprintf("%d", r.Port))
|
||||
io.WriteString(h, fmt.Sprintf("%v", r.Enabled))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// EnabledOrDefault returns the enabled property
|
||||
func (r *DNSResolver) EnabledOrDefault() bool {
|
||||
return r.Enabled
|
||||
}
|
||||
|
||||
// AddressOrDefault returns the address or returns the default if empty
|
||||
func (r *DNSResolver) AddressOrDefault() string {
|
||||
if r.Address != "" {
|
||||
return r.Address
|
||||
}
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
// PortOrDefault return the port or returns the default if 0
|
||||
func (r *DNSResolver) PortOrDefault() uint16 {
|
||||
if r.Port > 0 {
|
||||
return r.Port
|
||||
}
|
||||
return 53
|
||||
}
|
||||
|
||||
// UpstreamsOrDefault returns the upstreams or returns the default if empty
|
||||
func (r *DNSResolver) UpstreamsOrDefault() []string {
|
||||
if len(r.Upstreams) > 0 {
|
||||
return r.Upstreams
|
||||
}
|
||||
return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"}
|
||||
}
|
||||
|
||||
// BootstrapsOrDefault returns the bootstraps or returns the default if empty
|
||||
func (r *DNSResolver) BootstrapsOrDefault() []string {
|
||||
if len(r.Bootstraps) > 0 {
|
||||
return r.Bootstraps
|
||||
}
|
||||
return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/log"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/overwatch"
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
|
||||
raven "github.com/getsentry/raven-go"
|
||||
|
@ -180,7 +181,9 @@ func handleServiceMode(shutdownC chan struct{}) error {
|
|||
return err
|
||||
}
|
||||
|
||||
appService := NewAppService(configManager, shutdownC, logger)
|
||||
serviceManager := overwatch.NewAppManager(nil)
|
||||
|
||||
appService := NewAppService(configManager, serviceManager, shutdownC, logger)
|
||||
if err := appService.Run(); err != nil {
|
||||
logger.WithError(err).Error("Failed to start app service")
|
||||
return err
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
package overwatch
|
||||
|
||||
// AppManager is the default implementation of overwatch service management
|
||||
type AppManager struct {
|
||||
services map[string]Service
|
||||
errorChan chan error
|
||||
}
|
||||
|
||||
// NewAppManager creates a new overwatch manager
|
||||
func NewAppManager(errorChan chan error) Manager {
|
||||
return &AppManager{services: make(map[string]Service), errorChan: errorChan}
|
||||
}
|
||||
|
||||
// Add takes in a new service to manage.
|
||||
// It stops the service if it already exist in the manager and is running
|
||||
// It then starts the newly added service
|
||||
func (m *AppManager) Add(service Service) {
|
||||
// check for existing service
|
||||
if currentService, ok := m.services[service.Name()]; ok {
|
||||
if currentService.Hash() == service.Hash() {
|
||||
return // the exact same service, no changes, so move along
|
||||
}
|
||||
currentService.Shutdown() //shutdown the listener since a new one is starting
|
||||
}
|
||||
m.services[service.Name()] = service
|
||||
|
||||
//start the service!
|
||||
go m.serviceRun(service)
|
||||
}
|
||||
|
||||
// Remove shutdowns the service by name and removes it from its current management list
|
||||
func (m *AppManager) Remove(name string) {
|
||||
if currentService, ok := m.services[name]; ok {
|
||||
currentService.Shutdown()
|
||||
}
|
||||
delete(m.services, name)
|
||||
}
|
||||
|
||||
// Services returns all the current Services being managed
|
||||
func (m *AppManager) Services() []Service {
|
||||
values := []Service{}
|
||||
for _, value := range m.services {
|
||||
values = append(values, value)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func (m *AppManager) serviceRun(service Service) {
|
||||
err := service.Run()
|
||||
if err != nil && m.errorChan != nil {
|
||||
m.errorChan <- err
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package overwatch
|
||||
|
||||
// Service is the required functions for an object to be managed by the overwatch Manager
|
||||
type Service interface {
|
||||
Name() string
|
||||
Type() string
|
||||
Hash() string
|
||||
Shutdown()
|
||||
Run() error
|
||||
}
|
||||
|
||||
// Manager is based type to manage running services
|
||||
type Manager interface {
|
||||
Add(Service)
|
||||
Remove(string)
|
||||
Services() []Service
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package overwatch
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockService struct {
|
||||
serviceName string
|
||||
serviceType string
|
||||
runError error
|
||||
}
|
||||
|
||||
func (s *mockService) Name() string {
|
||||
return s.serviceName
|
||||
}
|
||||
|
||||
func (s *mockService) Type() string {
|
||||
return s.serviceType
|
||||
}
|
||||
|
||||
func (s *mockService) Hash() string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, s.serviceName)
|
||||
io.WriteString(h, s.serviceType)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
func (s *mockService) Shutdown() {
|
||||
}
|
||||
|
||||
func (s *mockService) Run() error {
|
||||
return s.runError
|
||||
}
|
||||
|
||||
func TestManagerAddAndRemove(t *testing.T) {
|
||||
m := NewAppManager(nil)
|
||||
|
||||
first := &mockService{serviceName: "first", serviceType: "mock"}
|
||||
second := &mockService{serviceName: "second", serviceType: "mock"}
|
||||
m.Add(first)
|
||||
m.Add(second)
|
||||
assert.Len(t, m.Services(), 2, "expected 2 services in the list")
|
||||
|
||||
m.Remove(first.Name())
|
||||
services := m.Services()
|
||||
assert.Len(t, services, 1, "expected 1 service in the list")
|
||||
assert.Equal(t, second.Hash(), services[0].Hash(), "hashes should match. Wrong service was removed")
|
||||
}
|
||||
|
||||
func TestManagerDuplicate(t *testing.T) {
|
||||
m := NewAppManager(nil)
|
||||
|
||||
first := &mockService{serviceName: "first", serviceType: "mock"}
|
||||
m.Add(first)
|
||||
m.Add(first)
|
||||
assert.Len(t, m.Services(), 1, "expected 1 service in the list")
|
||||
}
|
||||
|
||||
func TestManagerErrorChannel(t *testing.T) {
|
||||
errChan := make(chan error)
|
||||
m := NewAppManager(errChan)
|
||||
|
||||
err := errors.New("test error")
|
||||
first := &mockService{serviceName: "first", serviceType: "mock", runError: err}
|
||||
m.Add(first)
|
||||
respErr := <-errChan
|
||||
assert.Equal(t, err, respErr, "errors don't match")
|
||||
}
|
|
@ -2,6 +2,7 @@ package tunneldns
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics/vars"
|
||||
|
@ -12,6 +13,8 @@ import (
|
|||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
var once sync.Once
|
||||
|
||||
// MetricsPlugin is an adapter for CoreDNS and built-in metrics
|
||||
type MetricsPlugin struct {
|
||||
Next plugin.Handler
|
||||
|
@ -19,6 +22,7 @@ type MetricsPlugin struct {
|
|||
|
||||
// NewMetricsPlugin creates a plugin with configured metrics
|
||||
func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
|
||||
once.Do(func() {
|
||||
prometheus.MustRegister(vars.RequestCount)
|
||||
prometheus.MustRegister(vars.RequestDuration)
|
||||
prometheus.MustRegister(vars.RequestSize)
|
||||
|
@ -26,6 +30,7 @@ func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
|
|||
prometheus.MustRegister(vars.RequestType)
|
||||
prometheus.MustRegister(vars.ResponseSize)
|
||||
prometheus.MustRegister(vars.ResponseRcode)
|
||||
})
|
||||
return &MetricsPlugin{Next: next}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue