diff --git a/tunnelrpc/pogs/auth_outcome.go b/tunnelrpc/pogs/auth_outcome.go index 56001f10..6626ed4f 100644 --- a/tunnelrpc/pogs/auth_outcome.go +++ b/tunnelrpc/pogs/auth_outcome.go @@ -21,18 +21,18 @@ func (ar AuthenticateResponse) Outcome() AuthOutcome { // If there was a network error, then cloudflared should retry later, // because origintunneld couldn't prove whether auth was correct or not. if ar.RetryableErr != "" { - return &AuthUnknown{Err: fmt.Errorf(ar.RetryableErr), HoursUntilRefresh: ar.HoursUntilRefresh} + return NewAuthUnknown(fmt.Errorf(ar.RetryableErr), ar.HoursUntilRefresh) } // If the user's authentication was unsuccessful, the server will return an error explaining why. // cloudflared should fatal with this error. if ar.PermanentErr != "" { - return &AuthFail{Err: fmt.Errorf(ar.PermanentErr)} + return NewAuthFail(fmt.Errorf(ar.PermanentErr)) } // If auth succeeded, return the token and refresh it when instructed. if ar.PermanentErr == "" && len(ar.Jwt) > 0 { - return &AuthSuccess{Jwt: ar.Jwt, HoursUntilRefresh: ar.HoursUntilRefresh} + return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh) } // Otherwise the state got messed up. @@ -49,59 +49,71 @@ type AuthOutcome interface { // AuthSuccess means the backend successfully authenticated this cloudflared. type AuthSuccess struct { - Jwt []byte - HoursUntilRefresh uint8 + jwt []byte + hoursUntilRefresh uint8 +} + +func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess { + return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh} } // RefreshAfter is how long cloudflared should wait before rerunning Authenticate. -func (ao *AuthSuccess) RefreshAfter() time.Duration { - return hoursToTime(ao.HoursUntilRefresh) +func (ao AuthSuccess) RefreshAfter() time.Duration { + return hoursToTime(ao.hoursUntilRefresh) } // Serialize into an AuthenticateResponse which can be sent via Capnp -func (ao *AuthSuccess) Serialize() AuthenticateResponse { +func (ao AuthSuccess) Serialize() AuthenticateResponse { return AuthenticateResponse{ - Jwt: ao.Jwt, - HoursUntilRefresh: ao.HoursUntilRefresh, + Jwt: ao.jwt, + HoursUntilRefresh: ao.hoursUntilRefresh, } } -func (ao *AuthSuccess) isAuthOutcome() {} +func (ao AuthSuccess) isAuthOutcome() {} // AuthFail means this cloudflared has the wrong auth and should exit. type AuthFail struct { - Err error + err error +} + +func NewAuthFail(err error) AuthFail { + return AuthFail{err: err} } // Serialize into an AuthenticateResponse which can be sent via Capnp -func (ao *AuthFail) Serialize() AuthenticateResponse { +func (ao AuthFail) Serialize() AuthenticateResponse { return AuthenticateResponse{ - PermanentErr: ao.Err.Error(), + PermanentErr: ao.err.Error(), } } -func (ao *AuthFail) isAuthOutcome() {} +func (ao AuthFail) isAuthOutcome() {} // AuthUnknown means the backend couldn't finish checking authentication. Try again later. type AuthUnknown struct { - Err error - HoursUntilRefresh uint8 + err error + hoursUntilRefresh uint8 +} + +func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown { + return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh} } // RefreshAfter is how long cloudflared should wait before rerunning Authenticate. -func (ao *AuthUnknown) RefreshAfter() time.Duration { - return hoursToTime(ao.HoursUntilRefresh) +func (ao AuthUnknown) RefreshAfter() time.Duration { + return hoursToTime(ao.hoursUntilRefresh) } // Serialize into an AuthenticateResponse which can be sent via Capnp -func (ao *AuthUnknown) Serialize() AuthenticateResponse { +func (ao AuthUnknown) Serialize() AuthenticateResponse { return AuthenticateResponse{ - RetryableErr: ao.Err.Error(), - HoursUntilRefresh: ao.HoursUntilRefresh, + RetryableErr: ao.err.Error(), + HoursUntilRefresh: ao.hoursUntilRefresh, } } -func (ao *AuthUnknown) isAuthOutcome() {} +func (ao AuthUnknown) isAuthOutcome() {} func hoursToTime(hours uint8) time.Duration { return time.Duration(hours) * time.Hour diff --git a/tunnelrpc/pogs/auth_test.go b/tunnelrpc/pogs/auth_test.go index 3a5590e1..f30f3977 100644 --- a/tunnelrpc/pogs/auth_test.go +++ b/tunnelrpc/pogs/auth_test.go @@ -31,15 +31,15 @@ func TestAuthenticateResponseOutcome(t *testing.T) { }{ {"success", fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6}, - &AuthSuccess{Jwt: []byte("asdf"), HoursUntilRefresh: 6}, + AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6}, }, {"fail", fields{PermanentErr: "bad creds"}, - &AuthFail{Err: fmt.Errorf("bad creds")}, + AuthFail{err: fmt.Errorf("bad creds")}, }, {"error", fields{RetryableErr: "bad conn", HoursUntilRefresh: 6}, - &AuthUnknown{Err: fmt.Errorf("bad conn"), HoursUntilRefresh: 6}, + AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6}, }, {"nil (no fields are set)", fields{}, @@ -69,6 +69,27 @@ func TestAuthenticateResponseOutcome(t *testing.T) { } } +func TestAuthSuccess(t *testing.T) { + input := NewAuthSuccess([]byte("asdf"), 6) + output, ok := input.Serialize().Outcome().(AuthSuccess) + assert.True(t, ok) + assert.Equal(t, input, output) +} + +func TestAuthUnknown(t *testing.T) { + input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6) + output, ok := input.Serialize().Outcome().(AuthUnknown) + assert.True(t, ok) + assert.Equal(t, input, output) +} + +func TestAuthFail(t *testing.T) { + input := NewAuthFail(fmt.Errorf("wrong creds")) + output, ok := input.Serialize().Outcome().(AuthFail) + assert.True(t, ok) + assert.Equal(t, input, output) +} + func TestWhenToRefresh(t *testing.T) { expected := 4 * time.Hour actual := hoursToTime(4)