package connection

import (
	"fmt"
	"testing"

	"github.com/stretchr/testify/assert"

	"github.com/cloudflare/cloudflared/edgediscovery"
)

const (
	testNoTTL      = 0
	testAccountTag = "testAccountTag"
)

func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) edgediscovery.PercentageFetcher {
	return func() (edgediscovery.ProtocolPercents, error) {
		if getError {
			return nil, fmt.Errorf("failed to fetch percentage")
		}
		return protocolPercent, nil
	}
}

type dynamicMockFetcher struct {
	protocolPercents edgediscovery.ProtocolPercents
	err              error
}

func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher {
	return func() (edgediscovery.ProtocolPercents, error) {
		return dmf.protocolPercents, dmf.err
	}
}

func TestNewProtocolSelector(t *testing.T) {
	tests := []struct {
		name                string
		protocol            string
		tunnelTokenProvided bool
		needPQ              bool
		expectedProtocol    Protocol
		hasFallback         bool
		expectedFallback    Protocol
		wantErr             bool
	}{
		{
			name:     "named tunnel with unknown protocol",
			protocol: "unknown",
			wantErr:  true,
		},
		{
			name:             "named tunnel with h2mux: force to http2",
			protocol:         "h2mux",
			expectedProtocol: HTTP2,
		},
		{
			name:             "named tunnel with http2: no fallback",
			protocol:         "http2",
			expectedProtocol: HTTP2,
		},
		{
			name:             "named tunnel with auto: quic",
			protocol:         AutoSelectFlag,
			expectedProtocol: QUIC,
			hasFallback:      true,
			expectedFallback: HTTP2,
		},
		{
			name:             "named tunnel (post quantum)",
			protocol:         AutoSelectFlag,
			needPQ:           true,
			expectedProtocol: QUIC,
		},
		{
			name:             "named tunnel (post quantum) w/http2",
			protocol:         "http2",
			needPQ:           true,
			expectedProtocol: QUIC,
		},
	}

	fetcher := dynamicMockFetcher{
		protocolPercents: edgediscovery.ProtocolPercents{},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			selector, err := NewProtocolSelector(test.protocol, testAccountTag, test.tunnelTokenProvided, test.needPQ, fetcher.fetch(), ResolveTTL, &log)
			if test.wantErr {
				assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
			} else {
				assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
				assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
				fallback, ok := selector.Fallback()
				assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
				if test.hasFallback {
					assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
				}
			}
		})
	}
}

func TestAutoProtocolSelectorRefresh(t *testing.T) {
	fetcher := dynamicMockFetcher{}
	selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
	assert.NoError(t, err)
	assert.Equal(t, QUIC, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
	assert.Equal(t, QUIC, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.err = fmt.Errorf("failed to fetch")
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
	fetcher.err = nil
	assert.Equal(t, QUIC, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
	assert.Equal(t, QUIC, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
	assert.Equal(t, QUIC, selector.Current())
}

func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
	fetcher := dynamicMockFetcher{}
	// Since the user chooses http2 on purpose, we always stick to it.
	selector, err := NewProtocolSelector(HTTP2.String(), testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
	assert.NoError(t, err)
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.err = fmt.Errorf("failed to fetch")
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
	fetcher.err = nil
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
	assert.Equal(t, HTTP2, selector.Current())
}

func TestAutoProtocolSelectorNoRefreshWithToken(t *testing.T) {
	fetcher := dynamicMockFetcher{}
	selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, true, false, fetcher.fetch(), testNoTTL, &log)
	assert.NoError(t, err)
	assert.Equal(t, QUIC, selector.Current())

	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
	assert.Equal(t, QUIC, selector.Current())
}