package supervisor

import (
	"crypto/tls"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"github.com/cloudflare/cloudflared/features"
)

func TestCurvePreferences(t *testing.T) {
	// This tests if the correct curves are returned
	// given a PostQuantumMode and a FIPS enabled bool
	t.Parallel()

	tests := []struct {
		name           string
		currentCurves  []tls.CurveID
		expectedCurves []tls.CurveID
		pqMode         features.PostQuantumMode
		fipsEnabled    bool
	}{
		{
			name:           "FIPS with Prefer PQ",
			pqMode:         features.PostQuantumPrefer,
			fipsEnabled:    true,
			currentCurves:  []tls.CurveID{tls.CurveP384},
			expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256},
		},
		{
			name:           "FIPS with Strict PQ",
			pqMode:         features.PostQuantumStrict,
			fipsEnabled:    true,
			currentCurves:  []tls.CurveID{tls.CurveP256, tls.CurveP384},
			expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex},
		},
		{
			name:           "FIPS with Prefer PQ - no duplicates",
			pqMode:         features.PostQuantumPrefer,
			fipsEnabled:    true,
			currentCurves:  []tls.CurveID{tls.CurveP256},
			expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256},
		},
		{
			name:           "Non FIPS with Prefer PQ",
			pqMode:         features.PostQuantumPrefer,
			fipsEnabled:    false,
			currentCurves:  []tls.CurveID{tls.CurveP256},
			expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
		},
		{
			name:           "Non FIPS with Prefer PQ - no duplicates",
			pqMode:         features.PostQuantumPrefer,
			fipsEnabled:    false,
			currentCurves:  []tls.CurveID{X25519Kyber768Draft00PQKex, tls.CurveP256},
			expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
		},
		{
			name:           "Non FIPS with Prefer PQ - correct preference order",
			pqMode:         features.PostQuantumPrefer,
			fipsEnabled:    false,
			currentCurves:  []tls.CurveID{tls.CurveP256, X25519Kyber768Draft00PQKex},
			expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
		},
		{
			name:           "Non FIPS with Strict PQ",
			pqMode:         features.PostQuantumStrict,
			fipsEnabled:    false,
			currentCurves:  []tls.CurveID{tls.CurveP256},
			expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex},
		},
	}

	for _, tcase := range tests {
		t.Run(tcase.name, func(t *testing.T) {
			t.Parallel()
			curves, err := curvePreference(tcase.pqMode, tcase.fipsEnabled, tcase.currentCurves)
			require.NoError(t, err)
			assert.Equal(t, tcase.expectedCurves, curves)
		})
	}
}