This commit is contained in:
Aleksei Sviridkin 2026-02-24 12:03:14 -05:00 committed by GitHub
commit 5cb25841c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1065 additions and 10 deletions

View File

@ -372,7 +372,15 @@ func StartServer(
ctx, cancel := context.WithCancel(c.Context)
defer cancel()
go waitForSignal(graceShutdownC, log)
// reloadC is used to trigger configuration reloads via SIGHUP.
// Channel is created here but waitForSignal is started later, after localWatcher
// is ready to consume from reloadC (to avoid race condition).
var reloadC chan struct{}
configPath := c.String("config")
if configPath != "" && c.String(TunnelTokenFlag) == "" {
// Only enable hot reload for locally configured tunnels (not token-based)
reloadC = make(chan struct{}, 1)
}
connectedSignal := signal.New(make(chan struct{}))
go notifySystemd(connectedSignal)
@ -445,6 +453,20 @@ func StartServer(
return err
}
// Start local config watcher for hot reload if enabled
if reloadC != nil {
localWatcher := orchestration.NewLocalConfigWatcher(orchestrator, configPath, log)
readyC := localWatcher.Run(ctx, reloadC)
<-readyC // Wait until watcher is ready to receive signals
} else if configPath == "" {
log.Debug().Msg("Configuration hot reload disabled: no config file specified")
} else {
log.Debug().Msg("Configuration hot reload disabled: token-based tunnel")
}
// Start signal handler after localWatcher is ready to avoid race condition
go waitForSignal(graceShutdownC, reloadC, log)
metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics"))
if err != nil {
log.Err(err).Msg("Error opening metrics server listener")

View File

@ -8,16 +8,36 @@ import (
"github.com/rs/zerolog"
)
// waitForSignal closes graceShutdownC to indicate that we should start graceful shutdown sequence
func waitForSignal(graceShutdownC chan struct{}, logger *zerolog.Logger) {
// waitForSignal handles OS signals for graceful shutdown and configuration reload.
// It closes graceShutdownC on SIGTERM/SIGINT to trigger graceful shutdown.
// If reloadC is provided, SIGHUP will send a reload signal instead of being ignored.
func waitForSignal(graceShutdownC chan struct{}, reloadC chan<- struct{}, logger *zerolog.Logger) {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
defer signal.Stop(signals)
select {
case s := <-signals:
logger.Info().Msgf("Initiating graceful shutdown due to signal %s ...", s)
close(graceShutdownC)
case <-graceShutdownC:
for {
select {
case s := <-signals:
switch s {
case syscall.SIGHUP:
if reloadC != nil {
logger.Info().Msg("Received SIGHUP, triggering configuration reload")
select {
case reloadC <- struct{}{}:
default:
logger.Warn().Msg("Configuration reload already in progress, skipping")
}
} else {
logger.Info().Msg("Received SIGHUP but hot reload is not enabled for this tunnel")
}
case syscall.SIGTERM, syscall.SIGINT:
logger.Info().Msgf("Initiating graceful shutdown due to signal %s ...", s)
close(graceShutdownC)
return
}
case <-graceShutdownC:
return
}
}
}

View File

@ -52,11 +52,123 @@ func TestSignalShutdown(t *testing.T) {
}
})
waitForSignal(graceShutdownC, &log)
waitForSignal(graceShutdownC, nil, &log)
assert.True(t, channelClosed(graceShutdownC))
}
}
func TestSignalSIGHUP_WithReloadChannel(t *testing.T) {
log := zerolog.Nop()
graceShutdownC := make(chan struct{})
reloadC := make(chan struct{}, 1)
go func() {
// sleep for a tick to prevent sending signal before calling waitForSignal
time.Sleep(tick)
_ = syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
// Give time for signal to be processed
time.Sleep(tick)
// Send SIGTERM to exit waitForSignal
_ = syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
}()
time.AfterFunc(time.Second, func() {
select {
case <-graceShutdownC:
default:
close(graceShutdownC)
t.Fatal("waitForSignal timed out")
}
})
waitForSignal(graceShutdownC, reloadC, &log)
// Check that reload signal was received
select {
case <-reloadC:
// Expected - SIGHUP should trigger reload
default:
t.Fatal("Expected reload channel to receive signal from SIGHUP")
}
}
func TestSignalSIGHUP_WithoutReloadChannel(t *testing.T) {
log := zerolog.Nop()
graceShutdownC := make(chan struct{})
go func() {
// sleep for a tick to prevent sending signal before calling waitForSignal
time.Sleep(tick)
// Send SIGHUP without reload channel - should be ignored
_ = syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
time.Sleep(tick)
// Send SIGTERM to exit waitForSignal
_ = syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
}()
time.AfterFunc(time.Second, func() {
select {
case <-graceShutdownC:
default:
close(graceShutdownC)
t.Fatal("waitForSignal timed out")
}
})
// Should complete without panic or deadlock
waitForSignal(graceShutdownC, nil, &log)
assert.True(t, channelClosed(graceShutdownC))
}
func TestSignalSIGHUP_ReloadInProgress(t *testing.T) {
log := zerolog.Nop()
graceShutdownC := make(chan struct{})
// Create buffered channel and fill it
reloadC := make(chan struct{}, 1)
reloadC <- struct{}{} // Pre-fill to simulate reload in progress
go func() {
// sleep for a tick to prevent sending signal before calling waitForSignal
time.Sleep(tick)
// Send SIGHUP while reload is "in progress"
_ = syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
time.Sleep(tick)
// Send SIGTERM to exit waitForSignal
_ = syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
}()
time.AfterFunc(time.Second, func() {
select {
case <-graceShutdownC:
default:
close(graceShutdownC)
t.Fatal("waitForSignal timed out")
}
})
// Should complete without blocking (non-blocking send)
waitForSignal(graceShutdownC, reloadC, &log)
// Channel should still have exactly one signal (the pre-filled one)
select {
case <-reloadC:
// Expected - drain the one signal
default:
t.Fatal("Expected reload channel to have signal")
}
// Should be empty now
select {
case <-reloadC:
t.Fatal("Expected reload channel to be empty after draining")
default:
// Expected - channel is empty
}
}
func TestWaitForShutdown(t *testing.T) {
log := zerolog.Nop()

View File

@ -0,0 +1,92 @@
package orchestration
import (
"encoding/json"
"os"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress"
)
// LocalConfigJSON represents the JSON format expected by Orchestrator.UpdateConfig.
// It mirrors ingress.RemoteConfigJSON structure.
type LocalConfigJSON struct {
GlobalOriginRequest *config.OriginRequestConfig `json:"originRequest,omitempty"`
IngressRules []config.UnvalidatedIngressRule `json:"ingress"`
WarpRouting config.WarpRoutingConfig `json:"warp-routing"`
}
// ReadLocalConfig reads and parses the local YAML configuration file.
func ReadLocalConfig(configPath string) (*config.Configuration, error) {
file, err := os.Open(configPath)
if err != nil {
return nil, errors.Wrapf(err, "failed to open config file %s", configPath)
}
defer file.Close()
var cfg config.Configuration
if err := yaml.NewDecoder(file).Decode(&cfg); err != nil {
return nil, errors.Wrapf(err, "failed to parse YAML config file %s", configPath)
}
return &cfg, nil
}
// ConvertLocalConfigToJSON converts local YAML configuration to JSON format
// expected by Orchestrator.UpdateConfig.
func ConvertLocalConfigToJSON(cfg *config.Configuration) ([]byte, error) {
if cfg == nil {
return nil, errors.New("config cannot be nil")
}
localJSON := LocalConfigJSON{
GlobalOriginRequest: &cfg.OriginRequest,
IngressRules: cfg.Ingress,
WarpRouting: cfg.WarpRouting,
}
data, err := json.Marshal(localJSON)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal config to JSON")
}
return data, nil
}
// ValidateLocalConfig validates the local configuration by attempting to parse
// ingress rules. Returns nil if valid.
func ValidateLocalConfig(cfg *config.Configuration) error {
_, err := ConvertAndValidateLocalConfig(cfg)
return err
}
// ConvertAndValidateLocalConfig converts local config to JSON and validates it
// in a single pass. Returns JSON bytes if valid, error otherwise.
func ConvertAndValidateLocalConfig(cfg *config.Configuration) ([]byte, error) {
data, err := ConvertLocalConfigToJSON(cfg)
if err != nil {
return nil, err
}
// Skip validation if no ingress rules
if len(cfg.Ingress) == 0 {
return data, nil
}
// Validate catch-all rule exists (last rule must have empty hostname or "*")
lastRule := cfg.Ingress[len(cfg.Ingress)-1]
if lastRule.Hostname != "" && lastRule.Hostname != "*" {
return nil, errors.New("ingress rules must end with a catch-all rule (empty hostname or '*')")
}
// Validate by attempting to parse as RemoteConfig
var remoteConfig ingress.RemoteConfig
if err := json.Unmarshal(data, &remoteConfig); err != nil {
return nil, errors.Wrap(err, "invalid ingress configuration")
}
return data, nil
}

