diff --git a/streamhandler/stream_handler.go b/streamhandler/stream_handler.go index 99c49657..9f664301 100644 --- a/streamhandler/stream_handler.go +++ b/streamhandler/stream_handler.go @@ -82,9 +82,16 @@ func (s *StreamHandler) UseConfiguration(ctx context.Context, config *pogs.Clien // UpdateConfig replaces current originmapper mapping with mappings from newConfig func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (failedConfigs []*pogs.FailedConfig) { - // TODO: TUN-1968: Gracefully apply new config - s.tunnelHostnameMapper.DeleteAll() - for _, tunnelConfig := range newConfig { + + // Delete old configs that aren't in the `newConfig` + toRemove := s.tunnelHostnameMapper.ToRemove(newConfig) + for _, hostnameToRemove := range toRemove { + s.tunnelHostnameMapper.Delete(hostnameToRemove) + } + + // Add new configs that weren't in the old mapper + toAdd := s.tunnelHostnameMapper.ToAdd(newConfig) + for _, tunnelConfig := range toAdd { tunnelHostname := tunnelConfig.TunnelHostname originSerice, err := tunnelConfig.OriginConfigJSONHandler.OriginConfig.Service() if err != nil { diff --git a/tunnelhostnamemapper/tunnelhostnamemapper.go b/tunnelhostnamemapper/tunnelhostnamemapper.go index bb8f70f1..554da985 100644 --- a/tunnelhostnamemapper/tunnelhostnamemapper.go +++ b/tunnelhostnamemapper/tunnelhostnamemapper.go @@ -5,6 +5,7 @@ import ( "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/originservice" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) // TunnelHostnameMapper maps TunnelHostname to an OriginService @@ -38,12 +39,55 @@ func (om *TunnelHostnameMapper) Add(key h2mux.TunnelHostname, os originservice.O om.tunnelHostnameToOrigin[key] = os } -// DeleteAll mappings, and shutdown all OriginService -func (om *TunnelHostnameMapper) DeleteAll() { +// Delete a mapping, and shutdown its OriginService +func (om *TunnelHostnameMapper) Delete(key h2mux.TunnelHostname) (keyFound bool) { om.Lock() defer om.Unlock() - for key, os := range om.tunnelHostnameToOrigin { + if os, ok := om.tunnelHostnameToOrigin[key]; ok { os.Shutdown() delete(om.tunnelHostnameToOrigin, key) + return true } + return false +} + +// ToRemove finds all keys that should be removed from the TunnelHostnameMapper. +func (om *TunnelHostnameMapper) ToRemove(newConfigs []*pogs.ReverseProxyConfig) (toRemove []h2mux.TunnelHostname) { + om.Lock() + defer om.Unlock() + + // Convert into a set, for O(1) lookups instead of O(n) + newConfigSet := toSet(newConfigs) + + // If a config in `om` isn't in `newConfigs`, it must be removed. + for hostname := range om.tunnelHostnameToOrigin { + if _, ok := newConfigSet[hostname]; !ok { + toRemove = append(toRemove, hostname) + } + } + + return +} + +// ToAdd filters the given configs, keeping those that should be added to the TunnelHostnameMapper. +func (om *TunnelHostnameMapper) ToAdd(newConfigs []*pogs.ReverseProxyConfig) (toAdd []*pogs.ReverseProxyConfig) { + om.Lock() + defer om.Unlock() + + // If a config in `newConfigs` isn't in `om`, it must be added. + for _, config := range newConfigs { + if _, ok := om.tunnelHostnameToOrigin[config.TunnelHostname]; !ok { + toAdd = append(toAdd, config) + } + } + + return +} + +func toSet(configs []*pogs.ReverseProxyConfig) map[h2mux.TunnelHostname]*pogs.ReverseProxyConfig { + m := make(map[h2mux.TunnelHostname]*pogs.ReverseProxyConfig) + for _, config := range configs { + m[config.TunnelHostname] = config + } + return m } diff --git a/tunnelhostnamemapper/tunnelhostnamemapper_test.go b/tunnelhostnamemapper/tunnelhostnamemapper_test.go index e38d0611..685556fd 100644 --- a/tunnelhostnamemapper/tunnelhostnamemapper_test.go +++ b/tunnelhostnamemapper/tunnelhostnamemapper_test.go @@ -4,11 +4,14 @@ import ( "fmt" "net/http" "net/url" + "reflect" "sync" "testing" + "time" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/originservice" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/stretchr/testify/assert" ) @@ -52,9 +55,6 @@ func TestTunnelHostnameMapperConcurrentAccess(t *testing.T) { assert.True(t, ok) assert.Equal(t, secondHTTPOS, os) }) - - thm.DeleteAll() - assert.Empty(t, thm.tunnelHostnameToOrigin) } func concurrentOps(t *testing.T, f func(i int)) { @@ -72,3 +72,141 @@ func concurrentOps(t *testing.T, f func(i int)) { func tunnelHostname(i int) h2mux.TunnelHostname { return h2mux.TunnelHostname(fmt.Sprintf("%d.cftunnel.com", i)) } + +func Test_toSet(t *testing.T) { + + type args struct { + configs []*pogs.ReverseProxyConfig + } + tests := []struct { + name string + args args + want map[h2mux.TunnelHostname]*pogs.ReverseProxyConfig + }{ + { + name: "empty slice should yield empty map", + args: args{}, + want: map[h2mux.TunnelHostname]*pogs.ReverseProxyConfig{}, + }, + { + name: "multiple elements", + args: args{[]*pogs.ReverseProxyConfig{sampleConfig1(), sampleConfig2()}}, + want: map[h2mux.TunnelHostname]*pogs.ReverseProxyConfig{ + "mock.example.com": sampleConfig1(), + "mock2.example.com": sampleConfig2(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := toSet(tt.args.configs); !reflect.DeepEqual(got, tt.want) { + t.Errorf("toSet() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTunnelHostnameMapper_ToAdd(t *testing.T) { + type fields struct { + tunnelHostnameToOrigin map[h2mux.TunnelHostname]originservice.OriginService + } + type args struct { + newConfigs []*pogs.ReverseProxyConfig + } + tests := []struct { + name string + fields fields + args args + wantToAdd []*pogs.ReverseProxyConfig + }{ + { + name: "Mapper={}, NewConfig={}, toAdd={}", + }, + { + name: "Mapper={}, NewConfig={x}, toAdd={x}", + args: args{newConfigs: []*pogs.ReverseProxyConfig{sampleConfig1()}}, + wantToAdd: []*pogs.ReverseProxyConfig{sampleConfig1()}, + }, + { + name: "Mapper={x}, NewConfig={x,y}, toAdd={y}", + args: args{newConfigs: []*pogs.ReverseProxyConfig{sampleConfig2()}}, + wantToAdd: []*pogs.ReverseProxyConfig{sampleConfig2()}, + fields: fields{tunnelHostnameToOrigin: map[h2mux.TunnelHostname]originservice.OriginService{ + h2mux.TunnelHostname(sampleConfig1().TunnelHostname): &originservice.HelloWorldService{}, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + om := &TunnelHostnameMapper{ + tunnelHostnameToOrigin: tt.fields.tunnelHostnameToOrigin, + } + if gotToAdd := om.ToAdd(tt.args.newConfigs); !reflect.DeepEqual(gotToAdd, tt.wantToAdd) { + t.Errorf("TunnelHostnameMapper.ToAdd() = %v, want %v", gotToAdd, tt.wantToAdd) + } + }) + } +} + +func TestTunnelHostnameMapper_ToRemove(t *testing.T) { + type fields struct { + tunnelHostnameToOrigin map[h2mux.TunnelHostname]originservice.OriginService + } + type args struct { + newConfigs []*pogs.ReverseProxyConfig + } + tests := []struct { + name string + fields fields + args args + wantToRemove []h2mux.TunnelHostname + }{ + { + name: "Mapper={}, NewConfig={}, toRemove={}", + }, + { + name: "Mapper={x}, NewConfig={}, toRemove={x}", + wantToRemove: []h2mux.TunnelHostname{sampleConfig1().TunnelHostname}, + fields: fields{tunnelHostnameToOrigin: map[h2mux.TunnelHostname]originservice.OriginService{ + h2mux.TunnelHostname(sampleConfig1().TunnelHostname): &originservice.HelloWorldService{}, + }}, + }, + { + name: "Mapper={x}, NewConfig={x}, toRemove={}", + args: args{newConfigs: []*pogs.ReverseProxyConfig{sampleConfig1()}}, + fields: fields{tunnelHostnameToOrigin: map[h2mux.TunnelHostname]originservice.OriginService{ + h2mux.TunnelHostname(sampleConfig1().TunnelHostname): &originservice.HelloWorldService{}, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + om := &TunnelHostnameMapper{ + tunnelHostnameToOrigin: tt.fields.tunnelHostnameToOrigin, + } + if gotToRemove := om.ToRemove(tt.args.newConfigs); !reflect.DeepEqual(gotToRemove, tt.wantToRemove) { + t.Errorf("TunnelHostnameMapper.ToRemove() = %v, want %v", gotToRemove, tt.wantToRemove) + } + }) + } +} + +func sampleConfig1() *pogs.ReverseProxyConfig { + return &pogs.ReverseProxyConfig{ + TunnelHostname: "mock.example.com", + OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{OriginConfig: &pogs.HelloWorldOriginConfig{}}, + Retries: 18, + ConnectionTimeout: 5 * time.Second, + CompressionQuality: 3, + } +} + +func sampleConfig2() *pogs.ReverseProxyConfig { + return &pogs.ReverseProxyConfig{ + TunnelHostname: "mock2.example.com", + OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{OriginConfig: &pogs.HelloWorldOriginConfig{}}, + Retries: 18, + ConnectionTimeout: 5 * time.Second, + CompressionQuality: 3, + } +}