cloudflared-mirror/tunnelrpc/pogs/connectionrpc_test.go

141 lines
3.3 KiB
Go
Raw Permalink Normal View History

package pogs
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/tunnelrpc"
)
const testAccountTag = "abc123"
func TestMarshalConnectionOptions(t *testing.T) {
clientID := uuid.New()
orig := ConnectionOptions{
Client: ClientInfo{
ClientID: clientID[:],
Features: []string{"a", "b"},
Version: "1.2.3",
Arch: "macos",
},
OriginLocalIP: []byte{10, 2, 3, 4},
ReplaceExisting: false,
CompressionQuality: 1,
}
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
require.NoError(t, err)
capnpOpts, err := tunnelrpc.NewConnectionOptions(seg)
require.NoError(t, err)
err = orig.MarshalCapnproto(capnpOpts)
assert.NoError(t, err)
var pogsOpts ConnectionOptions
err = pogsOpts.UnmarshalCapnproto(capnpOpts)
assert.NoError(t, err)
assert.Equal(t, orig, pogsOpts)
}
func TestConnectionRegistrationRPC(t *testing.T) {
p1, p2 := net.Pipe()
t1, t2 := rpc.StreamTransport(p1), rpc.StreamTransport(p2)
// Server-side
testImpl := testConnectionRegistrationServer{}
srv := TunnelServer_ServerToClient(&testImpl)
serverConn := rpc.NewConn(t1, rpc.MainInterface(srv.Client))
defer serverConn.Wait()
ctx := context.Background()
clientConn := rpc.NewConn(t2)
defer clientConn.Close()
client := TunnelServer_PogsClient{
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
}
defer client.Close()
clientID := uuid.New()
options := &ConnectionOptions{
Client: ClientInfo{
ClientID: clientID[:],
Features: []string{"foo"},
Version: "1.2.3",
Arch: "macos",
},
OriginLocalIP: net.IP{10, 20, 30, 40},
ReplaceExisting: true,
CompressionQuality: 0,
}
expectedDetails := ConnectionDetails{
UUID: uuid.New(),
Location: "TEST",
}
testImpl.details = &expectedDetails
testImpl.err = nil
auth := TunnelAuth{
AccountTag: testAccountTag,
TunnelSecret: []byte{1, 2, 3, 4},
}
// success
tunnelID := uuid.New()
details, err := client.RegisterConnection(ctx, auth, tunnelID, 2, options)
assert.NoError(t, err)
assert.Equal(t, expectedDetails, *details)
// regular error
testImpl.details = nil
testImpl.err = errors.New("internal")
_, err = client.RegisterConnection(ctx, auth, tunnelID, 2, options)
assert.EqualError(t, err, "internal")
// retriable error
testImpl.details = nil
const delay = 27 * time.Second
testImpl.err = RetryErrorAfter(errors.New("retryable"), delay)
_, err = client.RegisterConnection(ctx, auth, tunnelID, 2, options)
assert.EqualError(t, err, "retryable")
re, ok := err.(*RetryableError)
assert.True(t, ok)
assert.Equal(t, delay, re.Delay)
}
type testConnectionRegistrationServer struct {
mockTunnelServerBase
details *ConnectionDetails
err error
}
func (t *testConnectionRegistrationServer) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
if auth.AccountTag != testAccountTag {
panic("bad account tag: " + auth.AccountTag)
}
if t.err != nil {
return nil, t.err
}
if t.details != nil {
return t.details, nil
}
panic("either details or err mush be set")
}