package connection

import (
	"fmt"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

const (
	testNoTTL            = 0
	noWarpRoutingEnabled = false
)

var (
	testNamedTunnelConfig = &NamedTunnelConfig{
		Credentials: Credentials{
			AccountTag: "testAccountTag",
		},
	}
)

func mockFetcher(percentage int32) PercentageFetcher {
	return func() (int32, error) {
		return percentage, nil
	}
}

func mockFetcherWithError() PercentageFetcher {
	return func() (int32, error) {
		return 0, fmt.Errorf("failed to fetch precentage")
	}
}

type dynamicMockFetcher struct {
	percentage int32
	err        error
}

func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
	return func() (int32, error) {
		if dmf.err != nil {
			return 0, dmf.err
		}
		return dmf.percentage, nil
	}
}

func TestNewProtocolSelector(t *testing.T) {
	tests := []struct {
		name               string
		protocol           string
		expectedProtocol   Protocol
		hasFallback        bool
		expectedFallback   Protocol
		warpRoutingEnabled bool
		namedTunnelConfig  *NamedTunnelConfig
		fetchFunc          PercentageFetcher
		wantErr            bool
	}{
		{
			name:              "classic tunnel",
			protocol:          "h2mux",
			expectedProtocol:  H2mux,
			namedTunnelConfig: nil,
		},
		{
			name:              "named tunnel over h2mux",
			protocol:          "h2mux",
			expectedProtocol:  H2mux,
			namedTunnelConfig: testNamedTunnelConfig,
		},
		{
			name:              "named tunnel over http2",
			protocol:          "http2",
			expectedProtocol:  HTTP2,
			hasFallback:       true,
			expectedFallback:  H2mux,
			fetchFunc:         mockFetcher(0),
			namedTunnelConfig: testNamedTunnelConfig,
		},
		{
			name:              "named tunnel http2 disabled",
			protocol:          "http2",
			expectedProtocol:  H2mux,
			fetchFunc:         mockFetcher(-1),
			namedTunnelConfig: testNamedTunnelConfig,
		},
		{
			name:              "named tunnel auto all http2 disabled",
			protocol:          "auto",
			expectedProtocol:  H2mux,
			fetchFunc:         mockFetcher(-1),
			namedTunnelConfig: testNamedTunnelConfig,
		},
		{
			name:              "named tunnel auto to h2mux",
			protocol:          "auto",
			expectedProtocol:  H2mux,
			fetchFunc:         mockFetcher(0),
			namedTunnelConfig: testNamedTunnelConfig,
		},
		{
			name:              "named tunnel auto to http2",
			protocol:          "auto",
			expectedProtocol:  HTTP2,
			hasFallback:       true,
			expectedFallback:  H2mux,
			fetchFunc:         mockFetcher(100),
			namedTunnelConfig: testNamedTunnelConfig,
		},
		{
			name:               "warp routing requesting h2mux",
			protocol:           "h2mux",
			expectedProtocol:   HTTP2,
			hasFallback:        false,
			expectedFallback:   H2mux,
			fetchFunc:          mockFetcher(100),
			warpRoutingEnabled: true,
			namedTunnelConfig:  testNamedTunnelConfig,
		},
		{
			name:               "warp routing http2",
			protocol:           "http2",
			expectedProtocol:   HTTP2,
			hasFallback:        false,
			expectedFallback:   H2mux,
			fetchFunc:          mockFetcher(100),
			warpRoutingEnabled: true,
			namedTunnelConfig:  testNamedTunnelConfig,
		},
		{
			name:               "warp routing auto",
			protocol:           "auto",
			expectedProtocol:   HTTP2,
			hasFallback:        false,
			expectedFallback:   H2mux,
			fetchFunc:          mockFetcher(100),
			warpRoutingEnabled: true,
			namedTunnelConfig:  testNamedTunnelConfig,
		},
		{
			// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
			name:             "classic tunnel unknown protocol",
			protocol:         "unknown",
			expectedProtocol: H2mux,
		},
		{
			name:              "named tunnel unknown protocol",
			protocol:          "unknown",
			fetchFunc:         mockFetcher(100),
			namedTunnelConfig: testNamedTunnelConfig,
			wantErr:           true,
		},
		{
			name:              "named tunnel fetch error",
			protocol:          "unknown",
			fetchFunc:         mockFetcherWithError(),
			namedTunnelConfig: testNamedTunnelConfig,
			expectedProtocol:  HTTP2,
			wantErr:           false,
		},
	}

	for _, test := range tests {
		selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
		if test.wantErr {
			assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
		} else {
			assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
			assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
			fallback, ok := selector.Fallback()
			assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
			if test.hasFallback {
				assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
			}
		}
	}
}

func TestAutoProtocolSelectorRefresh(t *testing.T) {
	fetcher := dynamicMockFetcher{}
	selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
	assert.NoError(t, err)
	assert.Equal(t, H2mux, selector.Current())

	fetcher.percentage = 100
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = 0
	assert.Equal(t, H2mux, selector.Current())

	fetcher.percentage = 100
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.err = fmt.Errorf("failed to fetch")
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = -1
	fetcher.err = nil
	assert.Equal(t, H2mux, selector.Current())

	fetcher.percentage = 0
	assert.Equal(t, H2mux, selector.Current())

	fetcher.percentage = 100
	assert.Equal(t, HTTP2, selector.Current())
}

func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
	fetcher := dynamicMockFetcher{}
	selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
	assert.NoError(t, err)
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = 100
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = 0
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.err = fmt.Errorf("failed to fetch")
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = -1
	fetcher.err = nil
	assert.Equal(t, H2mux, selector.Current())

	fetcher.percentage = 0
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = 100
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = -1
	assert.Equal(t, H2mux, selector.Current())
}

func TestProtocolSelectorRefreshTTL(t *testing.T) {
	fetcher := dynamicMockFetcher{percentage: 100}
	selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
	assert.NoError(t, err)
	assert.Equal(t, HTTP2, selector.Current())

	fetcher.percentage = 0
	assert.Equal(t, HTTP2, selector.Current())
}