package quic

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"testing"
	"time"

	"github.com/google/uuid"
	"github.com/rs/zerolog"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)

const (
	testCloseIdleAfterHint = time.Minute * 2
)

func TestConnectRequestData(t *testing.T) {
	var tests = []struct {
		name           string
		hostname       string
		connectionType ConnectionType
		metadata       []Metadata
	}{
		{
			name:           "Signature verified and request metadata is unmarshaled and read correctly",
			hostname:       "tunnel.com",
			connectionType: ConnectionTypeHTTP,
			metadata: []Metadata{
				{
					Key: "key",
					Val: "1234",
				},
			},
		},
	}
	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			b := &bytes.Buffer{}
			reqClientStream := RequestClientStream{noopCloser{b}}
			err := reqClientStream.WriteConnectRequestData(test.hostname, test.connectionType, test.metadata...)
			require.NoError(t, err)
			protocol, err := DetermineProtocol(b)
			require.NoError(t, err)
			reqServerStream, err := NewRequestServerStream(noopCloser{b}, protocol)
			require.NoError(t, err)

			reqMeta, err := reqServerStream.ReadConnectRequestData()
			require.NoError(t, err)

			assert.Equal(t, test.metadata, reqMeta.Metadata)
			assert.Equal(t, test.hostname, reqMeta.Dest)
			assert.Equal(t, test.connectionType, reqMeta.Type)
		})
	}
}

func TestConnectResponseMeta(t *testing.T) {
	var tests = []struct {
		name     string
		err      error
		metadata []Metadata
	}{
		{
			name: "Signature verified and response metadata is unmarshaled and read correctly",
			metadata: []Metadata{
				{
					Key: "key",
					Val: "1234",
				},
			},
		},
		{
			name: "If error is not empty, other fields should be blank",
			err:  errors.New("something happened"),
			metadata: []Metadata{
				{
					Key: "key",
					Val: "1234",
				},
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			b := &bytes.Buffer{}
			reqServerStream := RequestServerStream{noopCloser{b}}
			err := reqServerStream.WriteConnectResponseData(test.err, test.metadata...)
			require.NoError(t, err)

			reqClientStream := RequestClientStream{noopCloser{b}}
			respMeta, err := reqClientStream.ReadConnectResponseData()
			require.NoError(t, err)

			if respMeta.Error == "" {
				assert.Equal(t, test.metadata, respMeta.Metadata)
			} else {
				assert.Equal(t, 0, len(respMeta.Metadata))
			}
		})
	}
}

func TestUnregisterUdpSession(t *testing.T) {
	unregisterMessage := "closed by eyeball"

	var tests = []struct {
		name             string
		sessionRPCServer mockSessionRPCServer
		timeout          time.Duration
	}{

		{
			name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
			sessionRPCServer: mockSessionRPCServer{
				sessionID:         uuid.New(),
				dstIP:             net.IP{172, 16, 0, 1},
				dstPort:           8000,
				closeIdleAfter:    testCloseIdleAfterHint,
				unregisterMessage: unregisterMessage,
				traceContext:      "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
			},
			// very very low value so we trigger the timeout every time.
			timeout: time.Nanosecond * 1,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			logger := zerolog.Nop()
			clientStream, serverStream := newMockRPCStreams()
			sessionRegisteredChan := make(chan struct{})
			go func() {
				protocol, err := DetermineProtocol(serverStream)
				assert.NoError(t, err)
				rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
				assert.NoError(t, err)
				err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
				assert.NoError(t, err)

				serverStream.Close()
				close(sessionRegisteredChan)
			}()

			rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
			assert.NoError(t, err)

			reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
			assert.NoError(t, err)
			assert.NoError(t, reg.Err)

			assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))

			rpcClientStream.Close()
			<-sessionRegisteredChan
		})
	}

}

