130 lines
3.0 KiB
Go
130 lines
3.0 KiB
Go
package pogs
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
capnp "zombiezen.com/go/capnproto2"
|
|
"zombiezen.com/go/capnproto2/rpc"
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
)
|
|
|
|
func TestMarshalConnectionOptions(t *testing.T) {
|
|
clientID := uuid.New()
|
|
orig := ConnectionOptions{
|
|
Client: ClientInfo{
|
|
ClientID: clientID[:],
|
|
Features: []string{"a", "b"},
|
|
Version: "1.2.3",
|
|
Arch: "macos",
|
|
},
|
|
OriginLocalIP: []byte{10, 2, 3, 4},
|
|
ReplaceExisting: false,
|
|
CompressionQuality: 1,
|
|
}
|
|
|
|
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
|
require.NoError(t, err)
|
|
capnpOpts, err := tunnelrpc.NewConnectionOptions(seg)
|
|
require.NoError(t, err)
|
|
|
|
err = orig.MarshalCapnproto(capnpOpts)
|
|
assert.NoError(t, err)
|
|
|
|
var pogsOpts ConnectionOptions
|
|
err = pogsOpts.UnmarshalCapnproto(capnpOpts)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, orig, pogsOpts)
|
|
}
|
|
|
|
func TestConnectionRegistrationRPC(t *testing.T) {
|
|
p1, p2 := net.Pipe()
|
|
t1, t2 := rpc.StreamTransport(p1), rpc.StreamTransport(p2)
|
|
|
|
// Server-side
|
|
testImpl := testConnectionRegistrationServer{}
|
|
srv := TunnelServer_ServerToClient(&testImpl)
|
|
serverConn := rpc.NewConn(t1, rpc.MainInterface(srv.Client))
|
|
defer serverConn.Wait()
|
|
|
|
ctx := context.Background()
|
|
clientConn := rpc.NewConn(t2)
|
|
defer clientConn.Close()
|
|
client := TunnelServer_PogsClient{
|
|
Client: clientConn.Bootstrap(ctx),
|
|
Conn: clientConn,
|
|
}
|
|
defer client.Close()
|
|
|
|
clientID := uuid.New()
|
|
options := &ConnectionOptions{
|
|
Client: ClientInfo{
|
|
ClientID: clientID[:],
|
|
Features: []string{"foo"},
|
|
Version: "1.2.3",
|
|
Arch: "macos",
|
|
},
|
|
OriginLocalIP: net.IP{10, 20, 30, 40},
|
|
ReplaceExisting: true,
|
|
CompressionQuality: 0,
|
|
}
|
|
|
|
expectedDetails := ConnectionDetails{
|
|
UUID: uuid.New(),
|
|
Location: "TEST",
|
|
}
|
|
testImpl.details = &expectedDetails
|
|
testImpl.err = nil
|
|
|
|
// success
|
|
tunnelID := uuid.New()
|
|
details, err := client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedDetails, *details)
|
|
|
|
// regular error
|
|
testImpl.details = nil
|
|
testImpl.err = errors.New("internal")
|
|
|
|
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
|
assert.EqualError(t, err, "internal")
|
|
|
|
// retriable error
|
|
testImpl.details = nil
|
|
const delay = 27*time.Second
|
|
testImpl.err = RetryErrorAfter(errors.New("retryable"), delay)
|
|
|
|
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
|
assert.EqualError(t, err, "retryable")
|
|
|
|
re, ok := err.(*RetryableError)
|
|
assert.True(t, ok)
|
|
assert.Equal(t, delay, re.Delay)
|
|
}
|
|
|
|
type testConnectionRegistrationServer struct {
|
|
mockTunnelServerBase
|
|
|
|
details *ConnectionDetails
|
|
err error
|
|
}
|
|
|
|
func (t testConnectionRegistrationServer) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
|
if t.err != nil {
|
|
return nil, t.err
|
|
}
|
|
if t.details != nil {
|
|
return t.details, nil
|
|
}
|
|
|
|
panic("either details or err mush be set")
|
|
}
|