package h2mux

import (
	"bytes"
	"io"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
	return func(actual int, err error) {
		if expected != actual {
			t.Fatalf("Expected %d bytes, got %d", expected, actual)
		}
		if err != nil {
			t.Fatalf("Unexpected error %s", err)
		}
	}
}

func TestSharedBuffer(t *testing.T) {
	b := NewSharedBuffer()
	testData := []byte("Hello world")
	AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
	bytesRead := make([]byte, len(testData))
	AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
}

func TestSharedBufferBlockingRead(t *testing.T) {
	b := NewSharedBuffer()
	testData1 := []byte("Hello")
	testData2 := []byte(" world")
	result := make(chan []byte)
	go func() {
		bytesRead := make([]byte, len(testData1)+len(testData2))
		nRead, err := b.Read(bytesRead)
		AssertIOReturnIsGood(t, len(testData1))(nRead, err)
		result <- bytesRead[:nRead]
		nRead, err = b.Read(bytesRead)
		AssertIOReturnIsGood(t, len(testData2))(nRead, err)
		result <- bytesRead[:nRead]
	}()
	time.Sleep(time.Millisecond * 250)
	select {
	case <-result:
		t.Fatalf("read returned early")
	default:
	}
	AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
	select {
	case r := <-result:
		assert.Equal(t, testData1, r)
	case <-time.After(time.Second):
		t.Fatalf("read timed out")
	}
	AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
	select {
	case r := <-result:
		assert.Equal(t, testData2, r)
	case <-time.After(time.Second):
		t.Fatalf("read timed out")
	}
}

// This is quite slow under the race detector
func TestSharedBufferConcurrentReadWrite(t *testing.T) {
	b := NewSharedBuffer()
	var expectedResult, actualResult bytes.Buffer
	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		block := make([]byte, 256)
		for i := range block {
			block[i] = byte(i)
		}
		for blockSize := 1; blockSize <= 256; blockSize++ {
			for i := 0; i < 256; i++ {
				expectedResult.Write(block[:blockSize])
				n, err := b.Write(block[:blockSize])
				if n != blockSize || err != nil {
					t.Fatalf("write error: %d %s", n, err)
				}
			}
		}
		wg.Done()
	}()
	go func() {
		block := make([]byte, 256)
		// Change block sizes in opposition to the write thread, to test blocking for new data.
		for blockSize := 256; blockSize > 0; blockSize-- {
			for i := 0; i < 256; i++ {
				n, err := io.ReadFull(b, block[:blockSize])
				if n != blockSize || err != nil {
					t.Fatalf("read error: %d %s", n, err)
				}
				actualResult.Write(block[:blockSize])
			}
		}
		wg.Done()
	}()
	wg.Wait()
	if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 {
		t.Fatal("Result diverged")
	}
}

func TestSharedBufferClose(t *testing.T) {
	b := NewSharedBuffer()
	testData := []byte("Hello world")
	AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
	err := b.Close()
	if err != nil {
		t.Fatalf("unexpected error from Close: %s", err)
	}
	bytesRead := make([]byte, len(testData))
	AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
	n, err := b.Read(bytesRead)
	if n != 0 {
		t.Fatalf("extra bytes received: %d", n)
	}
	if err != io.EOF {
		t.Fatalf("expected EOF, got %s", err)
	}
}