package connection

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"sync"
	"testing"
	"time"

	"github.com/gobwas/ws/wsutil"
	"github.com/google/uuid"
	"github.com/rs/zerolog"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"golang.org/x/net/http2"

	"github.com/cloudflare/cloudflared/tunnelrpc"
	"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)

var (
	testTransport = http2.Transport{}
)

func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
	edgeConn, cfdConn := net.Pipe()
	var connIndex = uint8(0)
	log := zerolog.Nop()
	obs := NewObserver(&log, &log)
	controlStream := NewControlStream(
		obs,
		mockConnectedFuse{},
		&TunnelProperties{},
		connIndex,
		nil,
		nil,
		1*time.Second,
		nil,
		1*time.Second,
		HTTP2,
	)
	return NewHTTP2Connection(
		cfdConn,
		// OriginProxy is set in testConfigManager
		testOrchestrator,
		&pogs.ConnectionOptions{},
		obs,
		connIndex,
		controlStream,
		&log,
	), edgeConn
}

func TestHTTP2ConfigurationSet(t *testing.T) {
	http2Conn, edgeConn := newTestHTTP2Connection()

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		http2Conn.Serve(ctx)
	}()

	edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
	require.NoError(t, err)

	endpoint := fmt.Sprintf("http://localhost:8080/ok")
	reqBody := []byte(`{
"version": 2, 
"config": {"warp-routing": {"enabled": true},  "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`)
	reader := bytes.NewReader(reqBody)
	req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
	require.NoError(t, err)
	req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)

	resp, err := edgeHTTP2Conn.RoundTrip(req)
	require.NoError(t, err)
	require.Equal(t, http.StatusOK, resp.StatusCode)
	bdy, err := io.ReadAll(resp.Body)
	require.NoError(t, err)
	assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
	cancel()
	wg.Wait()

}

func TestServeHTTP(t *testing.T) {
	tests := []testRequest{
		{
			name:           "ok",
			endpoint:       "ok",
			expectedStatus: http.StatusOK,
			expectedBody:   []byte(http.StatusText(http.StatusOK)),
		},
		{
			name:           "large_file",
			endpoint:       "large_file",
			expectedStatus: http.StatusOK,
			expectedBody:   testLargeResp,
		},
		{
			name:           "Bad request",
			endpoint:       "400",
			expectedStatus: http.StatusBadRequest,
			expectedBody:   []byte(http.StatusText(http.StatusBadRequest)),
		},
		{
			name:           "Internal server error",
			endpoint:       "500",
			expectedStatus: http.StatusInternalServerError,
			expectedBody:   []byte(http.StatusText(http.StatusInternalServerError)),
		},
		{
			name:           "Proxy error",
			endpoint:       "error",
			expectedStatus: http.StatusBadGateway,
			expectedBody:   nil,
			isProxyError:   true,
		},
	}

	http2Conn, edgeConn := newTestHTTP2Connection()

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		http2Conn.Serve(ctx)
	}()

	edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
	require.NoError(t, err)

	for _, test := range tests {
		endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
		require.NoError(t, err)

		resp, err := edgeHTTP2Conn.RoundTrip(req)
		require.NoError(t, err)
		require.Equal(t, test.expectedStatus, resp.StatusCode)
		if test.expectedBody != nil {
			respBody, err := io.ReadAll(resp.Body)
			require.NoError(t, err)
			require.Equal(t, test.expectedBody, respBody)
		}
		if test.isProxyError {
			require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
		} else {
			require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
		}
	}
	cancel()
	wg.Wait()
}

type mockNamedTunnelRPCClient struct {
	shouldFail   error
	registered   chan struct{}
	unregistered chan struct{}
}

func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error {
	return nil
}

func (mc mockNamedTunnelRPCClient) RegisterConnection(
	ctx context.Context,
	auth pogs.TunnelAuth,
	tunnelID uuid.UUID,
	options *pogs.ConnectionOptions,
	connIndex uint8,
	edgeAddress net.IP,
) (*pogs.ConnectionDetails, error) {
	if mc.shouldFail != nil {
		return nil, mc.shouldFail
	}
	close(mc.registered)
	return &pogs.ConnectionDetails{
		Location:                "LIS",
		UUID:                    uuid.New(),
		TunnelIsRemotelyManaged: false,
	}, nil
}

