cloudflared-mirror/h2mux/shared_buffer_test.go

130 lines
3.2 KiB
Go

package h2mux
import (
"bytes"
"io"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
return func(actual int, err error) {
if expected != actual {
t.Fatalf("Expected %d bytes, got %d", expected, actual)
}
if err != nil {
t.Fatalf("Unexpected error %s", err)
}
}
}
func TestSharedBuffer(t *testing.T) {
b := NewSharedBuffer()
testData := []byte("Hello world")
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
bytesRead := make([]byte, len(testData))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
}
func TestSharedBufferBlockingRead(t *testing.T) {
b := NewSharedBuffer()
testData1 := []byte("Hello")
testData2 := []byte(" world")
result := make(chan []byte)
go func() {
bytesRead := make([]byte, len(testData1)+len(testData2))
nRead, err := b.Read(bytesRead)
AssertIOReturnIsGood(t, len(testData1))(nRead, err)
result <- bytesRead[:nRead]
nRead, err = b.Read(bytesRead)
AssertIOReturnIsGood(t, len(testData2))(nRead, err)
result <- bytesRead[:nRead]
}()
time.Sleep(time.Millisecond * 250)
select {
case <-result:
t.Fatalf("read returned early")
default:
}
AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
select {
case r := <-result:
assert.Equal(t, testData1, r)
case <-time.After(time.Second):
t.Fatalf("read timed out")
}
AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
select {
case r := <-result:
assert.Equal(t, testData2, r)
case <-time.After(time.Second):
t.Fatalf("read timed out")
}
}
// This is quite slow under the race detector
func TestSharedBufferConcurrentReadWrite(t *testing.T) {
b := NewSharedBuffer()
var expectedResult, actualResult bytes.Buffer
var wg sync.WaitGroup
wg.Add(2)
go func() {
block := make([]byte, 256)
for i := range block {
block[i] = byte(i)
}
for blockSize := 1; blockSize <= 256; blockSize++ {
for i := 0; i < 256; i++ {
expectedResult.Write(block[:blockSize])
n, err := b.Write(block[:blockSize])
if n != blockSize || err != nil {
t.Errorf("write error: %d %s", n, err)
return
}
}
}
wg.Done()
}()
go func() {
block := make([]byte, 256)
// Change block sizes in opposition to the write thread, to test blocking for new data.
for blockSize := 256; blockSize > 0; blockSize-- {
for i := 0; i < 256; i++ {
n, err := io.ReadFull(b, block[:blockSize])
if n != blockSize || err != nil {
t.Errorf("read error: %d %s", n, err)
return
}
actualResult.Write(block[:blockSize])
}
}
wg.Done()
}()
wg.Wait()
if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 {
t.Fatal("Result diverged")
}
}
func TestSharedBufferClose(t *testing.T) {
b := NewSharedBuffer()
testData := []byte("Hello world")
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
err := b.Close()
if err != nil {
t.Fatalf("unexpected error from Close: %s", err)
}
bytesRead := make([]byte, len(testData))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
n, err := b.Read(bytesRead)
if n != 0 {
t.Fatalf("extra bytes received: %d", n)
}
if err != io.EOF {
t.Fatalf("expected EOF, got %s", err)
}
}