130 lines
3.2 KiB
Go
130 lines
3.2 KiB
Go
package h2mux
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
|
|
return func(actual int, err error) {
|
|
if expected != actual {
|
|
t.Fatalf("Expected %d bytes, got %d", expected, actual)
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error %s", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSharedBuffer(t *testing.T) {
|
|
b := NewSharedBuffer()
|
|
testData := []byte("Hello world")
|
|
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
|
|
bytesRead := make([]byte, len(testData))
|
|
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
|
}
|
|
|
|
func TestSharedBufferBlockingRead(t *testing.T) {
|
|
b := NewSharedBuffer()
|
|
testData1 := []byte("Hello")
|
|
testData2 := []byte(" world")
|
|
result := make(chan []byte)
|
|
go func() {
|
|
bytesRead := make([]byte, len(testData1)+len(testData2))
|
|
nRead, err := b.Read(bytesRead)
|
|
AssertIOReturnIsGood(t, len(testData1))(nRead, err)
|
|
result <- bytesRead[:nRead]
|
|
nRead, err = b.Read(bytesRead)
|
|
AssertIOReturnIsGood(t, len(testData2))(nRead, err)
|
|
result <- bytesRead[:nRead]
|
|
}()
|
|
time.Sleep(time.Millisecond * 250)
|
|
select {
|
|
case <-result:
|
|
t.Fatalf("read returned early")
|
|
default:
|
|
}
|
|
AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
|
|
select {
|
|
case r := <-result:
|
|
assert.Equal(t, testData1, r)
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("read timed out")
|
|
}
|
|
AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
|
|
select {
|
|
case r := <-result:
|
|
assert.Equal(t, testData2, r)
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("read timed out")
|
|
}
|
|
}
|
|
|
|
// This is quite slow under the race detector
|
|
func TestSharedBufferConcurrentReadWrite(t *testing.T) {
|
|
b := NewSharedBuffer()
|
|
var expectedResult, actualResult bytes.Buffer
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
block := make([]byte, 256)
|
|
for i := range block {
|
|
block[i] = byte(i)
|
|
}
|
|
for blockSize := 1; blockSize <= 256; blockSize++ {
|
|
for i := 0; i < 256; i++ {
|
|
expectedResult.Write(block[:blockSize])
|
|
n, err := b.Write(block[:blockSize])
|
|
if n != blockSize || err != nil {
|
|
t.Errorf("write error: %d %s", n, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
wg.Done()
|
|
}()
|
|
go func() {
|
|
block := make([]byte, 256)
|
|
// Change block sizes in opposition to the write thread, to test blocking for new data.
|
|
for blockSize := 256; blockSize > 0; blockSize-- {
|
|
for i := 0; i < 256; i++ {
|
|
n, err := io.ReadFull(b, block[:blockSize])
|
|
if n != blockSize || err != nil {
|
|
t.Errorf("read error: %d %s", n, err)
|
|
return
|
|
}
|
|
actualResult.Write(block[:blockSize])
|
|
}
|
|
}
|
|
wg.Done()
|
|
}()
|
|
wg.Wait()
|
|
if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 {
|
|
t.Fatal("Result diverged")
|
|
}
|
|
}
|
|
|
|
func TestSharedBufferClose(t *testing.T) {
|
|
b := NewSharedBuffer()
|
|
testData := []byte("Hello world")
|
|
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
|
|
err := b.Close()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error from Close: %s", err)
|
|
}
|
|
bytesRead := make([]byte, len(testData))
|
|
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
|
n, err := b.Read(bytesRead)
|
|
if n != 0 {
|
|
t.Fatalf("extra bytes received: %d", n)
|
|
}
|
|
if err != io.EOF {
|
|
t.Fatalf("expected EOF, got %s", err)
|
|
}
|
|
}
|