2019-11-18 16:28:18 +00:00
|
|
|
package pogs
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"reflect"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
capnp "zombiezen.com/go/capnproto2"
|
2021-03-23 14:30:43 +00:00
|
|
|
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
2019-11-18 16:28:18 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
// Ensure the AuthOutcome sum is correct
|
|
|
|
var _ AuthOutcome = &AuthSuccess{}
|
|
|
|
var _ AuthOutcome = &AuthFail{}
|
|
|
|
var _ AuthOutcome = &AuthUnknown{}
|
|
|
|
|
|
|
|
// Unit tests for AuthenticateResponse.Outcome()
|
|
|
|
func TestAuthenticateResponseOutcome(t *testing.T) {
|
|
|
|
type fields struct {
|
|
|
|
PermanentErr string
|
|
|
|
RetryableErr string
|
|
|
|
Jwt []byte
|
|
|
|
HoursUntilRefresh uint8
|
|
|
|
}
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
fields fields
|
|
|
|
want AuthOutcome
|
|
|
|
}{
|
|
|
|
{"success",
|
|
|
|
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
|
2019-11-20 18:12:08 +00:00
|
|
|
AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6},
|
2019-11-18 16:28:18 +00:00
|
|
|
},
|
|
|
|
{"fail",
|
|
|
|
fields{PermanentErr: "bad creds"},
|
2019-11-20 18:12:08 +00:00
|
|
|
AuthFail{err: fmt.Errorf("bad creds")},
|
2019-11-18 16:28:18 +00:00
|
|
|
},
|
|
|
|
{"error",
|
|
|
|
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
|
2019-11-20 18:12:08 +00:00
|
|
|
AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6},
|
2019-11-18 16:28:18 +00:00
|
|
|
},
|
|
|
|
{"nil (no fields are set)",
|
|
|
|
fields{},
|
|
|
|
nil,
|
|
|
|
},
|
|
|
|
{"nil (too few fields are set)",
|
|
|
|
fields{HoursUntilRefresh: 6},
|
|
|
|
nil,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
ar := AuthenticateResponse{
|
|
|
|
PermanentErr: tt.fields.PermanentErr,
|
|
|
|
RetryableErr: tt.fields.RetryableErr,
|
|
|
|
Jwt: tt.fields.Jwt,
|
|
|
|
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
|
|
|
|
}
|
2019-11-18 23:01:20 +00:00
|
|
|
got := ar.Outcome()
|
|
|
|
if !reflect.DeepEqual(got, tt.want) {
|
2019-11-18 16:28:18 +00:00
|
|
|
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
|
|
|
|
}
|
2019-11-18 23:01:20 +00:00
|
|
|
if got != nil && !reflect.DeepEqual(got.Serialize(), ar) {
|
|
|
|
t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize())
|
|
|
|
}
|
2019-11-18 16:28:18 +00:00
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-11-20 18:12:08 +00:00
|
|
|
func TestAuthSuccess(t *testing.T) {
|
|
|
|
input := NewAuthSuccess([]byte("asdf"), 6)
|
|
|
|
output, ok := input.Serialize().Outcome().(AuthSuccess)
|
|
|
|
assert.True(t, ok)
|
|
|
|
assert.Equal(t, input, output)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuthUnknown(t *testing.T) {
|
|
|
|
input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6)
|
|
|
|
output, ok := input.Serialize().Outcome().(AuthUnknown)
|
|
|
|
assert.True(t, ok)
|
|
|
|
assert.Equal(t, input, output)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuthFail(t *testing.T) {
|
|
|
|
input := NewAuthFail(fmt.Errorf("wrong creds"))
|
|
|
|
output, ok := input.Serialize().Outcome().(AuthFail)
|
|
|
|
assert.True(t, ok)
|
|
|
|
assert.Equal(t, input, output)
|
|
|
|
}
|
|
|
|
|
2019-11-18 16:28:18 +00:00
|
|
|
func TestWhenToRefresh(t *testing.T) {
|
|
|
|
expected := 4 * time.Hour
|
|
|
|
actual := hoursToTime(4)
|
|
|
|
if expected != actual {
|
|
|
|
t.Fatalf("expected %v hours, got %v", expected, actual)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Test that serializing and deserializing AuthenticationResponse undo each other.
|
|
|
|
func TestSerializeAuthenticationResponse(t *testing.T) {
|
|
|
|
|
|
|
|
tests := []*AuthenticateResponse{
|
2020-11-25 06:55:13 +00:00
|
|
|
{
|
2019-11-18 16:28:18 +00:00
|
|
|
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
|
|
|
|
HoursUntilRefresh: 24,
|
|
|
|
},
|
2020-11-25 06:55:13 +00:00
|
|
|
{
|
2019-11-18 16:28:18 +00:00
|
|
|
PermanentErr: "bad auth",
|
|
|
|
},
|
2020-11-25 06:55:13 +00:00
|
|
|
{
|
2019-11-18 16:28:18 +00:00
|
|
|
RetryableErr: "bad connection",
|
|
|
|
HoursUntilRefresh: 24,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for i, testCase := range tests {
|
|
|
|
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
|
|
|
capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg)
|
|
|
|
if !assert.NoError(t, err) {
|
|
|
|
t.Fatal("Couldn't initialize a new message")
|
|
|
|
}
|
|
|
|
err = MarshalAuthenticateResponse(capnpEntity, testCase)
|
|
|
|
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
result, err := UnmarshalAuthenticateResponse(capnpEntity)
|
|
|
|
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
|
|
|
}
|
|
|
|
}
|