diff --git a/token/token.go b/token/token.go index d561dc38..30ab9366 100644 --- a/token/token.go +++ b/token/token.go @@ -53,7 +53,7 @@ type signalHandler struct { } type jwtPayload struct { - Aud []string `json:"aud"` + Aud []string `json:"-"` Email string `json:"email"` Exp int `json:"exp"` Iat int `json:"iat"` @@ -68,6 +68,34 @@ type transferServiceResponse struct { OrgToken string `json:"org_token"` } +func (p *jwtPayload) UnmarshalJSON(data []byte) error { + type Alias jwtPayload + if err := json.Unmarshal(data, (*Alias)(p)); err != nil { + return err + } + var audParser struct { + Aud any `json:"aud"` + } + if err := json.Unmarshal(data, &audParser); err != nil { + return err + } + switch aud := audParser.Aud.(type) { + case string: + p.Aud = []string{aud} + case []any: + for _, a := range aud { + s, ok := a.(string) + if !ok { + return errors.New("aud array contains non-string elements") + } + p.Aud = append(p.Aud, s) + } + default: + return errors.New("aud field is not a string or an array of strings") + } + return nil +} + func (p jwtPayload) isExpired() bool { return int(time.Now().Unix()) > p.Exp } @@ -182,7 +210,9 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog. if err = fileLockAppToken.Acquire(); err != nil { return "", errors.Wrap(err, "failed to acquire app token lock") } - defer fileLockAppToken.Release() + defer func() { + _ = fileLockAppToken.Release() + }() // check to see if another process has gotten a token while we waited for the lock if token, err := GetAppTokenIfExists(appInfo); token != "" && err == nil { @@ -202,7 +232,9 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog. if err = fileLockOrgToken.Acquire(); err != nil { return "", errors.Wrap(err, "failed to acquire org token lock") } - defer fileLockOrgToken.Release() + defer func() { + _ = fileLockOrgToken.Release() + }() // check if an org token has been created since the lock was acquired orgToken, err = GetOrgTokenIfExists(appInfo.AuthDomain) } @@ -218,7 +250,6 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog. } } return getTokensFromEdge(appURL, appInfo.AppAUD, appTokenPath, orgTokenPath, useHostOnly, log) - } // getTokensFromEdge will attempt to use the transfer service to retrieve an app and org token, save them to disk, @@ -250,7 +281,6 @@ func getTokensFromEdge(appURL *url.URL, appAUD, appTokenPath, orgTokenPath strin } return resp.AppToken, nil - } // GetAppInfo makes a request to the appURL and stops at the first redirect. The 302 location header will contain the @@ -320,7 +350,6 @@ func handleRedirects(req *http.Request, via []*http.Request, orgToken string) er } } } - } // stop after hitting authorized endpoint since it will contain the app token @@ -408,7 +437,6 @@ func GetAppTokenIfExists(appInfo *AppInfo) (string, error) { return "", err } return token.CompactSerialize() - } // GetTokenIfExists will return the token from local storage if it exists and not expired diff --git a/token/token_test.go b/token/token_test.go index 5c69352d..da92ed73 100644 --- a/token/token_test.go +++ b/token/token_test.go @@ -1,6 +1,7 @@ package token import ( + "encoding/json" "net/http" "net/url" "testing" @@ -11,7 +12,7 @@ func TestHandleRedirects_AttachOrgToken(t *testing.T) { via := []*http.Request{} orgToken := "orgTokenValue" - handleRedirects(req, via, orgToken) + _ = handleRedirects(req, via, orgToken) // Check if the orgToken cookie is attached cookies := req.Cookies() @@ -80,3 +81,55 @@ func TestHandleRedirects_StopAtAuthorizedEndpoint(t *testing.T) { t.Errorf("Expected ErrUseLastResponse, got %v", err) } } + +func TestJwtPayloadUnmarshal_AudAsString(t *testing.T) { + jwt := `{"aud":"7afbdaf987054f889b3bdd0d29ebfcd2"}` + var payload jwtPayload + if err := json.Unmarshal([]byte(jwt), &payload); err != nil { + t.Errorf("Expected no error, got %v", err) + } + if len(payload.Aud) != 1 || payload.Aud[0] != "7afbdaf987054f889b3bdd0d29ebfcd2" { + t.Errorf("Expected aud to be 7afbdaf987054f889b3bdd0d29ebfcd2, got %v", payload.Aud) + } +} + +func TestJwtPayloadUnmarshal_AudAsSlice(t *testing.T) { + jwt := `{"aud":["7afbdaf987054f889b3bdd0d29ebfcd2", "f835c0016f894768976c01e076844efe"]}` + var payload jwtPayload + if err := json.Unmarshal([]byte(jwt), &payload); err != nil { + t.Errorf("Expected no error, got %v", err) + } + if len(payload.Aud) != 2 || payload.Aud[0] != "7afbdaf987054f889b3bdd0d29ebfcd2" || payload.Aud[1] != "f835c0016f894768976c01e076844efe" { + t.Errorf("Expected aud to be [7afbdaf987054f889b3bdd0d29ebfcd2, f835c0016f894768976c01e076844efe], got %v", payload.Aud) + } +} + +func TestJwtPayloadUnmarshal_FailsWhenAudIsInt(t *testing.T) { + jwt := `{"aud":123}` + var payload jwtPayload + err := json.Unmarshal([]byte(jwt), &payload) + wantErr := "aud field is not a string or an array of strings" + if err.Error() != wantErr { + t.Errorf("Expected %v, got %v", wantErr, err) + } +} + +func TestJwtPayloadUnmarshal_FailsWhenAudIsArrayOfInts(t *testing.T) { + jwt := `{"aud": [999, 123] }` + var payload jwtPayload + err := json.Unmarshal([]byte(jwt), &payload) + wantErr := "aud array contains non-string elements" + if err.Error() != wantErr { + t.Errorf("Expected %v, got %v", wantErr, err) + } +} + +func TestJwtPayloadUnmarshal_FailsWhenAudIsOmitted(t *testing.T) { + jwt := `{}` + var payload jwtPayload + err := json.Unmarshal([]byte(jwt), &payload) + wantErr := "aud field is not a string or an array of strings" + if err.Error() != wantErr { + t.Errorf("Expected %v, got %v", wantErr, err) + } +}