139 lines
3.3 KiB
Go
139 lines
3.3 KiB
Go
package packet
|
|
|
|
import (
|
|
"fmt"
|
|
"net/netip"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type mockFunnelUniPipe struct {
|
|
uniPipe chan RawPacket
|
|
}
|
|
|
|
func (mfui *mockFunnelUniPipe) SendPacket(dst netip.Addr, pk RawPacket) error {
|
|
mfui.uniPipe <- pk
|
|
return nil
|
|
}
|
|
|
|
func (mfui *mockFunnelUniPipe) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func TestFunnelRegistration(t *testing.T) {
|
|
id := testFunnelID{"id1"}
|
|
funnelErr := fmt.Errorf("expected error")
|
|
newFunnelFuncErr := func() (Funnel, error) { return nil, funnelErr }
|
|
newFunnelFuncUncalled := func() (Funnel, error) {
|
|
require.FailNow(t, "a new funnel should not be created")
|
|
panic("unreached")
|
|
}
|
|
funnel1, newFunnelFunc1 := newFunnelAndFunc("funnel1")
|
|
funnel2, newFunnelFunc2 := newFunnelAndFunc("funnel2")
|
|
|
|
ft := NewFunnelTracker()
|
|
// Register funnel1
|
|
funnel, new, err := ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc1)
|
|
require.NoError(t, err)
|
|
require.True(t, new)
|
|
require.Equal(t, funnel1, funnel)
|
|
// Register funnel, no replace
|
|
funnel, new, err = ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFuncUncalled)
|
|
require.NoError(t, err)
|
|
require.False(t, new)
|
|
require.Equal(t, funnel1, funnel)
|
|
// Register funnel2, replace
|
|
funnel, new, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFunc2)
|
|
require.NoError(t, err)
|
|
require.True(t, new)
|
|
require.Equal(t, funnel2, funnel)
|
|
require.True(t, funnel1.closed)
|
|
// Register funnel error, replace
|
|
funnel, new, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFuncErr)
|
|
require.ErrorIs(t, err, funnelErr)
|
|
require.False(t, new)
|
|
require.Nil(t, funnel)
|
|
require.True(t, funnel2.closed)
|
|
}
|
|
|
|
func TestFunnelUnregister(t *testing.T) {
|
|
id := testFunnelID{"id1"}
|
|
funnel1, newFunnelFunc1 := newFunnelAndFunc("funnel1")
|
|
funnel2, newFunnelFunc2 := newFunnelAndFunc("funnel2")
|
|
funnel3, newFunnelFunc3 := newFunnelAndFunc("funnel3")
|
|
|
|
ft := NewFunnelTracker()
|
|
// Register & unregister
|
|
_, _, err := ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc1)
|
|
require.NoError(t, err)
|
|
require.True(t, ft.Unregister(id, funnel1))
|
|
require.True(t, funnel1.closed)
|
|
require.True(t, ft.Unregister(id, funnel1))
|
|
// Register, replace, and unregister
|
|
_, _, err = ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc2)
|
|
require.NoError(t, err)
|
|
_, _, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFunc3)
|
|
require.NoError(t, err)
|
|
require.True(t, funnel2.closed)
|
|
require.False(t, ft.Unregister(id, funnel2))
|
|
require.True(t, ft.Unregister(id, funnel3))
|
|
require.True(t, funnel3.closed)
|
|
}
|
|
|
|
func shouldReplaceFalse(_ Funnel) bool {
|
|
return false
|
|
}
|
|
|
|
func shouldReplaceTrue(_ Funnel) bool {
|
|
return true
|
|
}
|
|
|
|
func newFunnelAndFunc(id string) (*testFunnel, func() (Funnel, error)) {
|
|
funnel := newTestFunnel(id)
|
|
funnelFunc := func() (Funnel, error) {
|
|
return funnel, nil
|
|
}
|
|
return funnel, funnelFunc
|
|
}
|
|
|
|
type testFunnelID struct {
|
|
id string
|
|
}
|
|
|
|
func (t testFunnelID) Type() string {
|
|
return "testFunnelID"
|
|
}
|
|
|
|
func (t testFunnelID) String() string {
|
|
return t.id
|
|
}
|
|
|
|
type testFunnel struct {
|
|
id string
|
|
closed bool
|
|
}
|
|
|
|
func newTestFunnel(id string) *testFunnel {
|
|
return &testFunnel{
|
|
id,
|
|
false,
|
|
}
|
|
}
|
|
|
|
func (tf *testFunnel) Close() error {
|
|
tf.closed = true
|
|
return nil
|
|
}
|
|
|
|
func (tf *testFunnel) Equal(other Funnel) bool {
|
|
return tf.id == other.(*testFunnel).id
|
|
}
|
|
|
|
func (tf *testFunnel) LastActive() time.Time {
|
|
return time.Now()
|
|
}
|
|
|
|
func (tf *testFunnel) UpdateLastActive() {}
|