diff --git a/tunnelrpc/pogs/auth_outcome.go b/tunnelrpc/pogs/auth_outcome.go index 9f72711e..56001f10 100644 --- a/tunnelrpc/pogs/auth_outcome.go +++ b/tunnelrpc/pogs/auth_outcome.go @@ -43,6 +43,8 @@ func (ar AuthenticateResponse) Outcome() AuthOutcome { //go-sumtype:decl AuthOutcome type AuthOutcome interface { isAuthOutcome() + // Serialize into an AuthenticateResponse which can be sent via Capnp + Serialize() AuthenticateResponse } // AuthSuccess means the backend successfully authenticated this cloudflared. @@ -56,6 +58,14 @@ func (ao *AuthSuccess) RefreshAfter() time.Duration { return hoursToTime(ao.HoursUntilRefresh) } +// Serialize into an AuthenticateResponse which can be sent via Capnp +func (ao *AuthSuccess) Serialize() AuthenticateResponse { + return AuthenticateResponse{ + Jwt: ao.Jwt, + HoursUntilRefresh: ao.HoursUntilRefresh, + } +} + func (ao *AuthSuccess) isAuthOutcome() {} // AuthFail means this cloudflared has the wrong auth and should exit. @@ -63,6 +73,13 @@ type AuthFail struct { Err error } +// Serialize into an AuthenticateResponse which can be sent via Capnp +func (ao *AuthFail) Serialize() AuthenticateResponse { + return AuthenticateResponse{ + PermanentErr: ao.Err.Error(), + } +} + func (ao *AuthFail) isAuthOutcome() {} // AuthUnknown means the backend couldn't finish checking authentication. Try again later. @@ -76,6 +93,14 @@ func (ao *AuthUnknown) RefreshAfter() time.Duration { return hoursToTime(ao.HoursUntilRefresh) } +// Serialize into an AuthenticateResponse which can be sent via Capnp +func (ao *AuthUnknown) Serialize() AuthenticateResponse { + return AuthenticateResponse{ + RetryableErr: ao.Err.Error(), + HoursUntilRefresh: ao.HoursUntilRefresh, + } +} + func (ao *AuthUnknown) isAuthOutcome() {} func hoursToTime(hours uint8) time.Duration { diff --git a/tunnelrpc/pogs/auth_test.go b/tunnelrpc/pogs/auth_test.go index 50ae1ec1..3a5590e1 100644 --- a/tunnelrpc/pogs/auth_test.go +++ b/tunnelrpc/pogs/auth_test.go @@ -58,9 +58,13 @@ func TestAuthenticateResponseOutcome(t *testing.T) { Jwt: tt.fields.Jwt, HoursUntilRefresh: tt.fields.HoursUntilRefresh, } - if got := ar.Outcome(); !reflect.DeepEqual(got, tt.want) { + got := ar.Outcome() + if !reflect.DeepEqual(got, tt.want) { t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want) } + if got != nil && !reflect.DeepEqual(got.Serialize(), ar) { + t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize()) + } }) } }