func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
	close(mc.unregistered)
	return nil
}

func (mockNamedTunnelRPCClient) Close() {}

type mockRPCClientFactory struct {
	shouldFail   error
	registered   chan struct{}
	unregistered chan struct{}
}

func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient {
	return &mockNamedTunnelRPCClient{
		shouldFail:   mf.shouldFail,
		registered:   mf.registered,
		unregistered: mf.unregistered,
	}
}

type wsRespWriter struct {
	*httptest.ResponseRecorder
	readPipe  *io.PipeReader
	writePipe *io.PipeWriter
	closed    bool
	panicked  bool
}

func newWSRespWriter() *wsRespWriter {
	readPipe, writePipe := io.Pipe()
	return &wsRespWriter{
		httptest.NewRecorder(),
		readPipe,
		writePipe,
		false,
		false,
	}
}

type nowriter struct {
	io.Reader
}

func (nowriter) Write(_ []byte) (int, error) {
	return 0, fmt.Errorf("writer not implemented")
}

func (w *wsRespWriter) RespBody() io.ReadWriter {
	return nowriter{w.readPipe}
}

func (w *wsRespWriter) Write(data []byte) (n int, err error) {
	if w.closed {
		w.panicked = true
		return 0, errors.New("wsRespWriter panicked")
	}
	return w.writePipe.Write(data)
}

func (w *wsRespWriter) close() {
	w.closed = true
}

func TestServeWS(t *testing.T) {
	http2Conn, _ := newTestHTTP2Connection()

	ctx, cancel := context.WithCancel(context.Background())

	respWriter := newWSRespWriter()
	readPipe, writePipe := io.Pipe()

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
	require.NoError(t, err)
	req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)

	serveDone := make(chan struct{})
	go func() {
		defer close(serveDone)
		http2Conn.ServeHTTP(respWriter, req)
		respWriter.close()
	}()

	data := []byte("test websocket")
	err = wsutil.WriteClientBinary(writePipe, data)
	require.NoError(t, err)

	respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
	require.NoError(t, err)
	require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))

	cancel()
	resp := respWriter.Result()
	// http2RespWriter should rewrite status 101 to 200
	require.Equal(t, http.StatusOK, resp.StatusCode)
	require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))

	<-serveDone
	require.False(t, respWriter.panicked)
}

// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
	cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup

	serverDone := make(chan struct{})
	go func() {
		defer close(serverDone)
		cfdHTTP2Conn.Serve(ctx)
	}()

	edgeTransport := http2.Transport{}
	edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
	require.NoError(t, err)
	message := []byte(t.Name())

	for i := 0; i < 100; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			readPipe, writePipe := io.Pipe()
			reqCtx, reqCancel := context.WithCancel(ctx)
			req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
			require.NoError(t, err)
			req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)

			resp, err := edgeHTTP2Conn.RoundTrip(req)
			require.NoError(t, err)
			// http2RespWriter should rewrite status 101 to 200
			require.Equal(t, http.StatusOK, resp.StatusCode)

			wg.Add(1)
			go func() {
				defer wg.Done()
				for {
					select {
					case <-reqCtx.Done():
						return
					default:
					}
					_ = wsutil.WriteClientBinary(writePipe, message)
				}
			}()

			time.Sleep(time.Millisecond * 100)
			reqCancel()
		}()
	}

	wg.Wait()
	cancel()
	<-serverDone
}

func TestServeControlStream(t *testing.T) {
	http2Conn, edgeConn := newTestHTTP2Connection()

	rpcClientFactory := mockRPCClientFactory{
		registered:   make(chan struct{}),
		unregistered: make(chan struct{}),
	}

	obs := NewObserver(&log, &log)
	controlStream := NewControlStream(
		obs,
		mockConnectedFuse{},
		&TunnelProperties{},
		1,
		nil,
		rpcClientFactory.newMockRPCClient,
		1*time.Second,
		nil,
		1*time.Second,
		HTTP2,
	)
	http2Conn.controlStreamHandler = controlStream

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		http2Conn.Serve(ctx)
	}()

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
	require.NoError(t, err)
	req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)

	edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
	require.NoError(t, err)

	wg.Add(1)
	go func() {
		defer wg.Done()
		edgeHTTP2Conn.RoundTrip(req)
	}()

	<-rpcClientFactory.registered
	cancel()
	<-rpcClientFactory.unregistered
	assert.False(t, http2Conn.stoppedGracefully)

	wg.Wait()
}