func TestRegisterUdpSession(t *testing.T) {
	unregisterMessage := "closed by eyeball"

	var tests = []struct {
		name             string
		sessionRPCServer mockSessionRPCServer
	}{
		{
			name: "RegisterUdpSession (no trace context)",
			sessionRPCServer: mockSessionRPCServer{
				sessionID:         uuid.New(),
				dstIP:             net.IP{172, 16, 0, 1},
				dstPort:           8000,
				closeIdleAfter:    testCloseIdleAfterHint,
				unregisterMessage: unregisterMessage,
				traceContext:      "",
			},
		},
		{
			name: "RegisterUdpSession (with trace context)",
			sessionRPCServer: mockSessionRPCServer{
				sessionID:         uuid.New(),
				dstIP:             net.IP{172, 16, 0, 1},
				dstPort:           8000,
				closeIdleAfter:    testCloseIdleAfterHint,
				unregisterMessage: unregisterMessage,
				traceContext:      "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			logger := zerolog.Nop()
			clientStream, serverStream := newMockRPCStreams()
			sessionRegisteredChan := make(chan struct{})
			go func() {
				protocol, err := DetermineProtocol(serverStream)
				assert.NoError(t, err)
				rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
				assert.NoError(t, err)
				err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
				assert.NoError(t, err)

				serverStream.Close()
				close(sessionRegisteredChan)
			}()

			rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, 5*time.Second, &logger)
			assert.NoError(t, err)

			reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
			assert.NoError(t, err)
			assert.NoError(t, reg.Err)

			// Different sessionID, the RPC server should reject the registraion
			reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
			assert.NoError(t, err)
			assert.Error(t, reg.Err)

			assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))

			// Different sessionID, the RPC server should reject the unregistraion
			assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))

			rpcClientStream.Close()
			<-sessionRegisteredChan
		})
	}
}

func TestManageConfiguration(t *testing.T) {
	var (
		version int32 = 168
		config        = []byte(t.Name())
	)
	clientStream, serverStream := newMockRPCStreams()

	configRPCServer := mockConfigRPCServer{
		version: version,
		config:  config,
	}

	logger := zerolog.Nop()
	updatedChan := make(chan struct{})
	go func() {
		protocol, err := DetermineProtocol(serverStream)
		assert.NoError(t, err)
		rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
		assert.NoError(t, err)
		err = rpcServerStream.Serve(nil, configRPCServer, &logger)
		assert.NoError(t, err)

		serverStream.Close()
		close(updatedChan)
	}()

	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	rpcClientStream, err := NewRPCClientStream(ctx, clientStream, 5*time.Second, &logger)
	assert.NoError(t, err)

	result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
	assert.NoError(t, err)

	require.Equal(t, version, result.LastAppliedVersion)
	require.NoError(t, result.Err)

	rpcClientStream.Close()
	<-updatedChan
}

type mockSessionRPCServer struct {
	sessionID         uuid.UUID
	dstIP             net.IP
	dstPort           uint16
	closeIdleAfter    time.Duration
	unregisterMessage string
	traceContext      string
}

func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
	if s.sessionID != sessionID {
		return nil, fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
	}
	if !s.dstIP.Equal(dstIP) {
		return nil, fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
	}
	if s.dstPort != dstPort {
		return nil, fmt.Errorf("expect destination port %d, got %d", s.dstPort, dstPort)
	}
	if s.closeIdleAfter != closeIdleAfter {
		return nil, fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
	}
	if s.traceContext != traceContext {
		return nil, fmt.Errorf("expect traceContext %s, got %s", s.traceContext, traceContext)
	}
	return &tunnelpogs.RegisterUdpSessionResponse{}, nil
}

func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error {
	if s.sessionID != sessionID {
		return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
	}
	if s.unregisterMessage != message {
		return fmt.Errorf("expect unregister message %s, got %s", s.unregisterMessage, message)
	}
	return nil
}

type mockConfigRPCServer struct {
	version int32
	config  []byte
}

func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
	if s.version != version {
		return &tunnelpogs.UpdateConfigurationResponse{
			Err: fmt.Errorf("expect version %d, got %d", s.version, version),
		}
	}
	if !bytes.Equal(s.config, config) {
		return &tunnelpogs.UpdateConfigurationResponse{
			Err: fmt.Errorf("expect config %v, got %v", s.config, config),
		}
	}
	return &tunnelpogs.UpdateConfigurationResponse{LastAppliedVersion: version}
}

type mockRPCStream struct {
	io.ReadCloser
	io.WriteCloser
}

func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) {
	clientReader, serverWriter := io.Pipe()
	serverReader, clientWriter := io.Pipe()

	client = mockRPCStream{clientReader, clientWriter}
	server = mockRPCStream{serverReader, serverWriter}
	return
}

func (s mockRPCStream) Close() error {
	_ = s.ReadCloser.Close()
	_ = s.WriteCloser.Close()
	return nil
}

type noopCloser struct {
	io.ReadWriter
}

func (noopCloser) Close() error {
	return nil
}