cloudflared-mirror/datagramsession/session_test.go

178 lines
4.1 KiB
Go

package datagramsession
import (
"bytes"
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
// TestCloseSession makes sure a session will stop after context is done
func TestSessionCtxDone(t *testing.T) {
testSessionReturns(t, closeByContext, time.Minute*2)
}
// TestCloseSession makes sure a session will stop after close method is called
func TestCloseSession(t *testing.T) {
testSessionReturns(t, closeByCallingClose, time.Minute*2)
}
// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
func TestCloseIdle(t *testing.T) {
testSessionReturns(t, closeByTimeout, time.Millisecond*100)
}
func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
session := newSession(sessionID, transport, cfdConn)
ctx, cancel := context.WithCancel(context.Background())
sessionDone := make(chan struct{})
go func() {
session.Serve(ctx, closeAfterIdle)
close(sessionDone)
}()
go func() {
n, err := session.transportToDst(payload)
require.NoError(t, err)
require.Equal(t, len(payload), n)
}()
readBuffer := make([]byte, len(payload)+1)
n, err := originConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, len(payload), n)
lastRead := time.Now()
switch closeBy {
case closeByContext:
cancel()
case closeByCallingClose:
session.close()
}
<-sessionDone
if closeBy == closeByTimeout {
require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
}
// call cancelled again otherwise the linter will warn about possible context leak
cancel()
}
type closeMethod int
const (
closeByContext closeMethod = iota
closeByCallingClose
closeByTimeout
)
func TestWriteToDstSessionPreventClosed(t *testing.T) {
testActiveSessionNotClosed(t, false, true)
}
func TestReadFromDstSessionPreventClosed(t *testing.T) {
testActiveSessionNotClosed(t, true, false)
}
func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
const closeAfterIdle = time.Millisecond * 100
const activeTime = time.Millisecond * 500
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(100),
respChan: newDatagramChannel(100),
}
session := newSession(sessionID, transport, cfdConn)
startTime := time.Now()
activeUntil := startTime.Add(activeTime)
ctx, cancel := context.WithCancel(context.Background())
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
session.Serve(ctx, closeAfterIdle)
if time.Now().Before(startTime.Add(activeTime)) {
return fmt.Errorf("session closed while it's still active")
}
return nil
})
if readFromDst {
errGroup.Go(func() error {
for {
if time.Now().After(activeUntil) {
return nil
}
if _, err := originConn.Write(payload); err != nil {
return err
}
time.Sleep(closeAfterIdle / 2)
}
})
}
if writeToDst {
errGroup.Go(func() error {
readBuffer := make([]byte, len(payload))
for {
n, err := originConn.Read(readBuffer)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return nil
}
return err
}
if !bytes.Equal(payload, readBuffer[:n]) {
return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
}
}
})
errGroup.Go(func() error {
for {
if time.Now().After(activeUntil) {
return nil
}
if _, err := session.transportToDst(payload); err != nil {
return err
}
time.Sleep(closeAfterIdle / 2)
}
})
}
require.NoError(t, errGroup.Wait())
cancel()
}
func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50
session := newSession(uuid.New(), nil, nil)
var wg sync.WaitGroup
wg.Add(concurrentCalls)
for i := 0; i < concurrentCalls; i++ {
go func() {
session.markActive()
wg.Done()
}()
}
wg.Wait()
}