package v3_test

import (
	"context"
	"errors"
	"net"
	"net/netip"
	"slices"
	"sync/atomic"
	"testing"
	"time"

	"github.com/rs/zerolog"

	v3 "github.com/cloudflare/cloudflared/quic/v3"
)

var (
	expectedContextCanceled = errors.New("expected context canceled")

	testOriginAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0"))
	testLocalAddr  = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0"))
)

func TestSessionNew(t *testing.T) {
	log := zerolog.Nop()
	session := v3.NewSession(testRequestID, 5*time.Second, nil, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
	if testRequestID != session.ID() {
		t.Fatalf("session id doesn't match: %s != %s", testRequestID, session.ID())
	}
}

func testSessionWrite(t *testing.T, payload []byte) {
	log := zerolog.Nop()
	origin := newTestOrigin(makePayload(1280))
	session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
	n, err := session.Write(payload)
	if err != nil {
		t.Fatal(err)
	}
	if n != len(payload) {
		t.Fatal("unable to write the whole payload")
	}
	if !slices.Equal(payload, origin.write[:len(payload)]) {
		t.Fatal("payload provided from origin and read value are not the same")
	}
}

func TestSessionWrite_Max(t *testing.T) {
	payload := makePayload(1280)
	testSessionWrite(t, payload)
}

func TestSessionWrite_Min(t *testing.T) {
	payload := makePayload(0)
	testSessionWrite(t, payload)
}

func TestSessionServe_OriginMax(t *testing.T) {
	payload := makePayload(1280)
	testSessionServe_Origin(t, payload)
}

func TestSessionServe_OriginMin(t *testing.T) {
	payload := makePayload(0)
	testSessionServe_Origin(t, payload)
}

func testSessionServe_Origin(t *testing.T, payload []byte) {
	log := zerolog.Nop()
	eyeball := newMockEyeball()
	origin := newTestOrigin(payload)
	session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
	defer session.Close()

	ctx, cancel := context.WithCancelCause(context.Background())
	defer cancel(context.Canceled)
	done := make(chan error)
	go func() {
		done <- session.Serve(ctx)
	}()

	select {
	case data := <-eyeball.recvData:
		// check received data matches provided from origin
		expectedData := makePayload(1500)
		v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
		copy(expectedData[17:], payload)
		if !slices.Equal(expectedData[:17+len(payload)], data) {
			t.Fatal("expected datagram did not equal expected")
		}
		cancel(expectedContextCanceled)
	case err := <-ctx.Done():
		// we expect the payload to return before the context to cancel on the session
		t.Fatal(err)
	}

	err := <-done
	if !errors.Is(err, context.Canceled) {
		t.Fatal(err)
	}
	if !errors.Is(context.Cause(ctx), expectedContextCanceled) {
		t.Fatal(err)
	}
}

func TestSessionServe_OriginTooLarge(t *testing.T) {
	log := zerolog.Nop()
	eyeball := newMockEyeball()
	payload := makePayload(1281)
	origin := newTestOrigin(payload)
	session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
	defer session.Close()

	done := make(chan error)
	go func() {
		done <- session.Serve(context.Background())
	}()

	select {
	case data := <-eyeball.recvData:
		// we never expect a read to make it here because the origin provided a payload that is too large
		// for cloudflared to proxy and it will drop it.
		t.Fatalf("we should never proxy a payload of this size: %d", len(data))
	case err := <-done:
		if !errors.Is(err, v3.SessionIdleErr{}) {
			t.Error(err)
		}
	}
}

func TestSessionServe_Migrate(t *testing.T) {
	log := zerolog.Nop()
	eyeball := newMockEyeball()
	pipe1, pipe2 := net.Pipe()
	session := v3.NewSession(testRequestID, 2*time.Second, pipe2, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
	defer session.Close()

	done := make(chan error)
	eyeball1Ctx, cancel := context.WithCancelCause(context.Background())
	go func() {
		done <- session.Serve(eyeball1Ctx)
	}()

	// Migrate the session to a new connection before origin sends data
	eyeball2 := newMockEyeball()
	eyeball2.connID = 1
	eyeball2Ctx := context.Background()
	session.Migrate(&eyeball2, eyeball2Ctx, &log)

	// Cancel the origin eyeball context; this should not cancel the session
	contextCancelErr := errors.New("context canceled for first eyeball connection")
	cancel(contextCancelErr)
	select {
	case <-done:
		t.Fatalf("expected session to still be running")
	default:
	}
	if context.Cause(eyeball1Ctx) != contextCancelErr {
		t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx))
	}

	// Origin sends data
	payload2 := []byte{0xde}
	pipe1.Write(payload2)

	// Expect write to eyeball2
	data := <-eyeball2.recvData
	if len(data) <= 17 || !slices.Equal(payload2, data[17:]) {
		t.Fatalf("expected data to write to eyeball2 after migration: %+v", data)
	}

	select {
	case data := <-eyeball.recvData:
		t.Fatalf("expected no data to write to eyeball1 after migration: %+v", data)
	default:
	}

	err := <-done
	if !errors.Is(err, v3.SessionIdleErr{}) {
		t.Error(err)
	}
	if eyeball2Ctx.Err() != nil {
		t.Fatalf("second eyeball context should be not be cancelled")
	}
}