func TestFailRegistration(t *testing.T) {
	http2Conn, edgeConn := newTestHTTP2Connection()

	rpcClientFactory := mockRPCClientFactory{
		shouldFail:   errDuplicationConnection,
		registered:   make(chan struct{}),
		unregistered: make(chan struct{}),
	}

	obs := NewObserver(&log, &log)
	controlStream := NewControlStream(
		obs,
		mockConnectedFuse{},
		&TunnelProperties{},
		http2Conn.connIndex,
		nil,
		rpcClientFactory.newMockRPCClient,
		1*time.Second,
		nil,
		1*time.Second,
		HTTP2,
	)
	http2Conn.controlStreamHandler = controlStream

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		http2Conn.Serve(ctx)
	}()

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
	require.NoError(t, err)
	req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)

	edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
	require.NoError(t, err)
	resp, err := edgeHTTP2Conn.RoundTrip(req)
	require.NoError(t, err)
	require.Equal(t, http.StatusBadGateway, resp.StatusCode)

	assert.NotNil(t, http2Conn.controlStreamErr)
	cancel()
	wg.Wait()
}

func TestGracefulShutdownHTTP2(t *testing.T) {
	http2Conn, edgeConn := newTestHTTP2Connection()

	rpcClientFactory := mockRPCClientFactory{
		registered:   make(chan struct{}),
		unregistered: make(chan struct{}),
	}
	events := &eventCollectorSink{}

	shutdownC := make(chan struct{})
	obs := NewObserver(&log, &log)
	obs.RegisterSink(events)
	controlStream := NewControlStream(
		obs,
		mockConnectedFuse{},
		&TunnelProperties{},
		http2Conn.connIndex,
		nil,
		rpcClientFactory.newMockRPCClient,
		1*time.Second,
		shutdownC,
		1*time.Second,
		HTTP2,
	)

	http2Conn.controlStreamHandler = controlStream

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		http2Conn.Serve(ctx)
	}()

	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
	require.NoError(t, err)
	req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)

	edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
	require.NoError(t, err)

	wg.Add(1)
	go func() {
		defer wg.Done()
		_, _ = edgeHTTP2Conn.RoundTrip(req)
	}()

	select {
	case <-rpcClientFactory.registered:
		break // ok
	case <-time.Tick(time.Second):
		t.Fatal("timeout out waiting for registration")
	}

	// signal graceful shutdown
	close(shutdownC)

	select {
	case <-rpcClientFactory.unregistered:
		break // ok
	case <-time.Tick(time.Second):
		t.Fatal("timeout out waiting for unregistered signal")
	}
	assert.True(t, controlStream.IsStopped())

	cancel()
	wg.Wait()

	events.assertSawEvent(t, Event{
		Index:     http2Conn.connIndex,
		EventType: Unregistering,
	})
}

func benchmarkServeHTTP(b *testing.B, test testRequest) {
	http2Conn, edgeConn := newTestHTTP2Connection()

	ctx, cancel := context.WithCancel(context.Background())
	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		http2Conn.Serve(ctx)
	}()

	endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
	require.NoError(b, err)

	edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
	require.NoError(b, err)

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		b.StartTimer()
		resp, err := edgeHTTP2Conn.RoundTrip(req)
		b.StopTimer()
		require.NoError(b, err)
		require.Equal(b, test.expectedStatus, resp.StatusCode)
		if test.expectedBody != nil {
			respBody, err := io.ReadAll(resp.Body)
			require.NoError(b, err)
			require.Equal(b, test.expectedBody, respBody)
		}
		resp.Body.Close()
	}

	cancel()
	wg.Wait()
}

func BenchmarkServeHTTPSimple(b *testing.B) {
	test := testRequest{
		name:           "ok",
		endpoint:       "ok",
		expectedStatus: http.StatusOK,
		expectedBody:   []byte(http.StatusText(http.StatusOK)),
	}

	benchmarkServeHTTP(b, test)
}

func BenchmarkServeHTTPLargeFile(b *testing.B) {
	test := testRequest{
		name:           "large_file",
		endpoint:       "large_file",
		expectedStatus: http.StatusOK,
		expectedBody:   testLargeResp,
	}

	benchmarkServeHTTP(b, test)
}