cloudflared-mirror/supervisor/supervisor_test.go

75 lines
2.3 KiB
Go

package supervisor
import (
"context"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/signal"
)
type mockProtocolSelector struct {
protocols []connection.Protocol
index int
}
func (m *mockProtocolSelector) Current() connection.Protocol {
return m.protocols[m.index]
}
func (m *mockProtocolSelector) Fallback() (connection.Protocol, bool) {
m.index++
if m.index == len(m.protocols) {
return m.protocols[len(m.protocols)-1], false
}
return m.protocols[m.index], true
}
type mockEdgeTunnelServer struct {
config *TunnelConfig
}
func (m *mockEdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
// This is to mock the first connection falling back because of connectivity issues.
protocolFallback.protocol, _ = m.config.ProtocolSelector.Fallback()
connectedSignal.Notify()
return nil
}
// Test to check if initialize sets all the different connections to the same protocol should the first
// tunnel fall back.
func Test_Initialize_Same_Protocol(t *testing.T) {
edgeIPs, err := edgediscovery.ResolveEdge(&zerolog.Logger{}, "us", allregions.Auto)
assert.Nil(t, err)
s := Supervisor{
edgeIPs: edgeIPs,
config: &TunnelConfig{
ProtocolSelector: &mockProtocolSelector{protocols: []connection.Protocol{connection.QUIC, connection.HTTP2, connection.H2mux}},
},
tunnelsProtocolFallback: make(map[int]*protocolFallback),
edgeTunnelServer: &mockEdgeTunnelServer{
config: &TunnelConfig{
ProtocolSelector: &mockProtocolSelector{protocols: []connection.Protocol{connection.QUIC, connection.HTTP2, connection.H2mux}},
},
},
}
ctx := context.Background()
connectedSignal := signal.New(make(chan struct{}))
s.initialize(ctx, connectedSignal)
// Make sure we fell back to http2 as the mock Serve is wont to do.
assert.Equal(t, s.tunnelsProtocolFallback[0].protocol, connection.HTTP2)
// Ensure all the protocols we set to try are the same as what the first tunnel has fallen back to.
for _, protocolFallback := range s.tunnelsProtocolFallback {
assert.Equal(t, protocolFallback.protocol, s.tunnelsProtocolFallback[0].protocol)
}
}