75 lines
2.3 KiB
Go
75 lines
2.3 KiB
Go
package supervisor
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/cloudflare/cloudflared/connection"
|
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
|
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
|
"github.com/cloudflare/cloudflared/signal"
|
|
)
|
|
|
|
type mockProtocolSelector struct {
|
|
protocols []connection.Protocol
|
|
index int
|
|
}
|
|
|
|
func (m *mockProtocolSelector) Current() connection.Protocol {
|
|
return m.protocols[m.index]
|
|
}
|
|
|
|
func (m *mockProtocolSelector) Fallback() (connection.Protocol, bool) {
|
|
m.index++
|
|
if m.index == len(m.protocols) {
|
|
return m.protocols[len(m.protocols)-1], false
|
|
}
|
|
|
|
return m.protocols[m.index], true
|
|
}
|
|
|
|
type mockEdgeTunnelServer struct {
|
|
config *TunnelConfig
|
|
}
|
|
|
|
func (m *mockEdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
|
|
// This is to mock the first connection falling back because of connectivity issues.
|
|
protocolFallback.protocol, _ = m.config.ProtocolSelector.Fallback()
|
|
connectedSignal.Notify()
|
|
return nil
|
|
}
|
|
|
|
// Test to check if initialize sets all the different connections to the same protocol should the first
|
|
// tunnel fall back.
|
|
func Test_Initialize_Same_Protocol(t *testing.T) {
|
|
edgeIPs, err := edgediscovery.ResolveEdge(&zerolog.Logger{}, "us", allregions.Auto)
|
|
assert.Nil(t, err)
|
|
s := Supervisor{
|
|
edgeIPs: edgeIPs,
|
|
config: &TunnelConfig{
|
|
ProtocolSelector: &mockProtocolSelector{protocols: []connection.Protocol{connection.QUIC, connection.HTTP2, connection.H2mux}},
|
|
},
|
|
tunnelsProtocolFallback: make(map[int]*protocolFallback),
|
|
edgeTunnelServer: &mockEdgeTunnelServer{
|
|
config: &TunnelConfig{
|
|
ProtocolSelector: &mockProtocolSelector{protocols: []connection.Protocol{connection.QUIC, connection.HTTP2, connection.H2mux}},
|
|
},
|
|
},
|
|
}
|
|
|
|
ctx := context.Background()
|
|
connectedSignal := signal.New(make(chan struct{}))
|
|
s.initialize(ctx, connectedSignal)
|
|
|
|
// Make sure we fell back to http2 as the mock Serve is wont to do.
|
|
assert.Equal(t, s.tunnelsProtocolFallback[0].protocol, connection.HTTP2)
|
|
|
|
// Ensure all the protocols we set to try are the same as what the first tunnel has fallen back to.
|
|
for _, protocolFallback := range s.tunnelsProtocolFallback {
|
|
assert.Equal(t, protocolFallback.protocol, s.tunnelsProtocolFallback[0].protocol)
|
|
}
|
|
}
|