func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
	log := zerolog.Nop()
	eyeball := newMockEyeball()
	pipe1, pipe2 := net.Pipe()
	session := v3.NewSession(testRequestID, 2*time.Second, pipe2, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
	defer session.Close()

	done := make(chan error)
	eyeball1Ctx, cancel := context.WithCancelCause(context.Background())
	go func() {
		done <- session.Serve(eyeball1Ctx)
	}()

	// Migrate the session to a new connection before origin sends data
	eyeball2 := newMockEyeball()
	eyeball2.connID = 1
	eyeball2Ctx, cancel2 := context.WithCancelCause(context.Background())
	session.Migrate(&eyeball2, eyeball2Ctx, &log)

	// Cancel the origin eyeball context; this should not cancel the session
	contextCancelErr := errors.New("context canceled for first eyeball connection")
	cancel(contextCancelErr)
	select {
	case <-done:
		t.Fatalf("expected session to still be running")
	default:
	}
	if context.Cause(eyeball1Ctx) != contextCancelErr {
		t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx))
	}

	// Origin sends data
	payload2 := []byte{0xde}
	pipe1.Write(payload2)

	// Expect write to eyeball2
	data := <-eyeball2.recvData
	if len(data) <= 17 || !slices.Equal(payload2, data[17:]) {
		t.Fatalf("expected data to write to eyeball2 after migration: %+v", data)
	}

	select {
	case data := <-eyeball.recvData:
		t.Fatalf("expected no data to write to eyeball1 after migration: %+v", data)
	default:
	}

	// Close the connection2 context manually
	contextCancel2Err := errors.New("context canceled for second eyeball connection")
	cancel2(contextCancel2Err)
	err := <-done
	if err != context.Canceled {
		t.Fatalf("session Serve should be done: %+v", err)
	}
	if context.Cause(eyeball2Ctx) != contextCancel2Err {
		t.Fatalf("second eyeball context should have been cancelled manually: %+v", context.Cause(eyeball2Ctx))
	}
}

func TestSessionClose_Multiple(t *testing.T) {
	log := zerolog.Nop()
	origin := newTestOrigin(makePayload(128))
	session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
	err := session.Close()
	if err != nil {
		t.Fatal(err)
	}
	if !origin.closed.Load() {
		t.Fatal("origin wasn't closed")
	}
	// Reset the closed status to make sure it isn't closed again
	origin.closed.Store(false)
	// subsequent closes shouldn't call close again or cause any errors
	err = session.Close()
	if err != nil {
		t.Fatal(err)
	}
	if origin.closed.Load() {
		t.Fatal("origin was incorrectly closed twice")
	}
}

func TestSessionServe_IdleTimeout(t *testing.T) {
	log := zerolog.Nop()
	origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
	closeAfterIdle := 2 * time.Second
	session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
	err := session.Serve(context.Background())
	if !errors.Is(err, v3.SessionIdleErr{}) {
		t.Fatal(err)
	}
	// session should be closed
	if !origin.closed {
		t.Fatalf("session should be closed after Serve returns")
	}
	// closing a session again should not return an error
	err = session.Close()
	if err != nil {
		t.Fatal(err)
	}
}

func TestSessionServe_ParentContextCanceled(t *testing.T) {
	log := zerolog.Nop()
	// Make idle time and idle timeout longer than closeAfterIdle
	origin := newTestIdleOrigin(10 * time.Second)
	closeAfterIdle := 10 * time.Second

	session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()
	err := session.Serve(ctx)
	if !errors.Is(err, context.DeadlineExceeded) {
		t.Fatal(err)
	}
	// session should be closed
	if !origin.closed {
		t.Fatalf("session should be closed after Serve returns")
	}
	// closing a session again should not return an error
	err = session.Close()
	if err != nil {
		t.Fatal(err)
	}
}

func TestSessionServe_ReadErrors(t *testing.T) {
	log := zerolog.Nop()
	origin := newTestErrOrigin(net.ErrClosed, nil)
	session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
	err := session.Serve(context.Background())
	if !errors.Is(err, net.ErrClosed) {
		t.Fatal(err)
	}
}

type testOrigin struct {
	// bytes from Write
	write []byte
	// bytes provided to Read
	read     []byte
	readOnce atomic.Bool
	closed   atomic.Bool
}

func newTestOrigin(payload []byte) testOrigin {
	return testOrigin{
		read: payload,
	}
}

func (o *testOrigin) Read(p []byte) (n int, err error) {
	if o.closed.Load() {
		return -1, net.ErrClosed
	}
	if o.readOnce.Load() {
		// We only want to provide one read so all other reads will be blocked
		time.Sleep(10 * time.Second)
	}
	o.readOnce.Store(true)
	return copy(p, o.read), nil
}

func (o *testOrigin) Write(p []byte) (n int, err error) {
	if o.closed.Load() {
		return -1, net.ErrClosed
	}
	o.write = make([]byte, len(p))
	copy(o.write, p)
	return len(p), nil
}

func (o *testOrigin) Close() error {
	o.closed.Store(true)
	return nil
}

type testIdleOrigin struct {
	duration time.Duration
	closed   bool
}

func newTestIdleOrigin(d time.Duration) testIdleOrigin {
	return testIdleOrigin{
		duration: d,
	}
}

func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
	time.Sleep(o.duration)
	return -1, nil
}

func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
	return 0, nil
}

func (o *testIdleOrigin) Close() error {
	o.closed = true
	return nil
}

type testErrOrigin struct {
	readErr  error
	writeErr error
}

func newTestErrOrigin(readErr error, writeErr error) testErrOrigin {
	return testErrOrigin{readErr, writeErr}
}

func (o *testErrOrigin) Read(p []byte) (n int, err error) {
	return 0, o.readErr
}

func (o *testErrOrigin) Write(p []byte) (n int, err error) {
	return len(p), o.writeErr
}

func (o *testErrOrigin) Close() error {
	return nil
}