package streamhandler

import (
	"context"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"strconv"
	"sync"
	"testing"
	"time"

	"github.com/cloudflare/cloudflared/h2mux"
	"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
	"github.com/stretchr/testify/assert"
	"golang.org/x/sync/errgroup"
)

const (
	testOpenStreamTimeout = time.Millisecond * 5000
	testHandshakeTimeout  = time.Millisecond * 1000
)

var (
	testTunnelHostname = h2mux.TunnelHostname("123.cftunnel.com")
	baseHeaders        = []h2mux.Header{
		{Name: ":method", Value: "GET"},
		{Name: ":scheme", Value: "http"},
		{Name: ":authority", Value: "example.com"},
		{Name: ":path", Value: "/"},
	}
	tunnelHostnameHeader = h2mux.Header{
		Name:  h2mux.CloudflaredProxyTunnelHostnameHeader,
		Value: testTunnelHostname.String(),
	}
)

func TestServeRequest(t *testing.T) {
	configChan := make(chan *pogs.ClientConfig)
	useConfigResultChan := make(chan *pogs.UseConfigurationResult)
	streamHandler := NewStreamHandler(configChan, useConfigResultChan, logrus.New())

	message := []byte("Hello cloudflared")
	httpServer := httptest.NewServer(&mockHTTPHandler{message})

	reverseProxyConfigs := []*pogs.ReverseProxyConfig{
		{
			TunnelHostname: testTunnelHostname,
			OriginConfig: &pogs.HTTPOriginConfig{
				URLString: httpServer.URL,
			},
		},
	}
	streamHandler.UpdateConfig(reverseProxyConfigs)

	muxPair := NewDefaultMuxerPair(t, streamHandler)
	muxPair.Serve(t)

	ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
	defer cancel()

	headers := append(baseHeaders, tunnelHostnameHeader)
	stream, err := muxPair.EdgeMux.OpenStream(ctx, headers, nil)
	assert.NoError(t, err)
	assertStatusHeader(t, http.StatusOK, stream.Headers)
	assertRespBody(t, message, stream)
}

func TestServeBadRequest(t *testing.T) {
	configChan := make(chan *pogs.ClientConfig)
	useConfigResultChan := make(chan *pogs.UseConfigurationResult)
	streamHandler := NewStreamHandler(configChan, useConfigResultChan, logrus.New())

	muxPair := NewDefaultMuxerPair(t, streamHandler)
	muxPair.Serve(t)

	ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
	defer cancel()

	// No tunnel hostname header, expect to get 400 Bad Request
	stream, err := muxPair.EdgeMux.OpenStream(ctx, baseHeaders, nil)
	assert.NoError(t, err)
	assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
	assertRespBody(t, statusBadRequest.text, stream)

	// No mapping for the tunnel hostname, expect to get 404 Not Found
	headers := append(baseHeaders, tunnelHostnameHeader)
	stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil)
	assert.NoError(t, err)
	assertStatusHeader(t, http.StatusNotFound, stream.Headers)
	assertRespBody(t, statusNotFound.text, stream)

	// Nothing listening on empty url, so proxy would fail. Expect to get 502 Bad Gateway
	reverseProxyConfigs := []*pogs.ReverseProxyConfig{
		{
			TunnelHostname: testTunnelHostname,
			OriginConfig: &pogs.HTTPOriginConfig{
				URLString: "",
			},
		},
	}
	streamHandler.UpdateConfig(reverseProxyConfigs)
	stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil)
	assert.NoError(t, err)
	assertStatusHeader(t, http.StatusBadGateway, stream.Headers)
	assertRespBody(t, statusBadGateway.text, stream)

	// Invalid content-length, wouldn't not be able to create a request. Expect to get 400 Bad Request
	headers = append(headers, h2mux.Header{
		Name:  "content-length",
		Value: "x",
	})
	stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil)
	assert.NoError(t, err)
	assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
	assertRespBody(t, statusBadRequest.text, stream)
}

func assertStatusHeader(t *testing.T, expectedStatus int, headers []h2mux.Header) {
	assert.Equal(t, statusPseudoHeader, headers[0].Name)
	assert.Equal(t, strconv.Itoa(expectedStatus), headers[0].Value)
}

func assertRespBody(t *testing.T, expectedRespBody []byte, stream *h2mux.MuxedStream) {
	respBody := make([]byte, len(expectedRespBody))
	_, err := stream.Read(respBody)
	assert.NoError(t, err)
	assert.Equal(t, expectedRespBody, respBody)
}

type DefaultMuxerPair struct {
	OriginMuxConfig h2mux.MuxerConfig
	OriginMux       *h2mux.Muxer
	OriginConn      net.Conn
	EdgeMuxConfig   h2mux.MuxerConfig
	EdgeMux         *h2mux.Muxer
	EdgeConn        net.Conn
	doneC           chan struct{}
}

func NewDefaultMuxerPair(t *testing.T, h h2mux.MuxedStreamHandler) *DefaultMuxerPair {
	origin, edge := net.Pipe()
	p := &DefaultMuxerPair{
		OriginMuxConfig: h2mux.MuxerConfig{
			Timeout:                 testHandshakeTimeout,
			Handler:                 h,
			IsClient:                true,
			Name:                    "origin",
			Logger:                  logrus.NewEntry(logrus.New()),
			DefaultWindowSize:       (1 << 8) - 1,
			MaxWindowSize:           (1 << 15) - 1,
			StreamWriteBufferMaxLen: 1024,
		},
		OriginConn: origin,
		EdgeMuxConfig: h2mux.MuxerConfig{
			Timeout:                 testHandshakeTimeout,
			IsClient:                false,
			Name:                    "edge",
			Logger:                  logrus.NewEntry(logrus.New()),
			DefaultWindowSize:       (1 << 8) - 1,
			MaxWindowSize:           (1 << 15) - 1,
			StreamWriteBufferMaxLen: 1024,
		},
		EdgeConn: edge,
		doneC:    make(chan struct{}),
	}
	assert.NoError(t, p.Handshake(t.Name()))
	return p
}

func (p *DefaultMuxerPair) Handshake(testName string) error {
	ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout)
	defer cancel()
	errGroup, _ := errgroup.WithContext(ctx)
	errGroup.Go(func() (err error) {
		p.EdgeMux, err = h2mux.Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, h2mux.NewActiveStreamsMetrics(testName, "edge"))
		return errors.Wrap(err, "edge handshake failure")
	})
	errGroup.Go(func() (err error) {
		p.OriginMux, err = h2mux.Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, h2mux.NewActiveStreamsMetrics(testName, "origin"))
		return errors.Wrap(err, "origin handshake failure")
	})

	return errGroup.Wait()
}

func (p *DefaultMuxerPair) Serve(t assert.TestingT) {
	ctx := context.Background()
	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		err := p.EdgeMux.Serve(ctx)
		if err != nil && err != io.EOF && err != io.ErrClosedPipe {
			t.Errorf("error in edge muxer Serve(): %s", err)
		}
		p.OriginMux.Shutdown()
		wg.Done()
	}()
	go func() {
		err := p.OriginMux.Serve(ctx)
		if err != nil && err != io.EOF && err != io.ErrClosedPipe {
			t.Errorf("error in origin muxer Serve(): %s", err)
		}
		p.EdgeMux.Shutdown()
		wg.Done()
	}()
	go func() {
		// notify when both muxes have stopped serving
		wg.Wait()
		close(p.doneC)
	}()
}

type mockHTTPHandler struct {
	message []byte
}

func (mth *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	w.Write(mth.message)
}