View File

@ -0,0 +1,187 @@
package orchestration
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress"
)
func TestConvertLocalConfigToJSON(t *testing.T) {
connectTimeout := config.CustomDuration{Duration: 30 * time.Second}
tlsTimeout := config.CustomDuration{Duration: 10 * time.Second}
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "example.com",
Service: "http://localhost:8080",
},
{
Hostname: "*",
Service: "http://localhost:8081",
},
},
WarpRouting: config.WarpRoutingConfig{
ConnectTimeout: &connectTimeout,
},
OriginRequest: config.OriginRequestConfig{
ConnectTimeout: &connectTimeout,
TLSTimeout: &tlsTimeout,
},
}
jsonData, err := ConvertLocalConfigToJSON(cfg)
require.NoError(t, err)
require.NotEmpty(t, jsonData)
var remoteConfig ingress.RemoteConfig
err = json.Unmarshal(jsonData, &remoteConfig)
require.NoError(t, err)
require.Len(t, remoteConfig.Ingress.Rules, 2)
require.Equal(t, "example.com", remoteConfig.Ingress.Rules[0].Hostname)
require.Equal(t, "*", remoteConfig.Ingress.Rules[1].Hostname)
}
func TestConvertLocalConfigToJSON_EmptyIngress(t *testing.T) {
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{},
}
jsonData, err := ConvertLocalConfigToJSON(cfg)
require.NoError(t, err)
require.NotEmpty(t, jsonData)
var localJSON LocalConfigJSON
err = json.Unmarshal(jsonData, &localJSON)
require.NoError(t, err)
require.Empty(t, localJSON.IngressRules)
}
func TestValidateLocalConfig_Valid(t *testing.T) {
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "example.com",
Service: "http://localhost:8080",
},
{
Service: "http_status:404",
},
},
}
err := ValidateLocalConfig(cfg)
require.NoError(t, err)
}
func TestValidateLocalConfig_WildcardCatchAll(t *testing.T) {
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "example.com",
Service: "http://localhost:8080",
},
{
Hostname: "*",
Service: "http_status:404",
},
},
}
err := ValidateLocalConfig(cfg)
require.NoError(t, err)
}
func TestValidateLocalConfig_MissingCatchAll(t *testing.T) {
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "example.com",
Service: "http://localhost:8080",
},
},
}
err := ValidateLocalConfig(cfg)
require.Error(t, err)
require.Contains(t, err.Error(), "catch-all")
}
func TestValidateLocalConfig_EmptyIngress(t *testing.T) {
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{},
}
err := ValidateLocalConfig(cfg)
require.NoError(t, err)
}
func TestValidateLocalConfig_InvalidService(t *testing.T) {
cfg := &config.Configuration{
TunnelID: "test-tunnel-id",
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "example.com",
Service: "not-a-valid-url",
},
},
}
err := ValidateLocalConfig(cfg)
require.Error(t, err)
}
func TestReadLocalConfig(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
configContent := `
tunnel: test-tunnel-id
ingress:
- hostname: example.com
service: http://localhost:8080
- service: http_status:404
warp-routing:
connectTimeout: 5s
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
cfg, err := ReadLocalConfig(configPath)
require.NoError(t, err)
require.Equal(t, "test-tunnel-id", cfg.TunnelID)
require.Len(t, cfg.Ingress, 2)
require.Equal(t, "example.com", cfg.Ingress[0].Hostname)
require.NotNil(t, cfg.WarpRouting.ConnectTimeout)
require.Equal(t, 5*time.Second, cfg.WarpRouting.ConnectTimeout.Duration)
}
func TestReadLocalConfig_FileNotFound(t *testing.T) {
_, err := ReadLocalConfig("/nonexistent/path/config.yaml")
require.Error(t, err)
}
func TestReadLocalConfig_InvalidYAML(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
err := os.WriteFile(configPath, []byte("invalid: yaml: content: ["), 0o600)
require.NoError(t, err)
_, err = ReadLocalConfig(configPath)
require.Error(t, err)
}

View File

@ -0,0 +1,319 @@
package orchestration
import (
"context"
"os"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/watcher"
)
const (
// debounceInterval is the time to wait after a file change before reloading.
// This prevents multiple rapid reloads when editors save files multiple times.
debounceInterval = 500 * time.Millisecond
// pollInterval is the interval for polling file changes as a fallback.
// This handles cases where fsnotify stops working (e.g., file replaced via
// symlink rotation, Kubernetes ConfigMap updates).
pollInterval = 30 * time.Second
// localConfigVersionStart is the starting version for local config updates.
// Local config uses high positive versions (1_000_000+) to avoid conflicts with
// remote config versions (0, 1, 2, ...). At typical change rates (<100/day),
// collision would require decades of continuous operation.
localConfigVersionStart int32 = 1_000_000
// maxReloadRetries limits consecutive reloads when config keeps changing.
// This prevents infinite loops if the file is constantly being modified.
maxReloadRetries = 3
)
// LocalConfigWatcher watches a local configuration file for changes and updates
// the Orchestrator when changes are detected. It supports both automatic file
// watching via fsnotify and manual reload via SIGHUP signal.
//
// The watcher uses a hybrid approach: fsnotify for immediate notifications plus
// periodic polling as a fallback. This ensures config changes are detected even
// when fsnotify fails (e.g., file replaced via symlink, Kubernetes ConfigMap).
type LocalConfigWatcher struct {
orchestrator *Orchestrator
configPath string
log *zerolog.Logger
// mu protects version, lastModTime and serializes reload operations
mu sync.Mutex
version int32
lastModTime time.Time
reloadChan chan struct{}
}
// NewLocalConfigWatcher creates a new LocalConfigWatcher.
// Panics if orchestrator is nil (programming error, not recoverable).
func NewLocalConfigWatcher(
orchestrator *Orchestrator,
configPath string,
log *zerolog.Logger,
) *LocalConfigWatcher {
if orchestrator == nil {
panic("orchestrator cannot be nil")
}
return &LocalConfigWatcher{
orchestrator: orchestrator,
configPath: configPath,
log: log,
version: localConfigVersionStart,
reloadChan: make(chan struct{}, 1),
}
}
// Run starts the config watcher. It watches for file changes and listens
// for manual reload signals on reloadC.
//
// Returns a channel that is closed when the watcher is ready to receive signals.
// Callers should wait on this channel before starting the signal handler to avoid
// race conditions where signals arrive before the watcher is listening.
func (w *LocalConfigWatcher) Run(ctx context.Context, reloadC <-chan struct{}) <-chan struct{} {
readyC := make(chan struct{})
fileWatcher, err := watcher.NewFile()
if err != nil {
w.log.Warn().Err(err).Msg("Failed to create file watcher, falling back to SIGHUP only")
go func() {
w.log.Info().Str("config", w.configPath).Msg("Configuration reload available via SIGHUP signal")
close(readyC)
w.runWithoutFileWatcher(ctx, reloadC)
}()
return readyC
}
if err := fileWatcher.Add(w.configPath); err != nil {
w.log.Warn().Err(err).Str("config", w.configPath).Msg("Failed to watch config file, falling back to SIGHUP only")
go func() {
w.log.Info().Str("config", w.configPath).Msg("Configuration reload available via SIGHUP signal")
close(readyC)
w.runWithoutFileWatcher(ctx, reloadC)
}()
return readyC
}
w.log.Info().Str("config", w.configPath).Msg("Started watching configuration file for changes")
go fileWatcher.Start(w)
// Initialize lastModTime before signaling ready to avoid race with early SIGHUP
w.initLastModTime()
go func() {
close(readyC)
w.runLoop(ctx, reloadC, fileWatcher)
}()
return readyC
}
// runWithoutFileWatcher runs the watcher loop without file watching.
// Only manual SIGHUP reloads will work.
func (w *LocalConfigWatcher) runWithoutFileWatcher(ctx context.Context, reloadC <-chan struct{}) {
for {
select {
case <-ctx.Done():
return
case <-reloadC:
w.doReload()
}
}
}
// runLoop is the main event loop that handles file changes and reload signals.
func (w *LocalConfigWatcher) runLoop(ctx context.Context, reloadC <-chan struct{}, fileWatcher *watcher.File) {
// Use a stopped timer initially; we'll reset it when file changes occur
debounceTimer := time.NewTimer(0)
if !debounceTimer.Stop() {
<-debounceTimer.C
}
debounceActive := false
// Poll timer as fallback for when fsnotify misses changes
pollTicker := time.NewTicker(pollInterval)
defer func() {
debounceTimer.Stop()
pollTicker.Stop()
fileWatcher.Shutdown()
}()
for {
select {
case <-ctx.Done():
return
case <-reloadC:
w.log.Info().Msg("Received reload signal")
w.doReload()
case <-w.reloadChan:
// Stop existing timer and drain if necessary.
// If Stop() returns false, timer already expired and channel has value.
if !debounceTimer.Stop() && debounceActive {
<-debounceTimer.C
}
debounceTimer.Reset(debounceInterval)
debounceActive = true
case <-debounceTimer.C:
debounceActive = false
w.doReload()
case <-pollTicker.C:
// Fallback polling for when fsnotify misses changes (e.g., symlink rotation)
if w.checkFileChanged() {
w.log.Debug().Msg("Poll detected config file change")
w.doReload()
}
}
}
}
// initLastModTime initializes the lastModTime field from the current file state.
func (w *LocalConfigWatcher) initLastModTime() {
info, err := os.Stat(w.configPath)
if err != nil {
return
}
w.mu.Lock()
w.lastModTime = info.ModTime()
w.mu.Unlock()
}
// checkFileChanged checks if the config file has been modified since last check.
// Returns true if the file changed, false otherwise.
func (w *LocalConfigWatcher) checkFileChanged() bool {
info, err := os.Stat(w.configPath)
if err != nil {
return false
}
w.mu.Lock()
defer w.mu.Unlock()
modTime := info.ModTime()
if modTime.After(w.lastModTime) {
w.lastModTime = modTime
return true
}
return false
}
// getModTime returns the modification time of the config file.
// Returns zero time if file cannot be stat'd.
// Note: No lock needed - this reads from disk, not from struct fields.
// The lastModTime field is protected by mu where it's accessed.
func (w *LocalConfigWatcher) getModTime() time.Time {
info, err := os.Stat(w.configPath)
if err != nil {
return time.Time{}
}
return info.ModTime()
}
// WatcherItemDidChange implements watcher.Notification interface.
// Called when the config file is modified.
func (w *LocalConfigWatcher) WatcherItemDidChange(filepath string) {
w.log.Debug().Str("file", filepath).Msg("Config file changed, scheduling reload")
select {
case w.reloadChan <- struct{}{}:
default:
}
}
// WatcherDidError implements watcher.Notification interface.
// Called when the file watcher encounters an error.
//
// Note: If the config file is deleted and recreated (e.g., during deployment via symlink
// rotation), the file watcher may stop working. In this case, SIGHUP can still be used
// for manual reloads, or cloudflared can be restarted.
func (w *LocalConfigWatcher) WatcherDidError(err error) {
if os.IsNotExist(err) {
w.log.Warn().Str("config", w.configPath).
Msg("Config file was deleted or moved, keeping current configuration")
} else {
w.log.Error().Err(err).Str("config", w.configPath).
Msg("Config file watcher error, keeping current configuration")
}
}
// doReload performs the actual configuration reload.
// Uses TryLock to skip if another reload is already in progress.
// If the config file changes during reload, it will retry up to maxReloadRetries times.
func (w *LocalConfigWatcher) doReload() {
if !w.mu.TryLock() {
w.log.Info().Msg("Reload already in progress, skipping")
return
}
defer w.mu.Unlock()
for i := range maxReloadRetries {
startModTime := w.getModTime()
cfg, err := ReadLocalConfig(w.configPath)
if err != nil {
w.log.Error().Err(err).Str("config", w.configPath).
Msg("Failed to read config file, keeping current configuration")
return
}
configJSON, err := ConvertAndValidateLocalConfig(cfg)
if err != nil {
w.log.Error().Err(err).Msg("Invalid configuration, keeping current configuration")
return
}
nextVersion := w.version + 1
resp := w.orchestrator.UpdateConfig(nextVersion, configJSON)
if resp.Err != nil {
w.log.Error().Err(resp.Err).Int32("version", nextVersion).
Msg("Orchestrator rejected configuration update")
return
}
w.version = resp.LastAppliedVersion
// Get mtime once to avoid TOCTOU race
currentModTime := w.getModTime()
w.lastModTime = currentModTime
w.log.Info().Int32("version", resp.LastAppliedVersion).
Msg("Configuration reloaded successfully")
// Check if file changed during reload (using same mtime value)
if !currentModTime.After(startModTime) {
return // No changes during reload, done
}
if i < maxReloadRetries-1 {
w.log.Debug().Msg("Config file changed during reload, reloading again")
}
}
w.log.Warn().Int("retries", maxReloadRetries).
Msg("Config file keeps changing, giving up after max retries")
}
// ReloadConfig triggers a manual configuration reload.
// This is useful for programmatic reloads without SIGHUP.
func (w *LocalConfigWatcher) ReloadConfig() {
w.doReload()
}
// Version returns the current config version (thread-safe).
func (w *LocalConfigWatcher) Version() int32 {
w.mu.Lock()
defer w.mu.Unlock()
return w.version
}

View File

@ -0,0 +1,303 @@
package orchestration
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress"
)
func TestNewLocalConfigWatcher(t *testing.T) {
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, "/tmp/config.yaml", &log)
require.NotNil(t, watcher)
require.Equal(t, "/tmp/config.yaml", watcher.configPath)
require.Equal(t, int32(localConfigVersionStart), watcher.Version())
}
func TestLocalConfigWatcher_ReloadConfig(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
configContent := `
tunnel: test-tunnel-id
ingress:
- hostname: example.com
service: http://localhost:8080
- service: http_status:404
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
watcher.ReloadConfig()
require.Equal(t, int32(localConfigVersionStart+1), watcher.Version())
}
func TestLocalConfigWatcher_ReloadConfig_InvalidYAML(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
err := os.WriteFile(configPath, []byte("invalid: yaml: ["), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
watcher.ReloadConfig()
require.Equal(t, int32(localConfigVersionStart), watcher.Version())
}
func TestLocalConfigWatcher_ReloadConfig_InvalidIngress(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
// Missing catch-all rule (no empty hostname at end)
configContent := `
tunnel: test-tunnel-id
ingress:
- hostname: example.com
service: http://localhost:8080
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
watcher.ReloadConfig()
require.Equal(t, int32(localConfigVersionStart), watcher.Version())
}
func TestLocalConfigWatcher_WatcherItemDidChange(t *testing.T) {
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, "/tmp/config.yaml", &log)
watcher.WatcherItemDidChange("/tmp/config.yaml")
select {
case <-watcher.reloadChan:
default:
t.Fatal("Expected reload channel to receive signal")
}
}
func TestLocalConfigWatcher_WatcherItemDidChange_NonBlocking(t *testing.T) {
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, "/tmp/config.yaml", &log)
watcher.reloadChan <- struct{}{}
watcher.WatcherItemDidChange("/tmp/config.yaml")
select {
case <-watcher.reloadChan:
default:
t.Fatal("Expected reload channel to have signal")
}
}
func TestLocalConfigWatcher_Run_ManualReload(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
configContent := `
tunnel: test-tunnel-id
ingress:
- hostname: example.com
service: http://localhost:8080
- service: http_status:404
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
reloadC := make(chan struct{}, 1)
readyC := watcher.Run(ctx, reloadC)
<-readyC // Wait until watcher is ready
// Send reload signal
reloadC <- struct{}{}
// Wait for version to increment
require.Eventually(t, func() bool {
return watcher.Version() >= localConfigVersionStart+1
}, 2*time.Second, 10*time.Millisecond, "version should be incremented after reload")
}
func TestLocalConfigWatcher_Run_FileChange(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
configContent := `
tunnel: test-tunnel-id
ingress:
- hostname: example.com
service: http://localhost:8080
- service: http_status:404
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
reloadC := make(chan struct{}, 1)
readyC := watcher.Run(ctx, reloadC)
<-readyC // Wait until watcher is ready
newConfigContent := `
tunnel: test-tunnel-id
ingress:
- hostname: new-example.com
service: http://localhost:9090
- service: http_status:404
`
// Write the config file. We may need to write multiple times because fsnotify
// may not have started watching yet. We write with increasing delays to allow
// the debounce timer (500ms) to fire between writes.
written := false
for range 5 {
err = os.WriteFile(configPath, []byte(newConfigContent), 0o600)
require.NoError(t, err)
written = true
// Wait longer than debounce interval to allow reload to happen
time.Sleep(600 * time.Millisecond)
if watcher.Version() >= localConfigVersionStart+1 {
break
}
}
require.True(t, written, "should have written config file")
require.GreaterOrEqual(t, watcher.Version(), int32(localConfigVersionStart+1), "version should be incremented after file change")
}
func TestLocalConfigWatcher_ConcurrentReloads(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
configContent := `
tunnel: test-tunnel-id
ingress:
- hostname: example.com
service: http://localhost:8080
- service: http_status:404
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
// Run multiple concurrent reloads
const numGoroutines = 10
var wg sync.WaitGroup
wg.Add(numGoroutines)
for range numGoroutines {
go func() {
defer wg.Done()
watcher.ReloadConfig()
}()
}
wg.Wait()
// With TryLock, concurrent reloads are skipped if one is already in progress.
// At least one reload should succeed (version >= start+1).
// Due to TryLock skipping, version likely won't reach start+numGoroutines.
finalVersion := watcher.Version()
require.GreaterOrEqual(t, finalVersion, int32(localConfigVersionStart+1),
"At least one reload should have succeeded")
require.LessOrEqual(t, finalVersion, int32(localConfigVersionStart+numGoroutines),
"Version should not exceed expected reloads")
}
func TestLocalConfigWatcher_Run_ContextCancellation(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.yaml")
configContent := `
tunnel: test-tunnel-id
ingress:
- service: http_status:404
`
err := os.WriteFile(configPath, []byte(configContent), 0o600)
require.NoError(t, err)
log := zerolog.Nop()
orchestrator := createTestOrchestrator(t)
watcher := NewLocalConfigWatcher(orchestrator, configPath, &log)
ctx, cancel := context.WithCancel(context.Background())
reloadC := make(chan struct{}, 1)
readyC := watcher.Run(ctx, reloadC)
<-readyC
// Cancel context and verify watcher stops without panic or hang
cancel()
time.Sleep(50 * time.Millisecond)
}
func createTestOrchestrator(t *testing.T) *Orchestrator {
t.Helper()
log := zerolog.Nop()
originDialer := ingress.NewOriginDialer(ingress.OriginConfig{
DefaultDialer: ingress.NewDialer(ingress.WarpRoutingConfig{
ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second},
TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second},
MaxActiveFlows: 0,
}),
TCPWriteTimeout: 1 * time.Second,
}, &log)
initConfig := &Config{
Ingress: &ingress.Ingress{},
OriginDialerService: originDialer,
}
orchestrator, err := NewOrchestrator(t.Context(), initConfig, nil, []ingress.Rule{}, &log)
require.NoError(t, err)
return orchestrator
}