package pogs import ( "fmt" "reflect" "testing" "time" "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/stretchr/testify/assert" capnp "zombiezen.com/go/capnproto2" ) // Ensure the AuthOutcome sum is correct var _ AuthOutcome = &AuthSuccess{} var _ AuthOutcome = &AuthFail{} var _ AuthOutcome = &AuthUnknown{} // Unit tests for AuthenticateResponse.Outcome() func TestAuthenticateResponseOutcome(t *testing.T) { type fields struct { PermanentErr string RetryableErr string Jwt []byte HoursUntilRefresh uint8 } tests := []struct { name string fields fields want AuthOutcome }{ {"success", fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6}, AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6}, }, {"fail", fields{PermanentErr: "bad creds"}, AuthFail{err: fmt.Errorf("bad creds")}, }, {"error", fields{RetryableErr: "bad conn", HoursUntilRefresh: 6}, AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6}, }, {"nil (no fields are set)", fields{}, nil, }, {"nil (too few fields are set)", fields{HoursUntilRefresh: 6}, nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ar := AuthenticateResponse{ PermanentErr: tt.fields.PermanentErr, RetryableErr: tt.fields.RetryableErr, Jwt: tt.fields.Jwt, HoursUntilRefresh: tt.fields.HoursUntilRefresh, } got := ar.Outcome() if !reflect.DeepEqual(got, tt.want) { t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want) } if got != nil && !reflect.DeepEqual(got.Serialize(), ar) { t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize()) } }) } } func TestAuthSuccess(t *testing.T) { input := NewAuthSuccess([]byte("asdf"), 6) output, ok := input.Serialize().Outcome().(AuthSuccess) assert.True(t, ok) assert.Equal(t, input, output) } func TestAuthUnknown(t *testing.T) { input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6) output, ok := input.Serialize().Outcome().(AuthUnknown) assert.True(t, ok) assert.Equal(t, input, output) } func TestAuthFail(t *testing.T) { input := NewAuthFail(fmt.Errorf("wrong creds")) output, ok := input.Serialize().Outcome().(AuthFail) assert.True(t, ok) assert.Equal(t, input, output) } func TestWhenToRefresh(t *testing.T) { expected := 4 * time.Hour actual := hoursToTime(4) if expected != actual { t.Fatalf("expected %v hours, got %v", expected, actual) } } // Test that serializing and deserializing AuthenticationResponse undo each other. func TestSerializeAuthenticationResponse(t *testing.T) { tests := []*AuthenticateResponse{ &AuthenticateResponse{ Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"), HoursUntilRefresh: 24, }, &AuthenticateResponse{ PermanentErr: "bad auth", }, &AuthenticateResponse{ RetryableErr: "bad connection", HoursUntilRefresh: 24, }, } for i, testCase := range tests { _, seg, err := capnp.NewMessage(capnp.SingleSegment(nil)) capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg) if !assert.NoError(t, err) { t.Fatal("Couldn't initialize a new message") } err = MarshalAuthenticateResponse(capnpEntity, testCase) if !assert.NoError(t, err, "testCase index %v failed to marshal", i) { continue } result, err := UnmarshalAuthenticateResponse(capnpEntity) if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) { continue } assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i) } }