package connection

import (
	"context"
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"sync"
	"testing"
	"time"

	"github.com/cloudflare/cloudflared/h2mux"
	"github.com/gobwas/ws/wsutil"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

var (
	testMuxerConfig = &MuxerConfig{
		HeartbeatInterval:  time.Second * 5,
		MaxHeartbeats:      5,
		CompressionSetting: 0,
		MetricsUpdateFreq:  time.Second * 5,
	}
)

func newH2MuxConnection(ctx context.Context, t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
	edgeConn, originConn := net.Pipe()
	edgeMuxChan := make(chan *h2mux.Muxer)
	go func() {
		edgeMuxConfig := h2mux.MuxerConfig{
			Logger: testObserver,
		}
		edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
		require.NoError(t, err)
		edgeMuxChan <- edgeMux
	}()
	var connIndex = uint8(0)
	h2muxConn, err, _ := NewH2muxConnection(ctx, testConfig, testMuxerConfig, originConn, connIndex, testObserver)
	require.NoError(t, err)
	return h2muxConn, <-edgeMuxChan
}

func TestServeStreamHTTP(t *testing.T) {
	tests := []testRequest{
		{
			name:           "ok",
			endpoint:       "/ok",
			expectedStatus: http.StatusOK,
			expectedBody:   []byte(http.StatusText(http.StatusOK)),
		},
		{
			name:           "large_file",
			endpoint:       "/large_file",
			expectedStatus: http.StatusOK,
			expectedBody:   testLargeResp,
		},
		{
			name:           "Bad request",
			endpoint:       "/400",
			expectedStatus: http.StatusBadRequest,
			expectedBody:   []byte(http.StatusText(http.StatusBadRequest)),
		},
		{
			name:           "Internal server error",
			endpoint:       "/500",
			expectedStatus: http.StatusInternalServerError,
			expectedBody:   []byte(http.StatusText(http.StatusInternalServerError)),
		},
		{
			name:           "Proxy error",
			endpoint:       "/error",
			expectedStatus: http.StatusBadGateway,
			expectedBody:   nil,
			isProxyError:   true,
		},
	}

	ctx, cancel := context.WithCancel(context.Background())
	h2muxConn, edgeMux := newH2MuxConnection(ctx, t)

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		edgeMux.Serve(ctx)
	}()
	go func() {
		defer wg.Done()
		err := h2muxConn.serveMuxer(ctx)
		require.Error(t, err)
	}()

	for _, test := range tests {
		headers := []h2mux.Header{
			{
				Name:  ":path",
				Value: test.endpoint,
			},
		}
		stream, err := edgeMux.OpenStream(ctx, headers, nil)
		require.NoError(t, err)
		require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))

		if test.isProxyError {
			assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderCfd))
		} else {
			assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
			body := make([]byte, len(test.expectedBody))
			_, err = stream.Read(body)
			require.NoError(t, err)
			require.Equal(t, test.expectedBody, body)
		}
	}
	cancel()
	wg.Wait()
}

func TestServeStreamWS(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	h2muxConn, edgeMux := newH2MuxConnection(ctx, t)

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		edgeMux.Serve(ctx)
	}()
	go func() {
		defer wg.Done()
		err := h2muxConn.serveMuxer(ctx)
		require.Error(t, err)
	}()

	headers := []h2mux.Header{
		{
			Name:  ":path",
			Value: "/ws",
		},
		{
			Name:  "connection",
			Value: "upgrade",
		},
		{
			Name:  "upgrade",
			Value: "websocket",
		},
	}

	readPipe, writePipe := io.Pipe()
	stream, err := edgeMux.OpenStream(ctx, headers, readPipe)
	require.NoError(t, err)

	require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
	assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))

	data := []byte("test websocket")
	err = wsutil.WriteClientText(writePipe, data)
	require.NoError(t, err)

	respBody, err := wsutil.ReadServerText(stream)
	require.NoError(t, err)
	require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))

	cancel()
	wg.Wait()
}

func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
	for _, header := range stream.Headers {
		if header.Name == name && header.Value == val {
			return true
		}
	}
	return false
}

func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
	ctx, cancel := context.WithCancel(context.Background())
	h2muxConn, edgeMux := newH2MuxConnection(ctx, b)

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		edgeMux.Serve(ctx)
	}()
	go func() {
		defer wg.Done()
		err := h2muxConn.serveMuxer(ctx)
		require.Error(b, err)
	}()

	headers := []h2mux.Header{
		{
			Name:  ":path",
			Value: test.endpoint,
		},
	}

	body := make([]byte, len(test.expectedBody))
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		b.StartTimer()
		stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil)
		_, readBodyErr := stream.Read(body)
		b.StopTimer()

		require.NoError(b, openstreamErr)
		assert.True(b, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
		require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
		require.NoError(b, readBodyErr)
		require.Equal(b, test.expectedBody, body)
	}

	cancel()
	wg.Wait()
}

func BenchmarkServeStreamHTTPSimple(b *testing.B) {
	test := testRequest{
		name:           "ok",
		endpoint:       "/ok",
		expectedStatus: http.StatusOK,
		expectedBody:   []byte(http.StatusText(http.StatusOK)),
	}

	benchmarkServeStreamHTTPSimple(b, test)
}

func BenchmarkServeStreamHTTPLargeFile(b *testing.B) {
	test := testRequest{
		name:           "large_file",
		endpoint:       "/large_file",
		expectedStatus: http.StatusOK,
		expectedBody:   testLargeResp,
	}

	benchmarkServeStreamHTTPSimple(b, test)
}